import gc
import json
from collections.abc import Iterator
from pathlib import Path

from mlx_lm import load, stream_generate
from mlx_lm.sample_utils import make_sampler

from app.config import settings
from common_logging import get_logger, log_performance
from common_metrics import llm_request_duration, llm_first_token_seconds

logger = get_logger(__name__)


class LocalChatService:
    DEFAULT_CHAT_BASE_MODEL = "Qwen3.5-4B-MLX"
    DEFAULT_MAX_TOKENS = 128
    MAX_TOKENS_CAP = 1024
    DEFAULT_NUM_LAYERS = 16
    DEFAULT_LORA_RANK = 16
    DEFAULT_LORA_ALPHA = 32
    DEFAULT_LORA_DROPOUT = 0.05

    def __init__(self):
        self.loaded_models: dict[str, tuple] = {}

    def _unload_identifier(self, identifier: str) -> None:
        loaded = self.loaded_models.pop(identifier, None)
        if not loaded:
            return
        model, tokenizer = loaded
        del model
        del tokenizer
        gc.collect()

    def unload_all_models(self, keep_identifier: str | None = None) -> None:
        identifiers = list(self.loaded_models.keys())
        for identifier in identifiers:
            if keep_identifier and identifier == keep_identifier:
                continue
            self._unload_identifier(identifier)

    def _ensure_adapter_config_compatibility(self, adapter_path: Path) -> None:
        config_path = adapter_path / "adapter_config.json"
        if not config_path.exists():
            raise FileNotFoundError(f"Adapter config not found: {config_path}")
        with config_path.open("r", encoding="utf-8") as f:
            config = json.load(f)
        if not isinstance(config, dict):
            raise ValueError(f"Invalid adapter config format: {config_path}")
        updated = False
        if "fine_tune_type" not in config:
            config["fine_tune_type"] = "lora"
            updated = True
        if "num_layers" not in config:
            config["num_layers"] = self.DEFAULT_NUM_LAYERS
            updated = True
        lora_rank = int(config.get("lora_rank", self.DEFAULT_LORA_RANK))
        lora_alpha = float(config.get("lora_alpha", self.DEFAULT_LORA_ALPHA))
        lora_dropout = float(config.get("lora_dropout", self.DEFAULT_LORA_DROPOUT))
        lora_parameters = config.get("lora_parameters")
        if not isinstance(lora_parameters, dict):
            lora_parameters = {}
            updated = True
        if "rank" not in lora_parameters:
            lora_parameters["rank"] = lora_rank
            updated = True
        if "dropout" not in lora_parameters:
            lora_parameters["dropout"] = lora_dropout
            updated = True
        if "scale" not in lora_parameters:
            lora_parameters["scale"] = lora_alpha / max(lora_rank, 1)
            updated = True
        if updated:
            config["lora_parameters"] = lora_parameters
            with config_path.open("w", encoding="utf-8") as f:
                json.dump(config, f, indent=2)

    def _resolve_base_model_path(self, adapter_path: Path) -> str:
        config_path = adapter_path / "adapter_config.json"
        config = {}
        if config_path.exists():
            with config_path.open("r", encoding="utf-8") as f:
                loaded = json.load(f)
                if isinstance(loaded, dict):
                    config = loaded
        candidate_names = [
            config.get("base_model_name"),
            config.get("base_model"),
            config.get("model"),
            self.DEFAULT_CHAT_BASE_MODEL,
        ]
        for candidate in candidate_names:
            if not candidate:
                continue
            candidate_path = Path(candidate)
            if candidate_path.exists():
                return str(candidate_path)
            model_dir_name = candidate_path.name if candidate_path.name else str(candidate)
            resolved_path = Path(settings.BASE_MODELS_DIR) / model_dir_name
            if resolved_path.exists():
                return str(resolved_path)
        return str(Path(settings.BASE_MODELS_DIR) / self.DEFAULT_CHAT_BASE_MODEL)

    def load_model(self, model_path: str, identifier: str):
        if identifier in self.loaded_models:
            return self.loaded_models[identifier]
        if settings.LOCAL_CHAT_SINGLE_MODEL_CACHE:
            self.unload_all_models()
        adapter_path = Path(model_path)
        if not adapter_path.exists():
            raise FileNotFoundError(f"Model not found: {adapter_path}")
        self._ensure_adapter_config_compatibility(adapter_path)
        base_model_path = self._resolve_base_model_path(adapter_path)
        model, tokenizer = load(base_model_path, adapter_path=str(adapter_path))
        self.loaded_models[identifier] = (model, tokenizer)
        return (model, tokenizer)

    def _prepare_prompt(self, tokenizer, prompt: str, system_prompt: str | None = None) -> str:
        prompt_text = prompt.strip()
        if not prompt_text:
            raise ValueError("Prompt cannot be empty")
        if not getattr(tokenizer, "has_chat_template", False):
            return prompt_text
        messages = []
        if system_prompt and system_prompt.strip():
            messages.append({"role": "system", "content": system_prompt.strip()})
        messages.append({"role": "user", "content": prompt_text})
        template_kwargs = {"tokenize": False, "add_generation_prompt": True}
        try:
            return tokenizer.apply_chat_template(messages, enable_thinking=False, **template_kwargs)
        except TypeError:
            return tokenizer.apply_chat_template(messages, **template_kwargs)

    def _resolve_max_tokens(self, max_tokens: int) -> int:
        requested = max_tokens or self.DEFAULT_MAX_TOKENS
        return max(1, min(requested, self.MAX_TOKENS_CAP))

    @log_performance(threshold_ms=3000)
    def generate_response(
        self,
        identifier: str,
        model_path: str,
        prompt: str,
        max_tokens: int = 512,
        temperature: float = 0.7,
        system_prompt: str | None = None,
    ) -> str:
        return "".join(
            self.generate_response_stream(
                identifier=identifier,
                model_path=model_path,
                prompt=prompt,
                max_tokens=max_tokens,
                temperature=temperature,
                system_prompt=system_prompt,
            )
        )

    def generate_response_stream(
        self,
        identifier: str,
        model_path: str,
        prompt: str,
        max_tokens: int = 512,
        temperature: float = 0.7,
        system_prompt: str | None = None,
    ) -> Iterator[str]:
        import time

        model, tokenizer = self.load_model(model_path, identifier)
        prompt_text = self._prepare_prompt(tokenizer, prompt, system_prompt)
        effective_max_tokens = self._resolve_max_tokens(max_tokens)
        sampler = make_sampler(temp=max(temperature, 0.0))
        yielded_chars = 0
        first_token_recorded = False
        t_start = time.perf_counter()
        for response in stream_generate(
            model, tokenizer, prompt=prompt_text, max_tokens=effective_max_tokens, sampler=sampler
        ):
            if response.text:
                if not first_token_recorded:
                    if llm_first_token_seconds:
                        llm_first_token_seconds.labels(model=identifier).observe(
                            time.perf_counter() - t_start
                        )
                    first_token_recorded = True
                yielded_chars += len(response.text)
                yield response.text
        if llm_request_duration:
            llm_request_duration.labels(model=identifier).observe(time.perf_counter() - t_start)
        logger.bind(
            identifier=identifier,
            requested_max_tokens=max_tokens,
            effective_max_tokens=effective_max_tokens,
            yielded_chars=yielded_chars,
        ).info("Local model response generated")


local_chat_service = LocalChatService()
