import time
from contextlib import contextmanager
from typing import Any

from neo4j import Driver, GraphDatabase, Session

from app.config import settings
from app.core.exceptions import ExternalServiceError
from common_logging import get_logger

logger = get_logger(__name__)


class Neo4jClient:

    def __init__(self):
        self.driver: Driver | None = None
        self._initialize_driver()

    def _initialize_driver(self):
        if not settings.ENABLE_KNOWLEDGE_GRAPH:
            logger.info("Knowledge graph is disabled, skipping Neo4j initialization")
            return
        try:
            self.driver = GraphDatabase.driver(
                settings.NEO4J_URI,
                auth=(settings.NEO4J_USER, settings.NEO4J_PASSWORD),
                database=settings.NEO4J_DATABASE,
                max_connection_lifetime=3600,
                max_connection_pool_size=50,
                connection_acquisition_timeout=60,
            )
            self.driver.verify_connectivity()
            logger.info(f"Neo4j driver initialized successfully: {settings.NEO4J_URI}")
        except Exception as e:
            logger.error(f"Failed to initialize Neo4j driver: {e}")
            if settings.ENVIRONMENT == "test":
                logger.warning(
                    "Neo4j connection failed in test environment, continuing without graph support"
                )
                return
            raise ExternalServiceError(
                "Neo4j", f"Failed to connect to Neo4j at {settings.NEO4J_URI}: {str(e)}"
            ) from e

    @contextmanager
    def get_session(self) -> Session:
        if not self.driver:
            raise RuntimeError(
                "Neo4j driver not initialized. Enable ENABLE_KNOWLEDGE_GRAPH in settings."
            )
        session = self.driver.session()
        try:
            yield session
        finally:
            session.close()

    def execute_query(
        self,
        query: str,
        parameters: dict[str, Any] | None = None,
        tenant_id: int | None = None,
    ) -> list[dict[str, Any]]:
        if not self.driver:
            raise RuntimeError("Neo4j driver not initialized")
        params = parameters or {}
        params["tenant_id"] = tenant_id if tenant_id is not None else 0
        start = time.time()
        try:
            with self.get_session() as session:
                result = session.run(query, params)
                records = [dict(record) for record in result]
                elapsed = time.time() - start
                logger.bind(duration_ms=round(elapsed * 1000, 2), result_count=len(records)).info(
                    "Neo4j query executed"
                )
                return records
        except Exception as e:
            elapsed = time.time() - start
            logger.bind(duration_ms=round(elapsed * 1000, 2)).error(
                f"Query execution failed: {e}"
            )
            raise

    def execute_write(
        self,
        query: str,
        parameters: dict[str, Any] | None = None,
        tenant_id: int | None = None,
    ) -> list[dict[str, Any]]:
        if not self.driver:
            raise RuntimeError("Neo4j driver not initialized")
        params = parameters or {}
        params["tenant_id"] = tenant_id if tenant_id is not None else 0

        def _write_tx(tx):
            result = tx.run(query, params)
            return [dict(record) for record in result]

        start = time.time()
        try:
            with self.get_session() as session:
                records = session.execute_write(_write_tx)
                elapsed = time.time() - start
                logger.bind(duration_ms=round(elapsed * 1000, 2)).info(
                    "Neo4j write transaction executed"
                )
                return records
        except Exception as e:
            elapsed = time.time() - start
            logger.bind(duration_ms=round(elapsed * 1000, 2)).error(
                f"Write transaction failed: {e}"
            )
            raise

    def create_indexes(self):
        if not self.driver:
            logger.warning("Neo4j driver not initialized, skipping index creation")
            return
        indexes = [
            "CREATE INDEX document_tenant_kb IF NOT EXISTS FOR (d:Document) ON (d.tenant_id, d.kb_id)",
            "CREATE INDEX document_id IF NOT EXISTS FOR (d:Document) ON (d.id)",
            "CREATE INDEX entity_tenant_kb IF NOT EXISTS FOR (e:Entity) ON (e.tenant_id, e.kb_id)",
            "CREATE INDEX entity_name IF NOT EXISTS FOR (e:Entity) ON (e.name)",
            "CREATE INDEX entity_type IF NOT EXISTS FOR (e:Entity) ON (e.type)",
            "CREATE INDEX tag_tenant_kb IF NOT EXISTS FOR (t:Tag) ON (t.tenant_id, t.kb_id)",
            "CREATE INDEX category_tenant_kb IF NOT EXISTS FOR (c:Category) ON (c.tenant_id, c.kb_id)",
            "CREATE INDEX user_tenant IF NOT EXISTS FOR (u:User) ON (u.tenant_id)",
            "CREATE INDEX chunk_tenant_kb IF NOT EXISTS FOR (c:Chunk) ON (c.tenant_id, c.kb_id)",
            "CREATE INDEX chunk_id IF NOT EXISTS FOR (c:Chunk) ON (c.chunk_id)",
            "CREATE INDEX chunk_doc_number IF NOT EXISTS FOR (c:Chunk) ON (c.doc_number)",
            "CREATE INDEX chunk_level IF NOT EXISTS FOR (c:Chunk) ON (c.chunk_level)",
        ]
        try:
            with self.get_session() as session:
                for index_query in indexes:
                    session.run(index_query)
                logger.info("Neo4j indexes created successfully")
        except Exception as e:
            logger.error(f"Failed to create indexes: {e}")

    def clear_tenant_data(self, tenant_id: int):
        query = "\n        MATCH (n)\n        WHERE n.tenant_id = $tenant_id\n        DETACH DELETE n\n        "
        self.execute_write(query, tenant_id=tenant_id)
        logger.info(f"Cleared graph data for tenant {tenant_id}")

    def delete_kb_data(self, kb_id: int, tenant_id: int) -> int:
        query = "\n        MATCH (n)\n        WHERE n.tenant_id = $tenant_id AND n.kb_id = $kb_id\n        WITH count(n) AS deleted_count, collect(n) AS nodes\n        FOREACH (n IN nodes | DETACH DELETE n)\n        RETURN deleted_count\n        "
        try:
            results = self.execute_write(query, parameters={"kb_id": kb_id}, tenant_id=tenant_id)
            deleted_count = results[0].get("deleted_count", 0) if results else 0
            logger.info(
                f"Deleted {deleted_count} Neo4j nodes for kb_id={kb_id}, tenant_id={tenant_id}"
            )
            return deleted_count
        except Exception as e:
            logger.error(f"Failed to delete Neo4j graph data for kb_id={kb_id}: {e}")
            raise

    def create_chunk_node(
        self,
        chunk_id: str,
        document_id: int,
        chunk_text: str,
        chunk_level: str,
        chunk_index: int,
        is_parent: bool,
        parent_chunk_id: str | None,
        tenant_id: int,
        kb_id: int,
        doc_type: str | None = None,
        doc_number: str | None = None,
    ):
        query = "\n        MERGE (c:Chunk {chunk_id: $chunk_id, tenant_id: $tenant_id, kb_id: $kb_id})\n        SET c.document_id = $document_id,\n            c.text = $chunk_text,\n            c.chunk_level = $chunk_level,\n            c.chunk_index = $chunk_index,\n            c.is_parent = $is_parent,\n            c.parent_chunk_id = $parent_chunk_id,\n            c.doc_type = $doc_type,\n            c.doc_number = $doc_number,\n            c.updated_at = datetime()\n        WITH c\n        MATCH (d:Document {id: $document_id, tenant_id: $tenant_id, kb_id: $kb_id})\n        MERGE (d)-[:CONTAINS]->(c)\n        "
        self.execute_write(
            query,
            parameters={
                "chunk_id": chunk_id,
                "document_id": document_id,
                "chunk_text": chunk_text[:500],
                "chunk_level": chunk_level,
                "chunk_index": chunk_index,
                "is_parent": is_parent,
                "parent_chunk_id": parent_chunk_id,
                "doc_type": doc_type,
                "doc_number": doc_number,
                "kb_id": kb_id,
            },
            tenant_id=tenant_id,
        )

    def create_chunk_reference(
        self,
        source_chunk_id: str,
        target_doc_number: str,
        target_article: str | None,
        confidence: float,
        tenant_id: int,
        kb_id: int,
    ):
        if target_article:
            query = "\n            MATCH (source:Chunk {chunk_id: $source_chunk_id, tenant_id: $tenant_id, kb_id: $kb_id})\n            MATCH (target:Chunk {doc_number: $target_doc_number, tenant_id: $tenant_id, kb_id: $kb_id})\n            WHERE target.chunk_level = 'article' AND target.text CONTAINS $target_article\n            MERGE (source)-[r:REFERENCES]->(target)\n            SET r.confidence = $confidence,\n                r.target_article = $target_article,\n                r.created_at = datetime()\n            "
        else:
            query = "\n            MATCH (source:Chunk {chunk_id: $source_chunk_id, tenant_id: $tenant_id, kb_id: $kb_id})\n            MATCH (target:Chunk {doc_number: $target_doc_number, tenant_id: $tenant_id, kb_id: $kb_id})\n            WHERE target.is_parent = true\n            WITH source, target\n            LIMIT 1\n            MERGE (source)-[r:REFERENCES]->(target)\n            SET r.confidence = $confidence,\n                r.created_at = datetime()\n            "
        self.execute_write(
            query,
            parameters={
                "source_chunk_id": source_chunk_id,
                "target_doc_number": target_doc_number,
                "target_article": target_article,
                "confidence": confidence,
                "kb_id": kb_id,
            },
            tenant_id=tenant_id,
        )

    def get_referenced_chunks(
        self,
        chunk_id: str,
        tenant_id: int,
        kb_id: int,
        max_depth: int = 1,
        min_confidence: float = 0.5,
    ) -> list[dict[str, Any]]:
        query = f"\n        MATCH path = (source:Chunk {{chunk_id: $chunk_id, tenant_id: $tenant_id, kb_id: $kb_id}})\n                     -[r:REFERENCES*1..{max_depth}]->(target:Chunk)\n        WHERE ALL(rel IN relationships(path) WHERE rel.confidence >= $min_confidence)\n        RETURN DISTINCT\n            target.chunk_id as chunk_id,\n            target.document_id as document_id,\n            target.text as text,\n            target.chunk_level as chunk_level,\n            target.chunk_index as chunk_index,\n            target.doc_number as doc_number,\n            target.doc_type as doc_type,\n            target.is_parent as is_parent,\n            target.parent_chunk_id as parent_chunk_id,\n            [rel IN relationships(path) | rel.confidence] as confidence_path,\n            length(path) as depth\n        ORDER BY depth, confidence_path[0] DESC\n        LIMIT 20\n        "
        results = self.execute_query(
            query,
            parameters={"chunk_id": chunk_id, "min_confidence": min_confidence, "kb_id": kb_id},
            tenant_id=tenant_id,
        )
        return results

    def get_chunks_referencing(
        self, chunk_id: str, tenant_id: int, kb_id: int, min_confidence: float = 0.5
    ) -> list[dict[str, Any]]:
        query = "\n        MATCH (source:Chunk)-[r:REFERENCES]->(target:Chunk {chunk_id: $chunk_id, tenant_id: $tenant_id, kb_id: $kb_id})\n        WHERE r.confidence >= $min_confidence\n        RETURN\n            source.chunk_id as chunk_id,\n            source.document_id as document_id,\n            source.text as text,\n            source.chunk_level as chunk_level,\n            source.doc_number as doc_number,\n            r.confidence as confidence\n        ORDER BY r.confidence DESC\n        LIMIT 20\n        "
        results = self.execute_query(
            query,
            parameters={"chunk_id": chunk_id, "min_confidence": min_confidence, "kb_id": kb_id},
            tenant_id=tenant_id,
        )
        return results

    def create_supersedes_relation(self, source_doc_id: int, target_doc_id: int) -> None:
        query = "\n        MERGE (s:Document {id: $source_id})\n        MERGE (t:Document {id: $target_id})\n        MERGE (s)-[r:SUPERSEDES]->(t)\n        SET r.created_at = datetime()\n        "
        self.execute_write(
            query, parameters={"source_id": source_doc_id, "target_id": target_doc_id}
        )

    def get_stats(self, tenant_id: int, kb_id: int | None = None) -> dict[str, int]:
        where_clause = "WHERE n.tenant_id = $tenant_id"
        params = {"tenant_id": tenant_id}
        if kb_id is not None:
            where_clause += " AND n.kb_id = $kb_id"
            params["kb_id"] = kb_id
        logger.info(f"Getting stats with params: {params}")
        query = f"\n        MATCH (n)\n        {where_clause}\n        RETURN\n            labels(n)[0] as label,\n            count(n) as count\n        "
        results = self.execute_query(query, params)
        logger.info(f"Node query results: {results}")
        stats = {
            "total_nodes": sum(r["count"] for r in results),
            "documents": 0,
            "entities": 0,
            "chunks": 0,
            "tags": 0,
            "categories": 0,
            "users": 0,
        }
        for record in results:
            label = record.get("label", "")
            count = record.get("count", 0)
            if label == "Document":
                stats["documents"] = count
            elif label == "Entity":
                stats["entities"] = count
            elif label == "Chunk":
                stats["chunks"] = count
            elif label == "Tag":
                stats["tags"] = count
            elif label == "Category":
                stats["categories"] = count
            elif label == "User":
                stats["users"] = count
        rel_query = f"\n        MATCH (n1)-[r]-(n2)\n        WHERE COALESCE(n1.tenant_id, -1) = $tenant_id\n        {('AND COALESCE(n1.kb_id, -1) = $kb_id' if kb_id is not None else '')}\n        AND COALESCE(n2.tenant_id, -1) = $tenant_id\n        {('AND COALESCE(n2.kb_id, -1) = $kb_id' if kb_id is not None else '')}\n        RETURN count(DISTINCT r) as rel_count\n        "
        rel_results = self.execute_query(rel_query, params)
        logger.info(f"Relationship query results: {rel_results}")
        stats["total_relationships"] = rel_results[0]["rel_count"] // 2 if rel_results else 0
        logger.info(f"Final stats: {stats}")
        return stats

    def close(self):
        if self.driver:
            self.driver.close()
            logger.info("Neo4j driver closed")


_neo4j_client_instance = None


def get_neo4j_client() -> Neo4jClient:
    global _neo4j_client_instance
    if _neo4j_client_instance is None:
        _neo4j_client_instance = Neo4jClient()
    return _neo4j_client_instance


neo4j_client = None
try:
    neo4j_client = Neo4jClient()
except Exception as e:
    logger.warning(f"Failed to create default Neo4j client instance: {e}")
