from typing import Any

from sqlalchemy.orm import Session

from app.models.provider import Model, ModelProvider

from common_logging import get_logger

logger = get_logger(__name__)




class RerankBackendFactory:

    def rerank(
        self,
        query: str,
        documents: list[dict[str, Any]],
        model_id: int,
        db: Session,
        top_k: int = 5,
    ) -> list[dict[str, Any]]:
        model = db.query(Model).filter(Model.id == model_id).first()
        if not model:
            logger.error(f"Model {model_id} not found")
            return documents[:top_k]
        provider = model.provider
        if not provider:
            logger.error(f"Provider not found for model {model_id}")
            return documents[:top_k]
        if provider.protocol == "openai_compatible":
            return self._rerank_openai_compatible(query, documents, model, provider, top_k)
        elif provider.protocol == "custom_http":
            return self._rerank_custom_http(query, documents, model, provider, top_k)
        else:
            logger.error(f"Unsupported provider protocol for rerank: {provider.protocol}")
            return documents[:top_k]

    def _rerank_openai_compatible(
        self,
        query: str,
        documents: list[dict[str, Any]],
        model: Model,
        provider: ModelProvider,
        top_k: int,
    ) -> list[dict[str, Any]]:
        try:
            import httpx

            base_url = provider.base_url or provider.default_base_url
            if not base_url:
                logger.error(f"No base_url configured for provider {provider.id}")
                return documents[:top_k]
            headers = {"Content-Type": "application/json"}
            if provider.auth_type == "bearer" and provider.api_key:
                headers["Authorization"] = f"Bearer {provider.api_key}"
            elif provider.auth_type == "api_key" and provider.api_key:
                headers["api-key"] = provider.api_key
            elif provider.auth_type == "x_api_key" and provider.api_key:
                headers["X-API-Key"] = provider.api_key
            model_name = model.remote_model_id or model.code
            url = f"{base_url.rstrip('/')}/rerank"
            texts = [doc.get("text", "") or doc.get("chunk_text", "") for doc in documents]
            with httpx.Client(timeout=30.0) as client:
                response = client.post(
                    url,
                    headers=headers,
                    json={"query": query, "documents": texts, "model": model_name, "top_n": top_k},
                )
                response.raise_for_status()
                data = response.json()
                results = data.get("results", [])
                reranked_docs = []
                for result in results[:top_k]:
                    idx = result.get("index", 0)
                    score = result.get("relevance_score", 0.0)
                    if idx < len(documents):
                        doc = documents[idx].copy()
                        doc["rerank_score"] = score
                        reranked_docs.append(doc)
                return reranked_docs if reranked_docs else documents[:top_k]
        except Exception as e:
            logger.error(f"OpenAI-compatible rerank failed: {e}")
            return documents[:top_k]

    def _rerank_custom_http(
        self,
        query: str,
        documents: list[dict[str, Any]],
        model: Model,
        provider: ModelProvider,
        top_k: int,
    ) -> list[dict[str, Any]]:
        try:
            import httpx


            base_url = provider.base_url or provider.default_base_url
            if not base_url:
                logger.error(f"No base_url configured for provider {provider.id}")
                return documents[:top_k]
            headers = {"Content-Type": "application/json"}
            if provider.auth_type == "bearer" and provider.api_key:
                headers["Authorization"] = f"Bearer {provider.api_key}"
            elif provider.auth_type == "api_key" and provider.api_key:
                headers["api-key"] = provider.api_key
            elif provider.auth_type == "x_api_key" and provider.api_key:
                headers["X-API-Key"] = provider.api_key
            endpoint_path = (
                provider.extra_config.get("rerank_path", "/rerank")
                if provider.extra_config
                else "/rerank"
            )
            url = f"{base_url.rstrip('/')}{endpoint_path}"
            texts = [doc.get("text", "") or doc.get("chunk_text", "") for doc in documents]
            with httpx.Client(timeout=30.0) as client:
                response = client.post(
                    url, headers=headers, json={"query": query, "documents": texts, "top_k": top_k}
                )
                response.raise_for_status()
                data = response.json()
                if "results" in data:
                    results = data["results"]
                    reranked_docs = []
                    for result in results[:top_k]:
                        idx = result.get("index", result.get("doc_index", 0))
                        score = result.get("score", result.get("relevance_score", 0.0))
                        if idx < len(documents):
                            doc = documents[idx].copy()
                            doc["rerank_score"] = score
                            reranked_docs.append(doc)
                    return reranked_docs if reranked_docs else documents[:top_k]
                else:
                    logger.warning(f"Unexpected rerank response format: {data}")
                    return documents[:top_k]
        except Exception as e:
            logger.error(f"Custom HTTP rerank failed: {e}")
            return documents[:top_k]


_rerank_factory = None


def get_rerank_factory() -> RerankBackendFactory:
    global _rerank_factory
    if _rerank_factory is None:
        _rerank_factory = RerankBackendFactory()
    return _rerank_factory
