from typing import Any

from app.services.graph.entity_extractor import EntityExtractor
from app.services.graph.graph_cache import get_graph_cache
from app.services.graph.neo4j_client import Neo4jClient
from app.services.graph.relation_extractor import RelationExtractor

from common_logging import get_logger

logger = get_logger(__name__)




class IncrementalGraphUpdater:

    def __init__(
        self,
        neo4j_client: Neo4jClient,
        entity_extractor: EntityExtractor,
        relation_extractor: RelationExtractor,
    ):
        self.neo4j_client = neo4j_client
        self.entity_extractor = entity_extractor
        self.relation_extractor = relation_extractor
        self.cache = get_graph_cache()

    def update_document(
        self,
        document_id: int,
        title: str,
        content: str,
        summary: str | None,
        tenant_id: int,
        kb_id: int,
        category_id: int | None = None,
        tag_ids: list[int] | None = None,
    ) -> dict[str, Any]:
        try:
            logger.info(f"Incrementally updating graph for document {document_id}")
            self._remove_document_entities(document_id, tenant_id, kb_id)
            extraction_result = self.entity_extractor.extract_entities(title=title, content=content)
            entities = extraction_result.get("entities", [])
            relations = extraction_result.get("relations", [])
            self._update_document_node(document_id, title, summary, tenant_id, kb_id)
            entity_count = self._create_entity_nodes(entities, document_id, tenant_id, kb_id)
            self.relation_extractor.create_entity_relations(
                entities, relations, document_id, tenant_id, kb_id
            )
            self._update_document_relations(document_id, tenant_id, kb_id)
            if tag_ids is not None:
                self._update_tags(document_id, tag_ids, tenant_id, kb_id)
            if category_id is not None:
                self._update_category(document_id, category_id, tenant_id, kb_id)
            self.cache.invalidate_document(document_id, tenant_id, kb_id)
            logger.info(
                f"Graph updated successfully for document {document_id}: {entity_count} entities"
            )
            return {
                "success": True,
                "document_id": document_id,
                "entity_count": entity_count,
                "relation_count": len(relations),
            }
        except Exception as e:
            logger.error(f"Failed to update graph for document {document_id}: {e}")
            return {"success": False, "document_id": document_id, "error": str(e)}

    def delete_document(self, document_id: int, tenant_id: int, kb_id: int) -> dict[str, Any]:
        try:
            logger.info(f"Removing document {document_id} from graph")
            query = "\n            MATCH (d:Document {id: $doc_id, tenant_id: $tenant_id, kb_id: $kb_id})\n            DETACH DELETE d\n            "
            self.neo4j_client.execute_write(
                query, parameters={"doc_id": document_id, "tenant_id": tenant_id, "kb_id": kb_id}
            )
            self._cleanup_orphaned_entities(tenant_id, kb_id)
            self.cache.invalidate_document(document_id, tenant_id, kb_id)
            logger.info(f"Document {document_id} removed from graph")
            return {"success": True, "document_id": document_id}
        except Exception as e:
            logger.error(f"Failed to delete document {document_id} from graph: {e}")
            return {"success": False, "document_id": document_id, "error": str(e)}

    def _remove_document_entities(self, document_id: int, tenant_id: int, kb_id: int):
        query = "\n        MATCH (d:Document {id: $doc_id, tenant_id: $tenant_id, kb_id: $kb_id})\n        MATCH (d)-[r:CONTAINS]->(e:Entity)\n        DELETE r\n        "
        self.neo4j_client.execute_write(
            query, parameters={"doc_id": document_id, "tenant_id": tenant_id, "kb_id": kb_id}
        )

    def _update_document_node(
        self, document_id: int, title: str, summary: str | None, tenant_id: int, kb_id: int
    ):
        query = "\n        MERGE (d:Document {id: $doc_id, tenant_id: $tenant_id, kb_id: $kb_id})\n        SET d.title = $title,\n            d.summary = $summary,\n            d.updated_at = datetime()\n        "
        self.neo4j_client.execute_write(
            query,
            parameters={
                "doc_id": document_id,
                "title": title,
                "summary": summary or "",
                "tenant_id": tenant_id,
                "kb_id": kb_id,
            },
        )

    def _create_entity_nodes(
        self, entities: list[dict[str, Any]], document_id: int, tenant_id: int, kb_id: int
    ) -> int:
        if not entities:
            return 0
        query = "\n        UNWIND $entities AS entity\n        MERGE (e:Entity {\n            name: entity.name,\n            tenant_id: $tenant_id,\n            kb_id: $kb_id\n        })\n        SET e.type = entity.type,\n            e.description = entity.description,\n            e.confidence = entity.salience\n        WITH e, entity\n        MATCH (d:Document {id: $doc_id, tenant_id: $tenant_id, kb_id: $kb_id})\n        MERGE (d)-[r:CONTAINS]->(e)\n        SET r.salience = entity.salience,\n            r.frequency = 1\n        "
        self.neo4j_client.execute_write(
            query,
            parameters={
                "entities": entities,
                "doc_id": document_id,
                "tenant_id": tenant_id,
                "kb_id": kb_id,
            },
        )
        return len(entities)

    def _update_document_relations(self, document_id: int, tenant_id: int, kb_id: int):
        remove_query = "\n        MATCH (d:Document {id: $doc_id, tenant_id: $tenant_id, kb_id: $kb_id})\n        MATCH (d)-[r:SIMILAR_TO]-()\n        DELETE r\n        "
        self.neo4j_client.execute_write(
            remove_query, parameters={"doc_id": document_id, "tenant_id": tenant_id, "kb_id": kb_id}
        )
        self.relation_extractor.create_document_relations(document_id, tenant_id, kb_id)

    def _update_tags(self, document_id: int, tag_ids: list[int], tenant_id: int, kb_id: int):
        remove_query = "\n        MATCH (d:Document {id: $doc_id, tenant_id: $tenant_id, kb_id: $kb_id})\n        MATCH (d)-[r:HAS_TAG]->()\n        DELETE r\n        "
        self.neo4j_client.execute_write(
            remove_query, parameters={"doc_id": document_id, "tenant_id": tenant_id, "kb_id": kb_id}
        )
        if tag_ids:
            add_query = "\n            MATCH (d:Document {id: $doc_id, tenant_id: $tenant_id, kb_id: $kb_id})\n            UNWIND $tag_ids AS tag_id\n            MERGE (t:Tag {id: tag_id, tenant_id: $tenant_id, kb_id: $kb_id})\n            MERGE (d)-[:HAS_TAG]->(t)\n            "
            self.neo4j_client.execute_write(
                add_query,
                parameters={
                    "doc_id": document_id,
                    "tag_ids": tag_ids,
                    "tenant_id": tenant_id,
                    "kb_id": kb_id,
                },
            )

    def _update_category(self, document_id: int, category_id: int, tenant_id: int, kb_id: int):
        remove_query = "\n        MATCH (d:Document {id: $doc_id, tenant_id: $tenant_id, kb_id: $kb_id})\n        MATCH (d)-[r:IN_CATEGORY]->()\n        DELETE r\n        "
        self.neo4j_client.execute_write(
            remove_query, parameters={"doc_id": document_id, "tenant_id": tenant_id, "kb_id": kb_id}
        )
        add_query = "\n        MATCH (d:Document {id: $doc_id, tenant_id: $tenant_id, kb_id: $kb_id})\n        MERGE (c:Category {id: $cat_id, tenant_id: $tenant_id, kb_id: $kb_id})\n        MERGE (d)-[:IN_CATEGORY]->(c)\n        "
        self.neo4j_client.execute_write(
            add_query,
            parameters={
                "doc_id": document_id,
                "cat_id": category_id,
                "tenant_id": tenant_id,
                "kb_id": kb_id,
            },
        )

    def _cleanup_orphaned_entities(self, tenant_id: int, kb_id: int):
        query = "\n        MATCH (e:Entity {tenant_id: $tenant_id, kb_id: $kb_id})\n        WHERE NOT (e)<-[:CONTAINS]-()\n        DETACH DELETE e\n        "
        self.neo4j_client.execute_write(query, parameters={"tenant_id": tenant_id, "kb_id": kb_id})


incremental_updater: IncrementalGraphUpdater | None = None


def get_incremental_updater() -> IncrementalGraphUpdater:
    global incremental_updater
    if incremental_updater is None:
        from app.services.graph.entity_extractor import entity_extractor
        from app.services.graph.neo4j_client import neo4j_client
        from app.services.graph.relation_extractor import get_relation_extractor


        incremental_updater = IncrementalGraphUpdater(
            neo4j_client, entity_extractor, get_relation_extractor()
        )
    return incremental_updater
