from typing import Any

from app.celery_app import celery_app
from app.services.graph.entity_linker import get_entity_linker
from app.services.graph.graph_builder import get_graph_builder
from app.services.graph.graph_monitoring import get_graph_metrics, graph_logger

from common_logging import get_logger

logger = get_logger(__name__)




@celery_app.task(bind=True, name="graph.build_knowledge_base")
def build_knowledge_base_task(
    self, kb_id: int, tenant_id: int, documents: list[dict[str, Any]], link_entities: bool = True
) -> dict[str, Any]:
    try:
        logger.info(f"Starting async graph build for kb {kb_id}, {len(documents)} documents")
        graph_logger.log_graph_build_start(kb_id, len(documents))
        self.update_state(
            state="PROGRESS",
            meta={"current": 0, "total": len(documents), "status": "Building graph..."},
        )
        builder = get_graph_builder()
        success_count = 0
        failed_count = 0
        total_entities = 0
        for idx, doc in enumerate(documents):
            try:
                result = builder.build_document_graph(
                    document_id=doc["id"],
                    title=doc["title"],
                    content=doc["content"],
                    summary=doc.get("summary"),
                    tenant_id=tenant_id,
                    kb_id=kb_id,
                    category_id=doc.get("category_id"),
                    tag_ids=doc.get("tag_ids", []),
                )
                if result["success"]:
                    success_count += 1
                    total_entities += result["entity_count"]
                else:
                    failed_count += 1
                self.update_state(
                    state="PROGRESS",
                    meta={
                        "current": idx + 1,
                        "total": len(documents),
                        "status": f"Processed {idx + 1}/{len(documents)} documents",
                        "success": success_count,
                        "failed": failed_count,
                    },
                )
            except Exception as e:
                logger.error(f"Failed to build graph for document {doc['id']}: {e}")
                failed_count += 1
        if link_entities and success_count > 0:
            self.update_state(
                state="PROGRESS",
                meta={
                    "current": len(documents),
                    "total": len(documents),
                    "status": "Linking entities...",
                },
            )
            linker = get_entity_linker()
            linking_result = linker.link_entities(tenant_id, kb_id)
            logger.info(f"Entity linking completed: {linking_result}")
        metrics = get_graph_metrics()
        metrics.record_graph_build(duration=0, success=True, entity_count=total_entities)
        result = {
            "success": True,
            "kb_id": kb_id,
            "documents_processed": success_count,
            "documents_failed": failed_count,
            "total_entities": total_entities,
        }
        graph_logger.log_graph_build_complete(
            kb_id=kb_id, duration=0, entity_count=total_entities, success=True
        )
        return result
    except Exception as e:
        logger.error(f"Async graph build failed for kb {kb_id}: {e}")
        graph_logger.log_error("async_build", e, {"kb_id": kb_id})
        return {"success": False, "kb_id": kb_id, "error": str(e)}


@celery_app.task(bind=True, name="graph.build_document")
def build_document_task(
    self,
    document_id: int,
    title: str,
    content: str,
    summary: str | None,
    tenant_id: int,
    kb_id: int,
    category_id: int | None = None,
    tag_ids: list[int] | None = None,
    doc_number: str | None = None,
    doc_status: str | None = None,
    supersedes_doc_ids: list[int] | None = None,
) -> dict[str, Any]:
    try:
        logger.info(f"Starting async graph build for document {document_id}")
        self.update_state(state="PROGRESS", meta={"status": "Extracting entities..."})
        builder = get_graph_builder()
        result = builder.build_document_graph(
            document_id=document_id,
            title=title,
            content=content,
            summary=summary,
            tenant_id=tenant_id,
            kb_id=kb_id,
            category_id=category_id,
            tag_ids=tag_ids or [],
            doc_number=doc_number,
            doc_status=doc_status,
            supersedes_doc_ids=supersedes_doc_ids,
        )
        return result
    except Exception as e:
        logger.error(f"Async document build failed for {document_id}: {e}")
        return {"success": False, "document_id": document_id, "error": str(e)}


@celery_app.task(bind=True, name="graph.link_entities")
def link_entities_task(
    self, tenant_id: int, kb_id: int, similarity_threshold: float = 0.85
) -> dict[str, Any]:
    try:
        logger.info(f"Starting async entity linking for kb {kb_id}")
        self.update_state(state="PROGRESS", meta={"status": "Finding candidate pairs..."})
        linker = get_entity_linker()
        result = linker.link_entities(tenant_id, kb_id, similarity_threshold)
        return result
    except Exception as e:
        logger.error(f"Async entity linking failed for kb {kb_id}: {e}")
        return {"success": False, "kb_id": kb_id, "error": str(e)}


@celery_app.task(name="graph.cleanup_orphaned_entities")
def cleanup_orphaned_entities_task(tenant_id: int, kb_id: int) -> dict[str, Any]:
    try:
        from app.services.graph.neo4j_client import neo4j_client


        query = "\n        MATCH (e:Entity {tenant_id: $tenant_id, kb_id: $kb_id})\n        WHERE NOT (e)<-[:CONTAINS]-()\n        WITH e\n        DETACH DELETE e\n        RETURN count(e) as deleted_count\n        "
        result = neo4j_client.execute_write(
            query, parameters={"tenant_id": tenant_id, "kb_id": kb_id}
        )
        deleted_count = result[0]["deleted_count"] if result else 0
        logger.info(f"Cleaned up {deleted_count} orphaned entities for kb {kb_id}")
        return {"success": True, "kb_id": kb_id, "deleted_count": deleted_count}
    except Exception as e:
        logger.error(f"Cleanup failed for kb {kb_id}: {e}")
        return {"success": False, "kb_id": kb_id, "error": str(e)}
