import asyncio
import json
from collections.abc import Iterator
from typing import Any

import httpx
from sqlalchemy.orm import Session

from app.core.exceptions import ExternalServiceError, ModelNotFoundError, ProviderNotConfiguredError
from app.models.local_model import LocalModel
from app.models.provider import Model, ModelProvider

from common_logging import get_logger
from common_langfuse import trace_llm_call

logger = get_logger(__name__)


_shared_http_client: httpx.AsyncClient | None = None


def _get_shared_client(timeout: float) -> httpx.AsyncClient:
    global _shared_http_client
    if _shared_http_client is None or _shared_http_client.is_closed:
        _shared_http_client = httpx.AsyncClient(
            timeout=timeout, limits=httpx.Limits(max_connections=20, max_keepalive_connections=10)
        )
    return _shared_http_client


class BaseChatClient:

    async def chat_completion(
        self,
        messages: list[dict[str, str]],
        model: str | None = None,
        temperature: float = 0.7,
        max_tokens: int = 2048,
        stream: bool = False,
        enable_table_format: bool = True,
    ) -> dict[str, Any]:
        raise NotImplementedError

    def extract_response_text(self, api_response: dict[str, Any]) -> str:
        raise NotImplementedError

    def get_usage_info(self, api_response: dict[str, Any]) -> dict[str, int]:
        return {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}


class OpenAICompatibleChatClient(BaseChatClient):

    def __init__(
        self, provider: ModelProvider, model: Model, max_retries: int = 3, timeout: float = 300.0
    ):
        self._provider_name = provider.name
        self._base_url = provider.base_url or provider.default_base_url
        self._auth_type = provider.auth_type
        self._api_key = provider.get_api_key()
        self._model_code = model.remote_model_id or model.code
        self.max_retries = max_retries
        self.timeout = timeout

    async def chat_completion(
        self,
        messages: list[dict[str, str]],
        model: str | None = None,
        temperature: float = 0.7,
        max_tokens: int = 2048,
        stream: bool = False,
        enable_table_format: bool = True,
    ) -> dict[str, Any]:
        base_url = self._base_url
        if not base_url:
            raise ProviderNotConfiguredError(self._provider_name)
        headers = {"Content-Type": "application/json"}
        api_key = self._api_key
        if self._auth_type == "bearer" and api_key:
            headers["Authorization"] = f"Bearer {api_key}"
        elif self._auth_type == "api_key" and api_key:
            headers["api-key"] = api_key
        elif self._auth_type == "x_api_key" and api_key:
            headers["X-API-Key"] = api_key
        url = f"{base_url.rstrip('/')}/chat/completions"
        payload = {
            "model": model or self._model_code,
            "messages": messages,
            "temperature": temperature,
            "max_tokens": max_tokens,
            "stream": stream,
            "chat_template_kwargs": {"enable_thinking": False},
        }
        model_name = model or self._model_code
        last_error = None
        for attempt in range(self.max_retries):
            try:
                client = _get_shared_client(self.timeout)
                with trace_llm_call(
                    name=f"{self._provider_name}/chat",
                    model=model_name,
                    messages=messages,
                ) as generation:
                    response = await client.post(url, headers=headers, json=payload)
                    response.raise_for_status()
                    result = response.json()
                    if generation is not None:
                        usage = result.get("usage", {})
                        generation.update(
                            output=result.get("choices", [{}])[0].get("message", {}).get("content", ""),
                            usage={
                                "input": usage.get("prompt_tokens", 0),
                                "output": usage.get("completion_tokens", 0),
                            },
                        )
                    return result
            except httpx.TimeoutException as e:
                last_error = e
                logger.warning(f"Chat API timeout on attempt {attempt + 1}/{self.max_retries}: {e}")
                if attempt < self.max_retries - 1:
                    await asyncio.sleep(2**attempt)
                    continue
            except httpx.HTTPStatusError as e:
                last_error = e
                status_code = e.response.status_code
                if 500 <= status_code < 600 and attempt < self.max_retries - 1:
                    logger.warning(
                        f"Chat API server error {status_code} on attempt {attempt + 1}/{self.max_retries}"
                    )
                    await asyncio.sleep(2**attempt)
                    continue
                raise ExternalServiceError(
                    self._provider_name, f"HTTP {status_code}: {e.response.text}"
                ) from e
            except httpx.RequestError as e:
                last_error = e
                logger.warning(
                    f"Chat API request error on attempt {attempt + 1}/{self.max_retries}: {e}"
                )
                if attempt < self.max_retries - 1:
                    await asyncio.sleep(2**attempt)
                    continue
            except Exception as e:
                raise ExternalServiceError(self._provider_name, str(e)) from None
        if isinstance(last_error, httpx.TimeoutException):
            raise ExternalServiceError(
                self._provider_name, f"Request timeout after {self.timeout}s"
            )
        raise ExternalServiceError(
            self._provider_name, f"Failed after {self.max_retries} retries: {last_error}"
        )

    async def async_stream_completion(
        self, messages: list[dict[str, str]], temperature: float = 0.7, max_tokens: int = 2048
    ):
        base_url = self._base_url
        if not base_url:
            raise ProviderNotConfiguredError(self._provider_name)
        headers = {"Content-Type": "application/json"}
        api_key = self._api_key
        if self._auth_type == "bearer" and api_key:
            headers["Authorization"] = f"Bearer {api_key}"
        elif self._auth_type == "api_key" and api_key:
            headers["api-key"] = api_key
        elif self._auth_type == "x_api_key" and api_key:
            headers["X-API-Key"] = api_key
        url = f"{base_url.rstrip('/')}/chat/completions"
        payload = {
            "model": self._model_code,
            "messages": messages,
            "temperature": temperature,
            "max_tokens": max_tokens,
            "stream": True,
            "chat_template_kwargs": {"enable_thinking": False},
        }
        client = _get_shared_client(self.timeout)
        async with client.stream("POST", url, headers=headers, json=payload) as response:
            response.raise_for_status()
            async for line in response.aiter_lines():
                if not line.startswith("data:"):
                    continue
                data = line[len("data:") :].strip()
                if data == "[DONE]":
                    return
                try:
                    chunk_data = json.loads(data)
                    delta = chunk_data["choices"][0].get("delta", {})
                    text = delta.get("content", "")
                    if text:
                        yield text
                except (json.JSONDecodeError, KeyError, IndexError):
                    continue

    def extract_response_text(self, api_response: dict[str, Any]) -> str:
        try:
            choices = api_response.get("choices", [])
            if choices:
                message = choices[0].get("message", {})
                content = message.get("content", "")
                if isinstance(content, list):
                    return "".join(
                        part.get("text", "") for part in content if isinstance(part, dict)
                    )
                return content or ""
            return "Sorry, I cannot generate a response."
        except Exception as e:
            logger.error(f"Failed to extract response text: {e}")
            return "Sorry, an error occurred while processing the response."

    def get_usage_info(self, api_response: dict[str, Any]) -> dict[str, int]:
        try:
            usage = api_response.get("usage", {})
            prompt_tokens = usage.get("prompt_tokens", usage.get("input_tokens", 0))
            completion_tokens = usage.get("completion_tokens", usage.get("output_tokens", 0))
            total_tokens = usage.get("total_tokens", prompt_tokens + completion_tokens)
            return {
                "input_tokens": prompt_tokens,
                "output_tokens": completion_tokens,
                "total_tokens": total_tokens,
            }
        except Exception as e:
            logger.error(f"Failed to get usage info: {e}")
            return super().get_usage_info(api_response)


class LocalAssetChatClient(BaseChatClient):

    def __init__(self, local_model: LocalModel):
        self.local_model = local_model

    async def chat_completion(
        self,
        messages: list[dict[str, str]],
        model: str | None = None,
        temperature: float = 0.7,
        max_tokens: int = 2048,
        stream: bool = False,
        enable_table_format: bool = True,
    ) -> dict[str, Any]:
        from app.services.llm.local_chat_service import local_chat_service

        system_parts: list[str] = []
        user_prompt = ""
        for message in messages:
            role = message.get("role")
            content = message.get("content", "")
            if role == "system" and content:
                system_parts.append(content)
            elif role == "user":
                user_prompt = content
        response_text = local_chat_service.generate_response(
            identifier=self.local_model.identifier,
            model_path=self.local_model.model_path,
            prompt=user_prompt,
            max_tokens=max_tokens,
            temperature=temperature,
            system_prompt="\n\n".join(system_parts).strip() or None,
        )
        return {
            "choices": [{"message": {"role": "assistant", "content": response_text}}],
            "usage": {},
        }

    def stream_completion(
        self, messages: list[dict[str, str]], temperature: float = 0.7, max_tokens: int = 2048
    ) -> Iterator[str]:
        from app.services.llm.local_chat_service import local_chat_service

        system_parts: list[str] = []
        user_prompt = ""
        for message in messages:
            role = message.get("role")
            content = message.get("content", "")
            if role == "system" and content:
                system_parts.append(content)
            elif role == "user":
                user_prompt = content
        return local_chat_service.generate_response_stream(
            identifier=self.local_model.identifier,
            model_path=self.local_model.model_path,
            prompt=user_prompt,
            max_tokens=max_tokens,
            temperature=temperature,
            system_prompt="\n\n".join(system_parts).strip() or None,
        )

    async def async_stream_completion(
        self, messages: list[dict[str, str]], temperature: float = 0.7, max_tokens: int = 2048
    ):
        from app.services.llm.local_chat_service import local_chat_service


        system_parts: list[str] = []
        user_prompt = ""
        for message in messages:
            role = message.get("role")
            content = message.get("content", "")
            if role == "system" and content:
                system_parts.append(content)
            elif role == "user":
                user_prompt = content
        stream = local_chat_service.generate_response_stream(
            identifier=self.local_model.identifier,
            model_path=self.local_model.model_path,
            prompt=user_prompt,
            max_tokens=max_tokens,
            temperature=temperature,
            system_prompt="\n\n".join(system_parts).strip() or None,
        )
        for chunk in stream:
            yield chunk
            await asyncio.sleep(0)

    def extract_response_text(self, api_response: dict[str, Any]) -> str:
        try:
            return api_response["choices"][0]["message"]["content"]
        except Exception:
            return ""


class ChatBackendFactory:

    def get_client(self, model_id: int, db: Session) -> BaseChatClient:
        model = db.query(Model).filter(Model.id == model_id).first()
        if not model:
            raise ModelNotFoundError(model_id)
        provider = model.provider
        if not provider:
            raise ExternalServiceError("chat", f"Provider not found for model {model_id}")
        if provider.protocol in {"openai_compatible", "openai"}:
            return OpenAICompatibleChatClient(provider=provider, model=model)
        raise ExternalServiceError(provider.name, f"Unsupported chat protocol: {provider.protocol}")

    def get_local_asset_client(
        self, identifier: str, db: Session
    ) -> LocalAssetChatClient | None:
        local_model = (
            db.query(LocalModel)
            .filter(
                LocalModel.identifier == identifier,
                LocalModel.status == "active",
                not LocalModel.is_deleted,
            )
            .first()
        )
        if not local_model:
            return None
        return LocalAssetChatClient(local_model)


_chat_factory = None


def get_chat_factory() -> ChatBackendFactory:
    global _chat_factory
    if _chat_factory is None:
        _chat_factory = ChatBackendFactory()
    return _chat_factory
