from difflib import SequenceMatcher
from typing import Any

from app.services.graph.graph_cache import get_graph_cache
from app.services.graph.neo4j_client import Neo4jClient

from common_logging import get_logger

logger = get_logger(__name__)




class EntityLinker:

    def __init__(self, neo4j_client: Neo4jClient):
        self.neo4j_client = neo4j_client
        self.cache = get_graph_cache()

    def link_entities(
        self, tenant_id: int, kb_id: int, similarity_threshold: float = 0.85
    ) -> dict[str, Any]:
        try:
            logger.info(f"Starting entity linking for kb {kb_id}")
            candidates = self._find_candidate_pairs(tenant_id, kb_id)
            logger.info(f"Found {len(candidates)} candidate pairs")
            merged_count = 0
            for entity1, entity2 in candidates:
                similarity = self._calculate_entity_similarity(entity1, entity2)
                if similarity >= similarity_threshold:
                    self._merge_entities(entity1, entity2, tenant_id, kb_id)
                    merged_count += 1
            self._create_canonical_names(tenant_id, kb_id)
            self.cache.invalidate_knowledge_base(kb_id, tenant_id)
            logger.info(f"Entity linking completed for kb {kb_id}: {merged_count} entities merged")
            return {
                "success": True,
                "kb_id": kb_id,
                "candidates_evaluated": len(candidates),
                "entities_merged": merged_count,
            }
        except Exception as e:
            logger.error(f"Entity linking failed for kb {kb_id}: {e}")
            return {"success": False, "kb_id": kb_id, "error": str(e)}

    def _find_candidate_pairs(
        self, tenant_id: int, kb_id: int
    ) -> list[tuple[dict[str, Any], dict[str, Any]]]:
        query = "\n        MATCH (e1:Entity {tenant_id: $tenant_id, kb_id: $kb_id})\n        MATCH (e2:Entity {tenant_id: $tenant_id, kb_id: $kb_id})\n        WHERE e1.type = e2.type\n          AND id(e1) < id(e2)\n          AND (\n            toLower(e1.name) CONTAINS toLower(e2.name)\n            OR toLower(e2.name) CONTAINS toLower(e1.name)\n            OR e1.name =~ '(?i).*' + e2.name + '.*'\n          )\n        RETURN\n            e1.name as name1,\n            e1.type as type1,\n            e1.description as desc1,\n            e2.name as name2,\n            e2.type as type2,\n            e2.description as desc2\n        LIMIT 1000\n        "
        results = self.neo4j_client.execute_query(
            query, parameters={"tenant_id": tenant_id, "kb_id": kb_id}
        )
        candidates = []
        for r in results:
            entity1 = {"name": r["name1"], "type": r["type1"], "description": r.get("desc1") or ""}
            entity2 = {"name": r["name2"], "type": r["type2"], "description": r.get("desc2") or ""}
            candidates.append((entity1, entity2))
        return candidates

    def _calculate_entity_similarity(
        self, entity1: dict[str, Any], entity2: dict[str, Any]
    ) -> float:
        name1 = entity1["name"].lower()
        name2 = entity2["name"].lower()
        name_sim = SequenceMatcher(None, name1, name2).ratio()
        desc1 = entity1.get("description", "").lower()
        desc2 = entity2.get("description", "").lower()
        desc_sim = 0.0
        if desc1 and desc2:
            desc_sim = SequenceMatcher(None, desc1, desc2).ratio()
        type_match = 1.0 if entity1["type"] == entity2["type"] else 0.0
        similarity = 0.6 * name_sim + 0.3 * desc_sim + 0.1 * type_match
        return similarity

    def _merge_entities(
        self, entity1: dict[str, Any], entity2: dict[str, Any], tenant_id: int, kb_id: int
    ):
        query = "\n        MATCH (e1:Entity {name: $name1, tenant_id: $tenant_id, kb_id: $kb_id})\n        MATCH (e2:Entity {name: $name2, tenant_id: $tenant_id, kb_id: $kb_id})\n        OPTIONAL MATCH (e1)-[r1]-()\n        OPTIONAL MATCH (e2)-[r2]-()\n        WITH e1, e2, count(DISTINCT r1) as conn1, count(DISTINCT r2) as conn2\n        RETURN\n            CASE WHEN conn1 >= conn2 THEN e1.name ELSE e2.name END as canonical,\n            CASE WHEN conn1 >= conn2 THEN e2.name ELSE e1.name END as alias\n        "
        result = self.neo4j_client.execute_query(
            query,
            parameters={
                "name1": entity1["name"],
                "name2": entity2["name"],
                "tenant_id": tenant_id,
                "kb_id": kb_id,
            },
        )
        if not result:
            return
        canonical_name = result[0]["canonical"]
        alias_name = result[0]["alias"]
        merge_query = "\n        MATCH (canonical:Entity {name: $canonical, tenant_id: $tenant_id, kb_id: $kb_id})\n        MATCH (alias:Entity {name: $alias, tenant_id: $tenant_id, kb_id: $kb_id})\n\n        // Transfer all relationships from alias to canonical\n        OPTIONAL MATCH (alias)-[r:CONTAINS|RELATED_TO|CO_OCCURS_WITH]-(other)\n        WHERE other <> canonical\n        WITH canonical, alias, type(r) as rel_type, other, properties(r) as props\n        CALL apoc.create.relationship(canonical, rel_type, props, other) YIELD rel\n\n        // Create alias relationship\n        WITH canonical, alias\n        MERGE (canonical)-[:ALIAS_OF]->(alias)\n        SET alias.is_alias = true,\n            alias.canonical_name = canonical.name\n\n        RETURN count(*) as merged\n        "
        try:
            self.neo4j_client.execute_write(
                merge_query,
                parameters={
                    "canonical": canonical_name,
                    "alias": alias_name,
                    "tenant_id": tenant_id,
                    "kb_id": kb_id,
                },
            )
            logger.info(f"Merged entity '{alias_name}' into '{canonical_name}'")
        except Exception as e:
            logger.warning(f"APOC merge failed, using simple merge: {e}")
            self._simple_merge(canonical_name, alias_name, tenant_id, kb_id)

    def _simple_merge(self, canonical_name: str, alias_name: str, tenant_id: int, kb_id: int):
        query = "\n        MATCH (canonical:Entity {name: $canonical, tenant_id: $tenant_id, kb_id: $kb_id})\n        MATCH (alias:Entity {name: $alias, tenant_id: $tenant_id, kb_id: $kb_id})\n\n        // Mark as alias\n        SET alias.is_alias = true,\n            alias.canonical_name = canonical.name\n\n        // Create alias relationship\n        MERGE (canonical)-[:ALIAS_OF]->(alias)\n        "
        self.neo4j_client.execute_write(
            query,
            parameters={
                "canonical": canonical_name,
                "alias": alias_name,
                "tenant_id": tenant_id,
                "kb_id": kb_id,
            },
        )

    def _create_canonical_names(self, tenant_id: int, kb_id: int):
        query = "\n        MATCH (e:Entity {tenant_id: $tenant_id, kb_id: $kb_id})\n        WHERE NOT e.is_alias = true\n        SET e.canonical_name = e.name\n        "
        self.neo4j_client.execute_write(query, parameters={"tenant_id": tenant_id, "kb_id": kb_id})

    def resolve_entity(self, entity_name: str, tenant_id: int, kb_id: int) -> str | None:
        query = "\n        MATCH (e:Entity {name: $name, tenant_id: $tenant_id, kb_id: $kb_id})\n        RETURN COALESCE(e.canonical_name, e.name) as canonical_name\n        "
        result = self.neo4j_client.execute_query(
            query, parameters={"name": entity_name, "tenant_id": tenant_id, "kb_id": kb_id}
        )
        return result[0]["canonical_name"] if result else None

    def get_entity_aliases(self, canonical_name: str, tenant_id: int, kb_id: int) -> list[str]:
        query = "\n        MATCH (canonical:Entity {name: $name, tenant_id: $tenant_id, kb_id: $kb_id})\n        MATCH (canonical)-[:ALIAS_OF]->(alias:Entity)\n        RETURN alias.name as alias_name\n        "
        results = self.neo4j_client.execute_query(
            query, parameters={"name": canonical_name, "tenant_id": tenant_id, "kb_id": kb_id}
        )
        return [r["alias_name"] for r in results]

    def get_linking_stats(self, tenant_id: int, kb_id: int) -> dict[str, Any]:
        query = "\n        MATCH (e:Entity {tenant_id: $tenant_id, kb_id: $kb_id})\n        WITH\n            count(e) as total_entities,\n            sum(CASE WHEN e.is_alias = true THEN 1 ELSE 0 END) as alias_count,\n            count(DISTINCT e.canonical_name) as canonical_count\n        RETURN\n            total_entities,\n            alias_count,\n            canonical_count,\n            total_entities - alias_count as unique_entities\n        "
        result = self.neo4j_client.execute_query(
            query, parameters={"tenant_id": tenant_id, "kb_id": kb_id}
        )
        if result:
            return result[0]
        return {"total_entities": 0, "alias_count": 0, "canonical_count": 0, "unique_entities": 0}


entity_linker: EntityLinker | None = None


def get_entity_linker() -> EntityLinker:
    global entity_linker
    if entity_linker is None:
        from app.services.graph.neo4j_client import neo4j_client


        entity_linker = EntityLinker(neo4j_client)
    return entity_linker
