import os
from pathlib import Path

import torch
import torch.nn.functional as F

from app.config import settings
from common_logging import get_logger, log_performance

logger = get_logger(__name__)


class LocalEmbeddingService:

    def __init__(self, model_name: str = "mlx-community/bge-m3-mlx-fp16"):
        self.model_name = model_name
        self.model = None
        self.tokenizer = None
        self.dimension = None
        self.backend = None
        self.device = self._get_device()
        logger.info(
            f"LocalEmbeddingService initialized with model: {model_name}, device: {self.device}"
        )

    def _get_device(self) -> str:
        preferred = (settings.LOCAL_BGE_DEVICE or "auto").lower()
        if preferred == "cpu":
            return "cpu"
        if preferred == "mps":
            return "mps" if torch.backends.mps.is_available() else "cpu"
        if preferred == "cuda":
            return "cuda" if torch.cuda.is_available() else "cpu"
        if torch.backends.mps.is_available():
            return "mps"
        elif torch.cuda.is_available():
            return "cuda"
        else:
            return "cpu"

    def _resolve_model_source(self) -> str:
        env_model_path = os.getenv("LOCAL_BGE_MODEL_PATH")
        if env_model_path and Path(env_model_path).exists():
            return env_model_path
        model_path = Path(self.model_name)
        if model_path.exists():
            return str(model_path)
        model_aliases = {
            "mlx-community/bge-m3-mlx-fp16": "BAAI/bge-m3",
            "BAAI/bge-m3": "BAAI/bge-m3",
        }
        local_name = model_aliases.get(self.model_name, model_path.name or self.model_name)
        local_default = Path(settings.BASE_MODELS_DIR) / local_name
        if local_default.exists():
            return str(local_default)
        return self.model_name

    def _load_model(self):
        if self.model is not None:
            return
        model_source = self._resolve_model_source()
        try:
            logger.info(f"Loading BGE model with sentence-transformers: {model_source}")
            from sentence_transformers import SentenceTransformer

            self.model = SentenceTransformer(model_source, device=self.device)
            self.tokenizer = None
            self.dimension = self.model.get_sentence_embedding_dimension()
            self.backend = "sentence_transformers"
            logger.info(f"BGE model loaded with sentence-transformers on device: {self.device}")
            return
        except Exception as e:
            logger.warning(f"Sentence-transformers load failed, fallback to transformers: {e}")
        try:
            from transformers import AutoModel, AutoTokenizer

            logger.info(f"Loading BGE model with transformers: {model_source}")
            self.tokenizer = AutoTokenizer.from_pretrained(model_source, local_files_only=True)
            self.model = AutoModel.from_pretrained(model_source, local_files_only=True)
            self.model.to(self.device)
            self.model.eval()
            hidden_size = getattr(self.model.config, "hidden_size", None)
            if hidden_size:
                self.dimension = int(hidden_size)
            self.backend = "transformers"
            logger.info(f"BGE model loaded with transformers on device: {self.device}")
        except Exception as e:
            logger.error(f"Failed to load BGE model: {e}")
            raise

    @log_performance(threshold_ms=1000)
    def encode(self, text: str, normalize: bool = True) -> list[float] | None:
        if not text or not text.strip():
            logger.warning("Empty text provided for encoding")
            return None
        try:
            self._load_model()
            if self.backend == "sentence_transformers":
                embedding = self.model.encode(
                    text, normalize_embeddings=normalize, convert_to_numpy=True
                )
                embedding_list = embedding.tolist()
                logger.debug(f"Encoded text to {len(embedding_list)}-dim vector")
                return embedding_list
            inputs = self.tokenizer(
                text,
                return_tensors="pt",
                truncation=True,
                padding=True,
                max_length=settings.LOCAL_BGE_MAX_LENGTH,
            )
            inputs = {k: v.to(self.device) for k, v in inputs.items()}
            with torch.no_grad():
                outputs = self.model(**inputs)
            token_embeddings = outputs.last_hidden_state
            attention_mask = inputs.get("attention_mask")
            if attention_mask is None:
                pooled = token_embeddings.mean(dim=1)
            else:
                mask = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
                sum_embeddings = torch.sum(token_embeddings * mask, dim=1)
                sum_mask = torch.clamp(mask.sum(dim=1), min=1e-09)
                pooled = sum_embeddings / sum_mask
            if normalize:
                pooled = F.normalize(pooled, p=2, dim=1)
            embedding_list = pooled[0].detach().cpu().tolist()
            logger.debug(f"Encoded text to {len(embedding_list)}-dim vector")
            return embedding_list
        except Exception as e:
            logger.error(f"Failed to encode text: {e}")
            return None
        finally:
            if settings.LOCAL_BGE_AUTO_UNLOAD:
                self.unload_model()

    @log_performance(threshold_ms=1000)
    def encode_batch(
        self,
        texts: list[str],
        batch_size: int = settings.LOCAL_BGE_BATCH_SIZE,
        normalize: bool = True,
    ) -> list[list[float] | None]:
        if not texts:
            logger.warning("Empty text list provided for batch encoding")
            return []
        try:
            self._load_model()
            batch_size = max(1, min(batch_size, settings.LOCAL_BGE_BATCH_SIZE))
            if self.backend == "sentence_transformers":
                embeddings = self.model.encode(
                    texts,
                    batch_size=batch_size,
                    normalize_embeddings=normalize,
                    convert_to_numpy=True,
                    show_progress_bar=False,
                )
                embeddings_list = [emb.tolist() for emb in embeddings]
                logger.info(
                    f"Batch encoded {len(texts)} texts to {len(embeddings_list[0])}-dim vectors"
                )
                return embeddings_list
            all_embeddings: list[list[float]] = []
            for i in range(0, len(texts), batch_size):
                batch_texts = texts[i : i + batch_size]
                inputs = self.tokenizer(
                    batch_texts,
                    return_tensors="pt",
                    truncation=True,
                    padding=True,
                    max_length=settings.LOCAL_BGE_MAX_LENGTH,
                )
                inputs = {k: v.to(self.device) for k, v in inputs.items()}
                with torch.no_grad():
                    outputs = self.model(**inputs)
                token_embeddings = outputs.last_hidden_state
                attention_mask = inputs.get("attention_mask")
                if attention_mask is None:
                    pooled = token_embeddings.mean(dim=1)
                else:
                    mask = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
                    sum_embeddings = torch.sum(token_embeddings * mask, dim=1)
                    sum_mask = torch.clamp(mask.sum(dim=1), min=1e-09)
                    pooled = sum_embeddings / sum_mask
                if normalize:
                    pooled = F.normalize(pooled, p=2, dim=1)
                all_embeddings.extend(pooled.detach().cpu().tolist())
            logger.info(f"Batch encoded {len(texts)} texts to {len(all_embeddings[0])}-dim vectors")
            return all_embeddings
        except Exception as e:
            logger.error(f"Failed to batch encode texts: {e}")
            return [None] * len(texts)
        finally:
            if settings.LOCAL_BGE_AUTO_UNLOAD:
                self.unload_model()

    def get_dimension(self) -> int:
        self._load_model()
        if self.dimension is not None:
            return self.dimension
        if self.backend == "sentence_transformers":
            return self.model.get_sentence_embedding_dimension()
        hidden_size = getattr(self.model.config, "hidden_size", None)
        return int(hidden_size) if hidden_size else 1024

    def unload_model(self):
        if self.model is not None:
            logger.info(f"Unloading BGE model: {self.model_name}")
            del self.model
            self.model = None
            if self.tokenizer is not None:
                del self.tokenizer
                self.tokenizer = None
            self.dimension = None
            self.backend = None
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            elif torch.backends.mps.is_available():
                torch.mps.empty_cache()
            import gc

            gc.collect()
            logger.info("BGE model unloaded and memory cleared")


_local_embedding_services = {}


def get_local_embedding_service(
    model_name: str = "mlx-community/bge-m3-mlx-fp16",
) -> LocalEmbeddingService:
    global _local_embedding_services
    if model_name not in _local_embedding_services:
        _local_embedding_services[model_name] = LocalEmbeddingService(model_name)
    return _local_embedding_services[model_name]


def unload_local_embedding_service(model_name: str | None = None):
    global _local_embedding_services
    if model_name is not None:
        service = _local_embedding_services.get(model_name)
        if service is not None:
            service.unload_model()
            del _local_embedding_services[model_name]
    else:
        for service in _local_embedding_services.values():
            service.unload_model()
        _local_embedding_services.clear()
