from typing import Any

from app.services.graph.graph_cache import get_graph_cache
from app.services.graph.neo4j_client import Neo4jClient

from common_logging import get_logger

logger = get_logger(__name__)

class GraphQueryService:

    def __init__(self, neo4j_client: Neo4jClient):
        self.neo4j_client = neo4j_client
        self.cache = get_graph_cache()

    def expand_neighbors(
        self,
        document_ids: list[int],
        tenant_id: int,
        kb_id: int,
        depth: int = 1,
        relation_types: list[str] | None = None,
        limit: int = 20,
    ) -> list[dict[str, Any]]:
        if not document_ids:
            return []
        cache_key = self.cache._generate_key(
            "neighbors",
            doc_ids=sorted(document_ids),
            tenant_id=tenant_id,
            kb_id=kb_id,
            depth=depth,
            relation_types=relation_types,
            limit=limit,
        )
        cached_result = self.cache.get(cache_key)
        if cached_result is not None:
            return cached_result
        if relation_types is None:
            relation_types = ["REFERENCES", "SIMILAR_TO"]
        rel_pattern = "|".join(relation_types)
        if depth == 1:
            query = f"\n            MATCH (d:Document)\n            WHERE d.id IN $doc_ids\n              AND d.tenant_id = $tenant_id\n              AND d.kb_id = $kb_id\n            MATCH (d)-[r:{rel_pattern}]-(neighbor:Document)\n            WHERE neighbor.tenant_id = $tenant_id\n              AND neighbor.kb_id = $kb_id\n              AND NOT (neighbor.id IN $doc_ids)\n              AND (neighbor.doc_status IS NULL OR NOT (neighbor.doc_status IN ['obsolete', 'expired']))\n            RETURN DISTINCT\n                neighbor.id as id,\n                neighbor.title as title,\n                neighbor.summary as summary,\n                type(r) as relation_type,\n                COALESCE(r.similarity, r.confidence, 0.5) as score\n            ORDER BY score DESC\n            LIMIT $limit\n            "
        else:
            query = f"\n            MATCH (d:Document)\n            WHERE d.id IN $doc_ids\n              AND d.tenant_id = $tenant_id\n              AND d.kb_id = $kb_id\n            MATCH path = (d)-[:{rel_pattern}*1..2]-(neighbor:Document)\n            WHERE neighbor.tenant_id = $tenant_id\n              AND neighbor.kb_id = $kb_id\n              AND NOT (neighbor.id IN $doc_ids)\n              AND (neighbor.doc_status IS NULL OR NOT (neighbor.doc_status IN ['obsolete', 'expired']))\n            WITH DISTINCT neighbor, length(path) as distance\n            RETURN\n                neighbor.id as id,\n                neighbor.title as title,\n                neighbor.summary as summary,\n                'multi_hop' as relation_type,\n                1.0 / distance as score\n            ORDER BY score DESC\n            LIMIT $limit\n            "
        try:
            results = self.neo4j_client.execute_query(
                query,
                parameters={
                    "doc_ids": document_ids,
                    "tenant_id": tenant_id,
                    "kb_id": kb_id,
                    "limit": limit,
                },
            )
            logger.info(
                f"Expanded {len(results)} neighbors from {len(document_ids)} seed documents"
            )
            self.cache.set(cache_key, results)
            return results
        except Exception as e:
            logger.error(f"Failed to expand neighbors: {e}")
            return []

    def get_document_entities(
        self, document_id: int, tenant_id: int, kb_id: int
    ) -> list[dict[str, Any]]:
        cache_key = self.cache._generate_key(
            "doc_entities", doc_id=document_id, tenant_id=tenant_id, kb_id=kb_id
        )
        cached_result = self.cache.get(cache_key)
        if cached_result is not None:
            return cached_result
        query = "\n        MATCH (d:Document {id: $doc_id, tenant_id: $tenant_id, kb_id: $kb_id})\n        MATCH (d)-[r:CONTAINS]->(e:Entity)\n        RETURN\n            e.name as name,\n            e.type as type,\n            e.description as description,\n            r.salience as salience\n        ORDER BY r.salience DESC\n        "
        try:
            results = self.neo4j_client.execute_query(
                query, parameters={"doc_id": document_id, "tenant_id": tenant_id, "kb_id": kb_id}
            )
            self.cache.set(cache_key, results)
            return results
        except Exception as e:
            logger.error(f"Failed to get document entities: {e}")
            return []

    def find_entity_documents(
        self, entity_name: str, tenant_id: int, kb_id: int, limit: int = 10
    ) -> list[dict[str, Any]]:
        query = "\n        MATCH (e:Entity {name: $entity_name, tenant_id: $tenant_id, kb_id: $kb_id})\n        MATCH (e)<-[r:CONTAINS]-(d:Document)\n        WHERE d.tenant_id = $tenant_id\n          AND d.kb_id = $kb_id\n        RETURN\n            d.id as id,\n            d.title as title,\n            d.summary as summary,\n            r.salience as salience\n        ORDER BY r.salience DESC\n        LIMIT $limit\n        "
        try:
            results = self.neo4j_client.execute_query(
                query,
                parameters={
                    "entity_name": entity_name,
                    "tenant_id": tenant_id,
                    "kb_id": kb_id,
                    "limit": limit,
                },
            )
            return results
        except Exception as e:
            logger.error(f"Failed to find entity documents: {e}")
            return []

    def get_related_entities(
        self, entity_name: str, tenant_id: int, kb_id: int, limit: int = 10
    ) -> list[dict[str, Any]]:
        query = "\n        MATCH (e1:Entity {name: $entity_name, tenant_id: $tenant_id, kb_id: $kb_id})\n        MATCH (e1)-[r:RELATED_TO|CO_OCCURS_WITH]-(e2:Entity)\n        WHERE e2.tenant_id = $tenant_id\n          AND e2.kb_id = $kb_id\n        RETURN\n            e2.name as name,\n            e2.type as type,\n            e2.description as description,\n            type(r) as relation_type,\n            COALESCE(r.confidence, r.count, 0.5) as score\n        ORDER BY score DESC\n        LIMIT $limit\n        "
        try:
            results = self.neo4j_client.execute_query(
                query,
                parameters={
                    "entity_name": entity_name,
                    "tenant_id": tenant_id,
                    "kb_id": kb_id,
                    "limit": limit,
                },
            )
            return results
        except Exception as e:
            logger.error(f"Failed to get related entities: {e}")
            return []

    def search_entities(
        self,
        query_text: str,
        tenant_id: int,
        kb_id: int,
        entity_type: str | None = None,
        limit: int = 20,
    ) -> list[dict[str, Any]]:
        cache_key = self.cache._generate_key(
            "search_entities",
            query=query_text,
            tenant_id=tenant_id,
            kb_id=kb_id,
            entity_type=entity_type,
            limit=limit,
        )
        cached_result = self.cache.get(cache_key)
        if cached_result is not None:
            return cached_result
        where_clauses = [
            "e.tenant_id = $tenant_id",
            "e.kb_id = $kb_id",
            "toLower(e.name) CONTAINS toLower($query)",
        ]
        if entity_type:
            where_clauses.append("e.type = $entity_type")
        where_clause = " AND ".join(where_clauses)
        query = f"\n        MATCH (e:Entity)\n        WHERE {where_clause}\n        RETURN\n            e.name as name,\n            e.type as type,\n            e.description as description\n        LIMIT $limit\n        "
        params = {"query": query_text, "tenant_id": tenant_id, "kb_id": kb_id, "limit": limit}
        if entity_type:
            params["entity_type"] = entity_type
        try:
            results = self.neo4j_client.execute_query(query, parameters=params)
            self.cache.set(cache_key, results)
            return results
        except Exception as e:
            logger.error(f"Failed to search entities: {e}")
            return []

    def get_shortest_path(
        self, source_doc_id: int, target_doc_id: int, tenant_id: int, kb_id: int, max_depth: int = 3
    ) -> dict[str, Any] | None:
        query = "\n        MATCH (d1:Document {id: $source_id, tenant_id: $tenant_id, kb_id: $kb_id})\n        MATCH (d2:Document {id: $target_id, tenant_id: $tenant_id, kb_id: $kb_id})\n        MATCH path = shortestPath((d1)-[*..{max_depth}]-(d2))\n        RETURN\n            [node in nodes(path) | node.id] as node_ids,\n            [rel in relationships(path) | type(rel)] as relation_types,\n            length(path) as path_length\n        "
        try:
            results = self.neo4j_client.execute_query(
                query.replace("{max_depth}", str(max_depth)),
                parameters={
                    "source_id": source_doc_id,
                    "target_id": target_doc_id,
                    "tenant_id": tenant_id,
                    "kb_id": kb_id,
                },
            )
            return results[0] if results else None
        except Exception as e:
            logger.error(f"Failed to find shortest path: {e}")
            return None


graph_query_service: GraphQueryService | None = None


def get_graph_query_service() -> GraphQueryService:
    global graph_query_service
    if graph_query_service is None:
        from app.services.graph.neo4j_client import neo4j_client

        graph_query_service = GraphQueryService(neo4j_client)
    return graph_query_service
