import time
from datetime import datetime
from functools import wraps
from typing import Any

from common_logging import get_logger

logger = get_logger(__name__)



class GraphMetrics:

    def __init__(self):
        self.metrics = {
            "graph_builds": 0,
            "graph_queries": 0,
            "cache_hits": 0,
            "cache_misses": 0,
            "entity_extractions": 0,
            "errors": 0,
            "total_build_time": 0.0,
            "total_query_time": 0.0,
            "total_extraction_time": 0.0,
        }
        self.operation_history = []

    def record_graph_build(self, duration: float, success: bool, entity_count: int):
        self.metrics["graph_builds"] += 1
        self.metrics["total_build_time"] += duration
        if not success:
            self.metrics["errors"] += 1
        self.operation_history.append(
            {
                "type": "graph_build",
                "timestamp": datetime.now().isoformat(),
                "duration": duration,
                "success": success,
                "entity_count": entity_count,
            }
        )
        logger.info(
            f"Graph build completed: duration={duration:.2f}s, success={success}, entities={entity_count}"
        )

    def record_graph_query(self, duration: float, result_count: int, cache_hit: bool):
        self.metrics["graph_queries"] += 1
        self.metrics["total_query_time"] += duration
        if cache_hit:
            self.metrics["cache_hits"] += 1
        else:
            self.metrics["cache_misses"] += 1
        self.operation_history.append(
            {
                "type": "graph_query",
                "timestamp": datetime.now().isoformat(),
                "duration": duration,
                "result_count": result_count,
                "cache_hit": cache_hit,
            }
        )
        logger.debug(
            f"Graph query completed: duration={duration:.3f}s, results={result_count}, cache_hit={cache_hit}"
        )

    def record_entity_extraction(self, duration: float, entity_count: int):
        self.metrics["entity_extractions"] += 1
        self.metrics["total_extraction_time"] += duration
        self.operation_history.append(
            {
                "type": "entity_extraction",
                "timestamp": datetime.now().isoformat(),
                "duration": duration,
                "entity_count": entity_count,
            }
        )
        logger.info(
            f"Entity extraction completed: duration={duration:.2f}s, entities={entity_count}"
        )

    def record_error(self, operation: str, error: str):
        self.metrics["errors"] += 1
        self.operation_history.append(
            {
                "type": "error",
                "timestamp": datetime.now().isoformat(),
                "operation": operation,
                "error": error,
            }
        )
        logger.error(f"Graph operation error: {operation} - {error}")

    def get_metrics(self) -> dict[str, Any]:
        cache_total = self.metrics["cache_hits"] + self.metrics["cache_misses"]
        cache_hit_rate = self.metrics["cache_hits"] / cache_total if cache_total > 0 else 0.0
        avg_build_time = (
            self.metrics["total_build_time"] / self.metrics["graph_builds"]
            if self.metrics["graph_builds"] > 0
            else 0.0
        )
        avg_query_time = (
            self.metrics["total_query_time"] / self.metrics["graph_queries"]
            if self.metrics["graph_queries"] > 0
            else 0.0
        )
        avg_extraction_time = (
            self.metrics["total_extraction_time"] / self.metrics["entity_extractions"]
            if self.metrics["entity_extractions"] > 0
            else 0.0
        )
        return {
            "total_operations": sum(
                [
                    self.metrics["graph_builds"],
                    self.metrics["graph_queries"],
                    self.metrics["entity_extractions"],
                ]
            ),
            "graph_builds": self.metrics["graph_builds"],
            "graph_queries": self.metrics["graph_queries"],
            "entity_extractions": self.metrics["entity_extractions"],
            "errors": self.metrics["errors"],
            "cache_hit_rate": cache_hit_rate,
            "avg_build_time": avg_build_time,
            "avg_query_time": avg_query_time,
            "avg_extraction_time": avg_extraction_time,
            "total_build_time": self.metrics["total_build_time"],
            "total_query_time": self.metrics["total_query_time"],
            "total_extraction_time": self.metrics["total_extraction_time"],
        }

    def get_recent_operations(self, limit: int = 100) -> list:
        return self.operation_history[-limit:]

    def reset_metrics(self):
        self.metrics = {
            "graph_builds": 0,
            "graph_queries": 0,
            "cache_hits": 0,
            "cache_misses": 0,
            "entity_extractions": 0,
            "errors": 0,
            "total_build_time": 0.0,
            "total_query_time": 0.0,
            "total_extraction_time": 0.0,
        }
        self.operation_history = []
        logger.info("Graph metrics reset")


graph_metrics: GraphMetrics | None = None


def get_graph_metrics() -> GraphMetrics:
    global graph_metrics
    if graph_metrics is None:
        graph_metrics = GraphMetrics()
    return graph_metrics


def monitor_graph_operation(operation_type: str):

    def decorator(func):

        @wraps(func)
        def wrapper(*args, **kwargs):
            metrics = get_graph_metrics()
            start_time = time.time()
            try:
                result = func(*args, **kwargs)
                duration = time.time() - start_time
                if operation_type == "build":
                    success = result.get("success", False)
                    entity_count = result.get("entity_count", 0)
                    metrics.record_graph_build(duration, success, entity_count)
                elif operation_type == "query":
                    result_count = len(result) if isinstance(result, list) else 0
                    cache_hit = kwargs.get("cache_hit", False)
                    metrics.record_graph_query(duration, result_count, cache_hit)
                elif operation_type == "extraction":
                    entity_count = (
                        len(result.get("entities", [])) if isinstance(result, dict) else 0
                    )
                    metrics.record_entity_extraction(duration, entity_count)
                return result
            except Exception as e:
                duration = time.time() - start_time
                metrics.record_error(operation_type, str(e))
                raise

        return wrapper

    return decorator


class GraphLogger:

    @staticmethod
    def log_graph_build_start(kb_id: int, document_count: int):
        logger.bind(event="graph_build_start", kb_id=kb_id, document_count=document_count).info(
            f"Starting graph build for knowledge base {kb_id}"
        )

    @staticmethod
    def log_graph_build_complete(kb_id: int, duration: float, entity_count: int, success: bool):
        msg = f"Graph build {('completed' if success else 'failed')} for kb {kb_id}"
        bound = logger.bind(
            event="graph_build_complete", kb_id=kb_id, duration=duration,
            entity_count=entity_count, success=success,
        )
        if success:
            bound.info(msg)
        else:
            bound.error(msg)

    @staticmethod
    def log_graph_query(query_type: str, duration: float, result_count: int):
        logger.bind(
            event="graph_query", query_type=query_type, duration=duration, result_count=result_count,
        ).debug(f"Graph query executed: {query_type}")

    @staticmethod
    def log_cache_operation(operation: str, key: str, hit: bool):
        logger.bind(event="cache_operation", operation=operation, key=key, hit=hit).debug(
            f"Cache {operation}: {('hit' if hit else 'miss')}"
        )

    @staticmethod
    def log_entity_extraction(document_id: int, entity_count: int, duration: float):
        logger.bind(
            event="entity_extraction", document_id=document_id,
            entity_count=entity_count, duration=duration,
        ).info(f"Extracted {entity_count} entities from document {document_id}")

    @staticmethod
    def log_error(operation: str, error: Exception, context: dict[str, Any] = None):
        logger.bind(
            event="graph_error",
            operation=operation,
            error=str(error),
            error_type=type(error).__name__,
            context=context or {},
        ).opt(exception=True).error(f"Graph operation error: {operation}")


graph_logger = GraphLogger()
