from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware

from .translator import set_locale
from common_logging import get_logger

logger = get_logger(__name__)


class LanguageMiddleware(BaseHTTPMiddleware):

    def __init__(self, app, default_locale: str = "zh_CN"):
        super().__init__(app)
        self.default_locale = default_locale

    async def dispatch(self, request: Request, call_next):
        locale = self._detect_locale(request)
        set_locale(locale)
        logger.info(f"Language set to: {locale}", lang=locale)
        response = await call_next(request)
        return response

    def _detect_locale(self, request: Request) -> str:
        lang_param = request.query_params.get("lang")
        if lang_param:
            return self._normalize_locale(lang_param)
        accept_language = request.headers.get("accept-language")
        if accept_language:
            languages = accept_language.split(",")
            if languages:
                primary_lang = languages[0].split(";")[0].strip()
                return self._normalize_locale(primary_lang)
        return self.default_locale

    def _normalize_locale(self, locale: str) -> str:
        locale = locale.lower().replace("-", "_")
        locale_map = {
            "en": "en_US",
            "en_us": "en_US",
            "zh": "zh_CN",
            "zh_cn": "zh_CN",
            "zh_hans": "zh_CN",
        }
        return locale_map.get(locale, self.default_locale)
