
from sqlalchemy.orm import Session

from app.models.provider import Model, ModelProvider

from common_logging import get_logger

logger = get_logger(__name__)




class EmbeddingBackendFactory:

    def generate_embedding(
        self, text: str, model_id: int, db: Session, normalize: bool = True
    ) -> list[float] | None:
        model = db.query(Model).filter(Model.id == model_id).first()
        if not model:
            logger.error(f"Model {model_id} not found")
            return None
        provider = model.provider
        if not provider:
            logger.error(f"Provider not found for model {model_id}")
            return None
        if provider.protocol in {"openai_compatible", "openai"}:
            return self._generate_embedding_openai_compatible(text, model, provider, normalize)
        else:
            logger.error(f"Unsupported provider protocol: {provider.protocol}")
            return None

    def generate_embeddings_batch(
        self,
        texts: list[str],
        model_id: int,
        db: Session,
        batch_size: int = 32,
        normalize: bool = True,
    ) -> list[list[float] | None]:
        model = db.query(Model).filter(Model.id == model_id).first()
        if not model:
            logger.error(f"Model {model_id} not found")
            return [None] * len(texts)
        provider = model.provider
        if not provider:
            logger.error(f"Provider not found for model {model_id}")
            return [None] * len(texts)
        if provider.protocol in {"openai_compatible", "openai"}:
            return self._generate_embeddings_batch_openai_compatible(
                texts, model, provider, batch_size, normalize
            )
        else:
            logger.error(f"Unsupported provider protocol: {provider.protocol}")
            return [None] * len(texts)

    def _generate_embedding_openai_compatible(
        self, text: str, model: Model, provider: ModelProvider, normalize: bool
    ) -> list[float] | None:
        try:
            import httpx

            base_url = provider.base_url or provider.default_base_url
            if not base_url:
                logger.error(f"No base_url configured for provider {provider.id}")
                return None
            headers = {"Content-Type": "application/json"}
            if provider.auth_type == "bearer" and provider.api_key:
                headers["Authorization"] = f"Bearer {provider.api_key}"
            elif provider.auth_type == "api_key" and provider.api_key:
                headers["api-key"] = provider.api_key
            elif provider.auth_type == "x_api_key" and provider.api_key:
                headers["X-API-Key"] = provider.api_key
            model_name = model.remote_model_id or model.code
            url = f"{base_url.rstrip('/')}/embeddings"
            with httpx.Client(timeout=30.0) as client:
                response = client.post(
                    url, headers=headers, json={"input": text, "model": model_name}
                )
                response.raise_for_status()
                data = response.json()
                embedding = data["data"][0]["embedding"]
                return embedding
        except Exception as e:
            logger.error(f"OpenAI-compatible embedding generation failed: {e}")
            return None

    def _generate_embeddings_batch_openai_compatible(
        self,
        texts: list[str],
        model: Model,
        provider: ModelProvider,
        batch_size: int,
        normalize: bool,
    ) -> list[list[float] | None]:
        try:
            import httpx


            base_url = provider.base_url or provider.default_base_url
            if not base_url:
                logger.error(f"No base_url configured for provider {provider.id}")
                return [None] * len(texts)
            headers = {"Content-Type": "application/json"}
            if provider.auth_type == "bearer" and provider.api_key:
                headers["Authorization"] = f"Bearer {provider.api_key}"
            elif provider.auth_type == "api_key" and provider.api_key:
                headers["api-key"] = provider.api_key
            elif provider.auth_type == "x_api_key" and provider.api_key:
                headers["X-API-Key"] = provider.api_key
            model_name = model.remote_model_id or model.code
            url = f"{base_url.rstrip('/')}/embeddings"
            all_embeddings = []
            with httpx.Client(timeout=60.0) as client:
                for i in range(0, len(texts), batch_size):
                    batch = texts[i : i + batch_size]
                    response = client.post(
                        url, headers=headers, json={"input": batch, "model": model_name}
                    )
                    response.raise_for_status()
                    data = response.json()
                    embeddings = [item["embedding"] for item in data["data"]]
                    all_embeddings.extend(embeddings)
            return all_embeddings
        except Exception as e:
            logger.error(f"OpenAI-compatible batch embedding generation failed: {e}")
            return [None] * len(texts)


_embedding_factory = None


def get_embedding_factory() -> EmbeddingBackendFactory:
    global _embedding_factory
    if _embedding_factory is None:
        _embedding_factory = EmbeddingBackendFactory()
    return _embedding_factory
