import json
from contextvars import ContextVar
from pathlib import Path
from typing import Any

from common_logging import get_logger

logger = get_logger(__name__)
current_locale: ContextVar[str] = ContextVar("current_locale", default="zh_CN")


class Translator:

    def __init__(self, locale: str = "zh_CN"):
        self.locale = locale
        self.translations: dict[str, Any] = {}
        self._load_translations()

    def _load_translations(self):
        locale_file = Path(__file__).parent / "locales" / f"{self.locale}.json"
        if locale_file.exists():
            with open(locale_file, encoding="utf-8") as f:
                self.translations = json.load(f)
            logger.info(f"Translations loaded for locale: {self.locale}")
        else:
            logger.warning(f"Translation file not found for locale: {self.locale}, falling back to zh_CN")
            fallback_file = Path(__file__).parent / "locales" / "zh_CN.json"
            if fallback_file.exists():
                with open(fallback_file, encoding="utf-8") as f:
                    self.translations = json.load(f)

    def t(self, key: str, **kwargs) -> str:
        keys = key.split(".")
        value = self.translations
        try:
            for k in keys:
                value = value[k]
            if isinstance(value, str) and kwargs:
                return value.format(**kwargs)
            return value if isinstance(value, str) else key
        except (KeyError, TypeError):
            return key

    def get(self, key: str, default: str | None = None, **kwargs) -> str:
        result = self.t(key, **kwargs)
        if result == key and default is not None:
            return default
        return result


_translators: dict[str, Translator] = {}


def get_translator(locale: str | None = None) -> Translator:
    if locale is None:
        locale = current_locale.get()
    if locale not in _translators:
        _translators[locale] = Translator(locale)
    return _translators[locale]


def set_locale(locale: str):
    current_locale.set(locale)


def get_locale() -> str:
    return current_locale.get()
