import hashlib
from collections.abc import Callable
from functools import wraps
from typing import Any

from common_logging import get_logger

logger = get_logger(__name__)


try:
    import orjson

    USE_ORJSON = True
    logger.info("Using orjson for high-performance JSON serialization")
except ImportError:
    import json

    USE_ORJSON = False
    logger.info("Using standard json library (orjson not available)")


class QueryCacheService:

    def __init__(self):
        try:
            import redis

            from app.config import settings


            if settings.REDIS_URL:
                self.redis_client = redis.from_url(settings.REDIS_URL, decode_responses=True)
            else:
                self.redis_client = redis.Redis(
                    host=settings.REDIS_HOST,
                    port=settings.REDIS_PORT,
                    db=settings.REDIS_DB,
                    password=settings.REDIS_PASSWORD,
                    decode_responses=True,
                )
            self.enabled = True
            logger.info("Query cache service initialized with Redis")
        except Exception as e:
            logger.warning(f"Failed to initialize Redis cache: {e}. Caching disabled.")
            self.redis_client = None
            self.enabled = False

    def _generate_cache_key(self, prefix: str, tenant_id: int, *args, **kwargs) -> str:
        key_data = f"{prefix}:{tenant_id}:{str(args)}:{str(sorted(kwargs.items()))}"
        key_hash = hashlib.sha256(key_data.encode()).hexdigest()[:16]
        return f"cache:tenant:{tenant_id}:{prefix}:{key_hash}"

    def get(self, key: str) -> Any | None:
        if not self.enabled:
            return None
        try:
            value = self.redis_client.get(key)
            if value:
                if USE_ORJSON:
                    if isinstance(value, str):
                        value = value.encode("utf-8")
                    return orjson.loads(value)
                else:
                    return json.loads(value)
            return None
        except Exception as e:
            logger.error(f"Cache get error: {e}")
            return None

    def set(self, key: str, value: Any, ttl: int = 300):
        if not self.enabled:
            return
        try:
            if USE_ORJSON:
                serialized = orjson.dumps(value, default=str).decode("utf-8")
            else:
                serialized = json.dumps(value, default=str)
            self.redis_client.setex(key, ttl, serialized)
        except Exception as e:
            logger.error(f"Cache set error: {e}")

    def delete(self, key: str):
        if not self.enabled:
            return
        try:
            self.redis_client.delete(key)
        except Exception as e:
            logger.error(f"Cache delete error: {e}")

    def delete_pattern(self, pattern: str):
        if not self.enabled:
            return
        try:
            keys = self.redis_client.keys(pattern)
            if keys:
                self.redis_client.delete(*keys)
        except Exception as e:
            logger.error(f"Cache delete pattern error: {e}")

    def cached(self, prefix: str, ttl: int = 300):

        def decorator(func: Callable):

            @wraps(func)
            def wrapper(*args, **kwargs):
                tenant_id = kwargs.get("tenant_id")
                if tenant_id is None and len(args) > 0:
                    for arg in args:
                        if hasattr(arg, "tenant_id"):
                            tenant_id = arg.tenant_id
                            break
                if tenant_id is None:
                    logger.warning(
                        f"No tenant_id found for cached function {func.__name__}, skipping cache"
                    )
                    return func(*args, **kwargs)
                cache_key = self._generate_cache_key(prefix, tenant_id, *args, **kwargs)
                cached_value = self.get(cache_key)
                if cached_value is not None:
                    logger.debug(f"Cache hit: {cache_key}")
                    return cached_value
                logger.debug(f"Cache miss: {cache_key}")
                result = func(*args, **kwargs)
                self.set(cache_key, result, ttl)
                return result

            return wrapper

        return decorator

    def invalidate_knowledge_base_cache(self, tenant_id: int, kb_id: int | None = None):
        if kb_id:
            self.delete_pattern(f"cache:tenant:{tenant_id}:knowledge_bases:*")
        else:
            self.delete_pattern(f"cache:tenant:{tenant_id}:knowledge_bases:*")
        logger.info(f"Invalidated knowledge base cache for tenant={tenant_id}, kb_id={kb_id}")

    def invalidate_category_cache(self, tenant_id: int, category_id: int | None = None):
        if category_id:
            self.delete_pattern(f"cache:tenant:{tenant_id}:categories:*")
        else:
            self.delete_pattern(f"cache:tenant:{tenant_id}:categories:*")
        logger.info(f"Invalidated category cache for tenant={tenant_id}, category_id={category_id}")

    def invalidate_model_config_cache(self, tenant_id: int):
        self.delete_pattern(f"cache:tenant:{tenant_id}:model_configs:*")
        logger.info(f"Invalidated model configuration cache for tenant={tenant_id}")


_cache_service: QueryCacheService | None = None


def get_query_cache_service() -> QueryCacheService:
    global _cache_service
    if _cache_service is None:
        _cache_service = QueryCacheService()
    return _cache_service
