import os
from pathlib import Path
from typing import Any

import torch

from app.config import settings
from common_logging import get_logger, log_performance

logger = get_logger(__name__)


class LocalRerankService:

    def __init__(self, model_name: str = "BAAI/bge-reranker-v2-m3"):
        self.model_name = model_name
        self.model = None
        self.tokenizer = None
        self.device = self._get_device()
        logger.info(
            f"LocalRerankService initialized with model: {model_name}, device: {self.device}"
        )

    def _get_device(self) -> str:
        preferred = (settings.LOCAL_BGE_RERANKER_DEVICE or "auto").lower()
        if preferred == "cpu":
            return "cpu"
        if preferred == "mps":
            return "mps" if torch.backends.mps.is_available() else "cpu"
        if preferred == "cuda":
            return "cuda" if torch.cuda.is_available() else "cpu"
        if torch.backends.mps.is_available():
            return "mps"
        elif torch.cuda.is_available():
            return "cuda"
        else:
            return "cpu"

    def _resolve_model_source(self) -> str:
        env_model_path = os.getenv("LOCAL_BGE_RERANKER_MODEL_PATH")
        if env_model_path and Path(env_model_path).exists():
            return env_model_path
        model_path = Path(self.model_name)
        if model_path.exists():
            return str(model_path)
        model_aliases = {"BAAI/bge-reranker-v2-m3": "BAAI/bge-reranker-v2-m3"}
        local_name = model_aliases.get(self.model_name, model_path.name or self.model_name)
        local_default = Path(settings.BASE_MODELS_DIR) / local_name
        if local_default.exists():
            return str(local_default)
        return self.model_name

    def _load_model(self):
        if self.model is not None:
            return
        try:
            from transformers import AutoModelForSequenceClassification, AutoTokenizer

            model_source = self._resolve_model_source()
            logger.info(f"Loading BGE Reranker model: {model_source}")
            self.tokenizer = AutoTokenizer.from_pretrained(
                model_source, local_files_only=True, use_fast=False
            )
            self.model = AutoModelForSequenceClassification.from_pretrained(
                model_source, local_files_only=True
            )
            self.model.to(self.device)
            self.model.eval()
            logger.info(f"BGE Reranker model loaded successfully on device: {self.device}")
        except Exception as e:
            logger.error(f"Failed to load BGE Reranker model: {e}")
            raise

    def unload_model(self):
        if self.model is not None:
            del self.model
            self.model = None
        if self.tokenizer is not None:
            del self.tokenizer
            self.tokenizer = None
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        elif torch.backends.mps.is_available():
            torch.mps.empty_cache()
        import gc

        gc.collect()

    @log_performance(threshold_ms=500)
    def rerank(
        self, query: str, documents: list[dict[str, Any]], top_k: int = 5
    ) -> list[dict[str, Any]]:
        if not documents:
            logger.warning("Empty document list provided for reranking")
            return []
        if len(documents) <= top_k:
            logger.info(
                f"Document count ({len(documents)}) <= top_k ({top_k}), returning all documents"
            )
            return documents
        try:
            self._load_model()
            batch_size = max(1, settings.LOCAL_BGE_RERANKER_BATCH_SIZE)
            max_length = settings.LOCAL_BGE_RERANKER_MAX_LENGTH
            pairs = []
            for doc in documents:
                text = doc.get("text", "") or doc.get("chunk_text", "")
                pairs.append([query, text])
            all_scores = []
            with torch.no_grad():
                for i in range(0, len(pairs), batch_size):
                    batch_pairs = pairs[i : i + batch_size]
                    inputs = self.tokenizer(
                        batch_pairs,
                        padding=True,
                        truncation=True,
                        max_length=max_length,
                        return_tensors="pt",
                    ).to(self.device)
                    scores = self.model(**inputs, return_dict=True).logits.view(-1).float()
                    all_scores.extend(scores.cpu().tolist())
            doc_score_pairs = list(zip(documents, all_scores, strict=False))
            doc_score_pairs.sort(key=lambda x: x[1], reverse=True)
            reranked_docs = [doc for doc, score in doc_score_pairs[:top_k]]
            for i, (_doc, score) in enumerate(doc_score_pairs[:top_k]):
                reranked_docs[i]["rerank_score"] = float(score)
            logger.info(f"Reranked {len(documents)} documents, returning top {top_k}")
            logger.debug(f"Top 3 rerank scores: {[float(s) for _, s in doc_score_pairs[:3]]}")
            return reranked_docs
        except Exception as e:
            logger.error(f"Failed to rerank documents: {e}")
            import traceback

            logger.error(traceback.format_exc())
            return documents[:top_k]
        finally:
            if settings.LOCAL_BGE_RERANKER_AUTO_UNLOAD:
                self.unload_model()


_local_rerank_service = None


def unload_local_rerank_service() -> None:
    global _local_rerank_service
    if _local_rerank_service is not None:
        _local_rerank_service.unload_model()
        _local_rerank_service = None


def get_local_rerank_service(model_name: str = "BAAI/bge-reranker-v2-m3") -> LocalRerankService:
    global _local_rerank_service
    if _local_rerank_service is None:
        _local_rerank_service = LocalRerankService(model_name)
    return _local_rerank_service
