
from cryptography.fernet import Fernet

from app.config import settings
from common_logging import get_logger

logger = get_logger(__name__)


class EncryptionService:

    def __init__(self):
        self._fernet: Fernet | None = None
        self._initialize()

    def _initialize(self):
        if not settings.ENCRYPTION_KEY:
            logger.warning("ENCRYPTION_KEY not configured, API keys will be stored in plaintext")
            return
        try:
            key = settings.ENCRYPTION_KEY
            if isinstance(key, str):
                key = key.encode()
            self._fernet = Fernet(key)
            logger.info("Encryption service initialized successfully")
        except Exception as e:
            logger.error(f"Failed to initialize encryption: {e}")
            raise ValueError("Invalid ENCRYPTION_KEY") from e

    def encrypt(self, plaintext: str) -> str:
        if not plaintext:
            return plaintext
        if not self._fernet:
            logger.warning("Encryption not configured, storing plaintext")
            return plaintext
        try:
            encrypted = self._fernet.encrypt(plaintext.encode())
            logger.info("Data encrypted successfully")
            return encrypted.decode()
        except Exception as e:
            logger.error(f"Encryption failed: {e}")
            raise

    def decrypt(self, ciphertext: str) -> str:
        if not ciphertext:
            return ciphertext
        if not self._fernet:
            logger.warning("Encryption not configured, returning value as-is")
            return ciphertext
        try:
            decrypted = self._fernet.decrypt(ciphertext.encode())
            logger.info("Data decrypted successfully")
            return decrypted.decode()
        except Exception as e:
            logger.warning(f"Decryption failed: {e}")
            return ciphertext

    def is_encrypted(self, value: str) -> bool:
        if not value or not self._fernet:
            return False
        try:
            self._fernet.decrypt(value.encode())
            return True
        except Exception:
            return False


encryption_service = EncryptionService()
