from typing import Any

from app.services.graph.entity_extractor import EntityExtractor
from app.services.graph.neo4j_client import Neo4jClient
from app.services.graph.relation_extractor import RelationExtractor
from common_logging import get_logger

logger = get_logger(__name__)


class GraphBuilder:

    def __init__(
        self,
        neo4j_client: Neo4jClient,
        entity_extractor: EntityExtractor,
        relation_extractor: RelationExtractor,
    ):
        self.neo4j_client = neo4j_client
        self.entity_extractor = entity_extractor
        self.relation_extractor = relation_extractor

    def build_document_graph(
        self,
        document_id: int,
        title: str,
        content: str,
        summary: str | None,
        tenant_id: int,
        kb_id: int,
        category_id: int | None = None,
        tag_ids: list | None = None,
        db_session=None,
        doc_number: str | None = None,
        doc_status: str | None = None,
        supersedes_doc_ids: list[int] | None = None,
        parent_doc_id: int | None = None,
    ) -> dict[str, Any]:
        try:
            logger.info(f"Building graph for document {document_id}: {title}")
            self._create_document_node(
                document_id,
                title,
                summary,
                tenant_id,
                kb_id,
                doc_number=doc_number,
                doc_status=doc_status,
            )
            extraction_result = self.entity_extractor.extract_entities(
                title=title, content=content, db_session=db_session
            )
            entities = extraction_result.get("entities", [])
            relations = extraction_result.get("relations", [])
            entity_count = self._create_entity_nodes(entities, document_id, tenant_id, kb_id)
            self.relation_extractor.create_entity_relations(
                entities, relations, document_id, tenant_id, kb_id
            )
            self.relation_extractor.create_document_relations(document_id, tenant_id, kb_id)
            self.relation_extractor.infer_document_references(
                document_id, content, tenant_id, kb_id
            )
            if tag_ids:
                self._link_tags(document_id, tag_ids, tenant_id, kb_id)
            if category_id:
                self._link_category(document_id, category_id, tenant_id, kb_id)
            if supersedes_doc_ids:
                for target_id in supersedes_doc_ids:
                    self.neo4j_client.create_supersedes_relation(document_id, target_id)
            if parent_doc_id:
                query = "\n                MATCH (parent:Document {id: $parent_id, tenant_id: $tenant_id, kb_id: $kb_id})\n                MATCH (att:Document {id: $doc_id, tenant_id: $tenant_id, kb_id: $kb_id})\n                MERGE (parent)-[:HAS_ATTACHMENT]->(att)\n                MERGE (att)-[:IS_ATTACHMENT_OF]->(parent)\n                "
                self.neo4j_client.execute_write(
                    query,
                    parameters={
                        "doc_id": document_id,
                        "parent_id": parent_doc_id,
                        "tenant_id": tenant_id,
                        "kb_id": kb_id,
                    },
                )
            logger.info(
                f"Graph built successfully for document {document_id}: {entity_count} entities, {len(relations)} relations"
            )
            return {
                "success": True,
                "document_id": document_id,
                "entity_count": entity_count,
                "relation_count": len(relations),
            }
        except Exception as e:
            logger.error(f"Failed to build graph for document {document_id}: {e}")
            return {"success": False, "document_id": document_id, "error": str(e)}

    def _create_document_node(
        self,
        document_id: int,
        title: str,
        summary: str | None,
        tenant_id: int,
        kb_id: int,
        doc_number: str | None = None,
        doc_status: str | None = None,
    ):
        tenant_id = tenant_id if tenant_id is not None else 0
        query = "\n        MERGE (d:Document {id: $doc_id, tenant_id: $tenant_id, kb_id: $kb_id})\n        SET d.title = $title,\n            d.summary = $summary,\n            d.doc_number = $doc_number,\n            d.doc_status = $doc_status,\n            d.updated_at = datetime()\n        "
        self.neo4j_client.execute_write(
            query,
            parameters={
                "doc_id": document_id,
                "title": title,
                "summary": summary or "",
                "kb_id": kb_id,
                "doc_number": doc_number or "",
                "doc_status": doc_status or "effective",
            },
            tenant_id=tenant_id,
        )

    def _create_entity_nodes(
        self, entities: list, document_id: int, tenant_id: int, kb_id: int
    ) -> int:
        if not entities:
            return 0
        query = "\n        UNWIND $entities AS entity\n        MERGE (e:Entity {\n            name: entity.name,\n            tenant_id: $tenant_id,\n            kb_id: $kb_id\n        })\n        SET e.type = entity.type,\n            e.description = entity.description,\n            e.confidence = entity.salience\n        WITH e, entity\n        MATCH (d:Document {id: $doc_id, tenant_id: $tenant_id, kb_id: $kb_id})\n        MERGE (d)-[r:CONTAINS]->(e)\n        SET r.salience = entity.salience,\n            r.frequency = 1\n        "
        self.neo4j_client.execute_write(
            query,
            parameters={"entities": entities, "doc_id": document_id, "kb_id": kb_id},
            tenant_id=tenant_id,
        )
        return len(entities)

    def _link_tags(self, document_id: int, tag_ids: list, tenant_id: int, kb_id: int):
        query = "\n        MATCH (d:Document {id: $doc_id, tenant_id: $tenant_id, kb_id: $kb_id})\n        UNWIND $tag_ids AS tag_id\n        MERGE (t:Tag {id: tag_id, tenant_id: $tenant_id, kb_id: $kb_id})\n        MERGE (d)-[:HAS_TAG]->(t)\n        "
        self.neo4j_client.execute_write(
            query,
            parameters={
                "doc_id": document_id,
                "tag_ids": tag_ids,
                "tenant_id": tenant_id,
                "kb_id": kb_id,
            },
        )

    def _link_category(self, document_id: int, category_id: int, tenant_id: int, kb_id: int):
        query = "\n        MATCH (d:Document {id: $doc_id, tenant_id: $tenant_id, kb_id: $kb_id})\n        MERGE (c:Category {id: $cat_id, tenant_id: $tenant_id, kb_id: $kb_id})\n        MERGE (d)-[:IN_CATEGORY]->(c)\n        "
        self.neo4j_client.execute_write(
            query,
            parameters={
                "doc_id": document_id,
                "cat_id": category_id,
                "tenant_id": tenant_id,
                "kb_id": kb_id,
            },
        )

    def build_chunk_graph(
        self, document_id: int, chunks: list[dict[str, Any]], tenant_id: int, kb_id: int
    ) -> dict[str, Any]:
        try:
            logger.info(f"Building chunk graph for document {document_id}: {len(chunks)} chunks")
            chunk_count = 0
            reference_count = 0
            for chunk in chunks:
                chunk_id = chunk.get("chunk_id")
                if not chunk_id:
                    logger.warning(
                        f"Chunk missing chunk_id, skipping: {chunk.get('chunk_index', 'unknown')}"
                    )
                    continue
                self.neo4j_client.create_chunk_node(
                    chunk_id=chunk_id,
                    document_id=document_id,
                    chunk_text=chunk.get("text", ""),
                    chunk_level=chunk.get("chunk_level", ""),
                    chunk_index=chunk.get("chunk_index", 0),
                    is_parent=chunk.get("is_parent", False),
                    parent_chunk_id=chunk.get("parent_chunk_id"),
                    tenant_id=tenant_id,
                    kb_id=kb_id,
                    doc_type=chunk.get("metadata", {}).get("doc_type"),
                    doc_number=chunk.get("metadata", {}).get("doc_number"),
                )
                chunk_count += 1
            for chunk in chunks:
                chunk_id = chunk.get("chunk_id")
                references = chunk.get("references", [])
                if not chunk_id or not references:
                    continue
                for ref in references:
                    if not isinstance(ref, dict):
                        continue
                    target_doc_number = ref.get("target_doc_number")
                    if not target_doc_number:
                        continue
                    try:
                        self.neo4j_client.create_chunk_reference(
                            source_chunk_id=chunk_id,
                            target_doc_number=target_doc_number,
                            target_article=ref.get("article_number"),
                            confidence=ref.get("confidence", 0.8),
                            tenant_id=tenant_id,
                            kb_id=kb_id,
                        )
                        reference_count += 1
                    except Exception as e:
                        logger.warning(
                            f"Failed to create reference from {chunk_id} to {target_doc_number}: {e}"
                        )
            logger.info(
                f"Chunk graph built successfully for document {document_id}: {chunk_count} chunks, {reference_count} references"
            )
            return {
                "success": True,
                "document_id": document_id,
                "chunk_count": chunk_count,
                "reference_count": reference_count,
            }
        except Exception as e:
            logger.error(f"Failed to build chunk graph for document {document_id}: {e}")
            return {"success": False, "document_id": document_id, "error": str(e)}

    def rebuild_knowledge_base_graph(
        self, kb_id: int, tenant_id: int, documents: list, db_session=None
    ) -> dict[str, Any]:
        from app.db.session import SessionLocal
        from app.models.knowledge_base import KnowledgeBase, KnowledgeDocument

        logger.info(f"Rebuilding graph for knowledge base {kb_id}")
        db = SessionLocal()
        try:
            kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first()
            if kb:
                kb.graph_status = "building"
                db.commit()
            success_count = 0
            failed_count = 0
            total_entities = 0
            for doc in documents:
                try:
                    result = self.build_document_graph(
                        document_id=doc["id"],
                        title=doc["title"],
                        content=doc["content"],
                        summary=doc.get("summary"),
                        tenant_id=tenant_id,
                        kb_id=kb_id,
                        category_id=doc.get("category_id"),
                        tag_ids=doc.get("tag_ids", []),
                        db_session=db,
                    )
                    if result["success"]:
                        success_count += 1
                        total_entities += result["entity_count"]
                        doc_obj = (
                            db.query(KnowledgeDocument)
                            .filter(KnowledgeDocument.id == doc["id"])
                            .first()
                        )
                        if doc_obj:
                            doc_obj.graph_status = "completed"
                            doc_obj.entity_count = result["entity_count"]
                            db.commit()
                    else:
                        failed_count += 1
                        doc_obj = (
                            db.query(KnowledgeDocument)
                            .filter(KnowledgeDocument.id == doc["id"])
                            .first()
                        )
                        if doc_obj:
                            doc_obj.graph_status = "failed"
                            db.commit()
                    logger.info(
                        f"Processed document {doc['id']}: {result['success']}, entities: {result.get('entity_count', 0)}"
                    )
                except Exception as e:
                    failed_count += 1
                    logger.error(f"Failed to process document {doc['id']}: {e}")
            try:
                self.relation_extractor.create_co_occurrence_relations(tenant_id, kb_id)
            except Exception as e:
                logger.error(f"Failed to create co-occurrence relations: {e}")
            kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first()
            if kb:
                kb.graph_status = "completed" if failed_count == 0 else "failed"
                kb.entity_count = total_entities
                db.commit()
            logger.info(
                f"Graph rebuild completed for kb {kb_id}: {success_count} success, {failed_count} failed, {total_entities} total entities"
            )
            return {
                "success": True,
                "kb_id": kb_id,
                "success_count": success_count,
                "failed_count": failed_count,
                "total_entities": total_entities,
            }
        except Exception as e:
            logger.error(f"Graph rebuild failed for kb {kb_id}: {e}")
            try:
                kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first()
                if kb:
                    kb.graph_status = "failed"
                    db.commit()
            except Exception:
                pass
            raise
        finally:
            db.close()


graph_builder: GraphBuilder | None = None


def get_graph_builder() -> GraphBuilder:
    global graph_builder
    if graph_builder is None:
        from app.services.graph.entity_extractor import entity_extractor
        from app.services.graph.neo4j_client import neo4j_client
        from app.services.graph.relation_extractor import get_relation_extractor

        graph_builder = GraphBuilder(neo4j_client, entity_extractor, get_relation_extractor())
    return graph_builder
