import hashlib
import json
from typing import Any

try:
    import redis

    REDIS_AVAILABLE = True
except ImportError:
    REDIS_AVAILABLE = False
from app.config import settings

from common_logging import get_logger

logger = get_logger(__name__)



class GraphCache:

    def __init__(self):
        self.enabled = REDIS_AVAILABLE and settings.ENABLE_KNOWLEDGE_GRAPH
        self.client = None
        if self.enabled:
            try:
                if settings.REDIS_URL:
                    self.client = redis.from_url(
                        settings.REDIS_URL, decode_responses=True, socket_connect_timeout=5
                    )
                else:
                    self.client = redis.Redis(
                        host=settings.REDIS_HOST,
                        port=settings.REDIS_PORT,
                        db=settings.REDIS_DB,
                        password=settings.REDIS_PASSWORD,
                        decode_responses=True,
                        socket_connect_timeout=5,
                    )
                self.client.ping()
                logger.info("Redis cache initialized successfully")
            except Exception as e:
                logger.warning(f"Failed to connect to Redis: {e}. Cache disabled.")
                self.enabled = False
                self.client = None

    def _generate_key(self, prefix: str, tenant_id: int, kb_id: int, **kwargs) -> str:
        sorted_params = sorted(kwargs.items())
        param_str = json.dumps(sorted_params, sort_keys=True)
        param_hash = hashlib.sha256(param_str.encode()).hexdigest()[:16]
        return f"graph:tenant:{tenant_id}:kb:{kb_id}:{prefix}:{param_hash}"

    def get(self, key: str) -> list[dict[str, Any]] | None:
        if not self.enabled or not self.client:
            return None
        try:
            value = self.client.get(key)
            if value:
                logger.debug(f"Cache hit: {key}")
                return json.loads(value)
            logger.debug(f"Cache miss: {key}")
            return None
        except Exception as e:
            logger.error(f"Cache get error: {e}")
            return None

    def set(self, key: str, value: list[dict[str, Any]], ttl: int | None = None) -> bool:
        if not self.enabled or not self.client:
            return False
        try:
            ttl = ttl or settings.GRAPH_CACHE_TTL
            serialized = json.dumps(value)
            self.client.setex(key, ttl, serialized)
            logger.debug(f"Cache set: {key} (TTL: {ttl}s)")
            return True
        except Exception as e:
            logger.error(f"Cache set error: {e}")
            return False

    def delete(self, key: str) -> bool:
        if not self.enabled or not self.client:
            return False
        try:
            self.client.delete(key)
            logger.debug(f"Cache deleted: {key}")
            return True
        except Exception as e:
            logger.error(f"Cache delete error: {e}")
            return False

    def invalidate_document(self, document_id: int, tenant_id: int, kb_id: int):
        if not self.enabled or not self.client:
            return
        try:
            pattern = f"graph:tenant:{tenant_id}:kb:{kb_id}:*:*{document_id}*"
            cursor = 0
            deleted_count = 0
            while True:
                cursor, keys = self.client.scan(cursor, match=pattern, count=100)
                if keys:
                    self.client.delete(*keys)
                    deleted_count += len(keys)
                if cursor == 0:
                    break
            logger.info(
                f"Invalidated {deleted_count} cache entries for document {document_id} in tenant {tenant_id}, kb {kb_id}"
            )
        except Exception as e:
            logger.error(f"Cache invalidation error: {e}")

    def invalidate_knowledge_base(self, kb_id: int, tenant_id: int):
        if not self.enabled or not self.client:
            return
        try:
            pattern = f"graph:tenant:{tenant_id}:kb:{kb_id}:*"
            cursor = 0
            deleted_count = 0
            while True:
                cursor, keys = self.client.scan(cursor, match=pattern, count=100)
                if keys:
                    self.client.delete(*keys)
                    deleted_count += len(keys)
                if cursor == 0:
                    break
            logger.info(
                f"Invalidated {deleted_count} cache entries for kb {kb_id} in tenant {tenant_id}"
            )
        except Exception as e:
            logger.error(f"Cache invalidation error: {e}")

    def clear_all(self):
        if not self.enabled or not self.client:
            return
        try:
            pattern = "graph:*"
            cursor = 0
            deleted_count = 0
            while True:
                cursor, keys = self.client.scan(cursor, match=pattern, count=100)
                if keys:
                    self.client.delete(*keys)
                    deleted_count += len(keys)
                if cursor == 0:
                    break
            logger.info(f"Cleared {deleted_count} cache entries")
        except Exception as e:
            logger.error(f"Cache clear error: {e}")


graph_cache: GraphCache | None = None


def get_graph_cache() -> GraphCache:
    global graph_cache
    if graph_cache is None:
        graph_cache = GraphCache()
    return graph_cache
