from typing import Any

from app.services.graph.neo4j_client import Neo4jClient

from common_logging import get_logger

logger = get_logger(__name__)




class RelationExtractor:

    def __init__(self, neo4j_client: Neo4jClient):
        self.neo4j_client = neo4j_client

    def create_document_relations(self, document_id: int, tenant_id: int, kb_id: int):
        query = "\n        MATCH (d1:Document {id: $doc_id, tenant_id: $tenant_id, kb_id: $kb_id})\n        MATCH (d1)-[:CONTAINS]->(e:Entity)\n        MATCH (e)<-[:CONTAINS]-(d2:Document)\n        WHERE d2.tenant_id = $tenant_id\n          AND d2.kb_id = $kb_id\n          AND d2.id <> $doc_id\n        WITH d1, d2, count(e) as shared_entities\n        WHERE shared_entities >= 2\n        MERGE (d1)-[r:SIMILAR_TO]-(d2)\n        SET r.similarity = toFloat(shared_entities) / 10.0,\n            r.method = 'entity_overlap'\n        RETURN d2.id as related_doc_id, shared_entities\n        "
        try:
            results = self.neo4j_client.execute_write(
                query, parameters={"doc_id": document_id, "tenant_id": tenant_id, "kb_id": kb_id}
            )
            logger.info(f"Created {len(results)} document relations for doc {document_id}")
        except Exception as e:
            logger.error(f"Failed to create document relations: {e}")

    def create_entity_relations(
        self,
        entities: list[dict[str, Any]],
        relations: list[dict[str, Any]],
        document_id: int,
        tenant_id: int,
        kb_id: int,
    ):
        if not relations:
            return
        entity_names = {e["name"] for e in entities}
        for relation in relations:
            source = relation["source"]
            target = relation["target"]
            if source not in entity_names or target not in entity_names:
                continue
            query = "\n            MATCH (e1:Entity {name: $source, tenant_id: $tenant_id, kb_id: $kb_id})\n            MATCH (e2:Entity {name: $target, tenant_id: $tenant_id, kb_id: $kb_id})\n            MERGE (e1)-[r:RELATED_TO {type: $rel_type}]->(e2)\n            SET r.confidence = $confidence,\n                r.document_id = $doc_id\n            "
            try:
                self.neo4j_client.execute_write(
                    query,
                    parameters={
                        "source": source,
                        "target": target,
                        "rel_type": relation["type"],
                        "confidence": relation["confidence"],
                        "doc_id": document_id,
                        "tenant_id": tenant_id,
                        "kb_id": kb_id,
                    },
                )
            except Exception as e:
                logger.error(f"Failed to create entity relation {source}->{target}: {e}")

    def create_co_occurrence_relations(
        self, tenant_id: int, kb_id: int, min_co_occurrence: int = 2
    ):
        query = "\n        MATCH (d:Document {tenant_id: $tenant_id, kb_id: $kb_id})\n        MATCH (d)-[:CONTAINS]->(e1:Entity)\n        MATCH (d)-[:CONTAINS]->(e2:Entity)\n        WHERE e1.name < e2.name\n        WITH e1, e2, collect(d.id) as docs, count(d) as co_count\n        WHERE co_count >= $min_count\n        MERGE (e1)-[r:CO_OCCURS_WITH]-(e2)\n        SET r.count = co_count,\n            r.documents = docs\n        "
        try:
            self.neo4j_client.execute_write(
                query,
                parameters={"tenant_id": tenant_id, "kb_id": kb_id, "min_count": min_co_occurrence},
            )
            logger.info(f"Created co-occurrence relations for kb {kb_id}")
        except Exception as e:
            logger.error(f"Failed to create co-occurrence relations: {e}")

    def infer_document_references(self, document_id: int, content: str, tenant_id: int, kb_id: int):
        query = "\n        MATCH (d1:Document {id: $doc_id, tenant_id: $tenant_id, kb_id: $kb_id})\n        MATCH (d2:Document {tenant_id: $tenant_id, kb_id: $kb_id})\n        WHERE d2.id <> $doc_id\n          AND d2.title IS NOT NULL\n          AND $content CONTAINS d2.title\n        MERGE (d1)-[r:REFERENCES]->(d2)\n        SET r.context = 'title_mention',\n            r.confidence = 0.8\n        RETURN d2.id as referenced_doc_id\n        "
        try:
            results = self.neo4j_client.execute_write(
                query,
                parameters={
                    "doc_id": document_id,
                    "content": content,
                    "tenant_id": tenant_id,
                    "kb_id": kb_id,
                },
            )
            if results:
                logger.info(f"Inferred {len(results)} document references for doc {document_id}")
        except Exception as e:
            logger.error(f"Failed to infer document references: {e}")


relation_extractor: RelationExtractor | None = None


def get_relation_extractor() -> RelationExtractor:
    global relation_extractor
    if relation_extractor is None:
        from app.services.graph.neo4j_client import neo4j_client


        relation_extractor = RelationExtractor(neo4j_client)
    return relation_extractor
