import math
from typing import Any

from app.services.graph.neo4j_client import Neo4jClient

from common_logging import get_logger

logger = get_logger(__name__)


_OBSOLETE_SCORE_PENALTY = 0.15
_AMENDED_SCORE_PENALTY = 0.6


class GraphReranker:

    def __init__(self, neo4j_client: Neo4jClient):
        self.neo4j_client = neo4j_client

    def rerank(
        self,
        results: list[dict[str, Any]],
        query: str,
        tenant_id: int,
        kb_id: int,
        weights: dict[str, float] | None = None,
        db=None,
    ) -> list[dict[str, Any]]:
        if not results:
            return results
        if weights is None:
            weights = {
                "vector_score": 0.4,
                "pagerank": 0.2,
                "entity_overlap": 0.2,
                "path_distance": 0.2,
            }
        doc_ids = [r["document_id"] for r in results]
        doc_status_map: dict[int, str] = {}
        if db is not None:
            try:
                from app.models.knowledge_base import KnowledgeDocument

                rows = (
                    db.query(KnowledgeDocument.id, KnowledgeDocument.doc_status)
                    .filter(KnowledgeDocument.id.in_(doc_ids))
                    .all()
                )
                doc_status_map = {row.id: row.doc_status or "effective" for row in rows}
            except Exception as e:
                logger.warning(f"[GraphReranker] doc_status 查询失败（非致命）: {e}")
        pagerank_scores = self._calculate_pagerank(doc_ids, tenant_id, kb_id)
        entity_overlap_scores = self._calculate_entity_overlap(doc_ids, tenant_id, kb_id)
        path_distances = self._calculate_path_distances(doc_ids, tenant_id, kb_id)
        for result in results:
            doc_id = result["document_id"]
            vector_score = result.get("score", 0.5)
            pagerank = pagerank_scores.get(doc_id, 0.0)
            entity_overlap = entity_overlap_scores.get(doc_id, 0.0)
            path_distance = path_distances.get(doc_id, 0.0)
            combined_score = (
                weights["vector_score"] * vector_score
                + weights["pagerank"] * pagerank
                + weights["entity_overlap"] * entity_overlap
                + weights["path_distance"] * path_distance
            )
            status = doc_status_map.get(doc_id, "effective")
            if status == "obsolete":
                combined_score *= _OBSOLETE_SCORE_PENALTY
                result["doc_status"] = "obsolete"
                result["status_penalized"] = True
            elif status == "amended":
                combined_score *= _AMENDED_SCORE_PENALTY
                result["doc_status"] = "amended"
                result["status_penalized"] = True
            else:
                result["doc_status"] = status
                result["status_penalized"] = False
            result["original_score"] = vector_score
            result["pagerank_score"] = pagerank
            result["entity_overlap_score"] = entity_overlap
            result["path_distance_score"] = path_distance
            result["score"] = combined_score
        results.sort(key=lambda x: x["score"], reverse=True)
        penalized_count = sum(1 for r in results if r.get("status_penalized"))
        logger.info(
            f"Reranked {len(results)} results using graph features ({penalized_count} obsolete/amended penalized)"
        )
        return results

    def _calculate_pagerank(
        self, doc_ids: list[int], tenant_id: int, kb_id: int
    ) -> dict[int, float]:
        try:
            query = "\n            MATCH (d:Document)\n            WHERE d.id IN $doc_ids\n              AND d.tenant_id = $tenant_id\n              AND d.kb_id = $kb_id\n            CALL gds.pageRank.stream({\n                nodeProjection: 'Document',\n                relationshipProjection: {\n                    REFERENCES: {\n                        type: 'REFERENCES',\n                        orientation: 'NATURAL'\n                    },\n                    SIMILAR_TO: {\n                        type: 'SIMILAR_TO',\n                        orientation: 'UNDIRECTED'\n                    }\n                },\n                maxIterations: 20,\n                dampingFactor: 0.85\n            })\n            YIELD nodeId, score\n            WITH gds.util.asNode(nodeId) AS doc, score\n            WHERE doc.id IN $doc_ids\n            RETURN doc.id as doc_id, score\n            "
            fallback_query = "\n            MATCH (d:Document)\n            WHERE d.id IN $doc_ids\n              AND d.tenant_id = $tenant_id\n              AND d.kb_id = $kb_id\n            OPTIONAL MATCH (d)-[r:REFERENCES|SIMILAR_TO]-(other:Document)\n            WHERE other.tenant_id = $tenant_id\n              AND other.kb_id = $kb_id\n            WITH d.id as doc_id, count(r) as degree\n            RETURN doc_id, degree\n            "
            try:
                results = self.neo4j_client.execute_query(
                    query, parameters={"doc_ids": doc_ids, "tenant_id": tenant_id, "kb_id": kb_id}
                )
            except Exception:
                results = self.neo4j_client.execute_query(
                    fallback_query,
                    parameters={"doc_ids": doc_ids, "tenant_id": tenant_id, "kb_id": kb_id},
                )
                max_degree = max([r["degree"] for r in results]) if results else 1
                results = [
                    {"doc_id": r["doc_id"], "score": r["degree"] / max_degree} for r in results
                ]
            scores = {r["doc_id"]: r["score"] for r in results}
            if scores:
                max_score = max(scores.values())
                if max_score > 0:
                    scores = {k: v / max_score for k, v in scores.items()}
            return scores
        except Exception as e:
            logger.error(f"Failed to calculate PageRank: {e}")
            return dict.fromkeys(doc_ids, 0.0)

    def _calculate_entity_overlap(
        self, doc_ids: list[int], tenant_id: int, kb_id: int
    ) -> dict[int, float]:
        try:
            query = "\n            MATCH (d1:Document)\n            WHERE d1.id IN $doc_ids\n              AND d1.tenant_id = $tenant_id\n              AND d1.kb_id = $kb_id\n            MATCH (d1)-[:CONTAINS]->(e:Entity)<-[:CONTAINS]-(d2:Document)\n            WHERE d2.id IN $doc_ids\n              AND d2.tenant_id = $tenant_id\n              AND d2.kb_id = $kb_id\n              AND d1.id <> d2.id\n            WITH d1.id as doc_id, count(DISTINCT e) as shared_entities\n            RETURN doc_id, shared_entities\n            "
            results = self.neo4j_client.execute_query(
                query, parameters={"doc_ids": doc_ids, "tenant_id": tenant_id, "kb_id": kb_id}
            )
            scores = {r["doc_id"]: r["shared_entities"] for r in results}
            if scores:
                max_score = max(scores.values())
                if max_score > 0:
                    scores = {k: v / max_score for k, v in scores.items()}
            for doc_id in doc_ids:
                if doc_id not in scores:
                    scores[doc_id] = 0.0
            return scores
        except Exception as e:
            logger.error(f"Failed to calculate entity overlap: {e}")
            return dict.fromkeys(doc_ids, 0.0)

    def _calculate_path_distances(
        self, doc_ids: list[int], tenant_id: int, kb_id: int
    ) -> dict[int, float]:
        try:
            query = "\n            MATCH (d1:Document)\n            WHERE d1.id IN $doc_ids\n              AND d1.tenant_id = $tenant_id\n              AND d1.kb_id = $kb_id\n            MATCH (d2:Document)\n            WHERE d2.id IN $doc_ids\n              AND d2.tenant_id = $tenant_id\n              AND d2.kb_id = $kb_id\n              AND d1.id <> d2.id\n            MATCH path = shortestPath((d1)-[*..3]-(d2))\n            WITH d1.id as doc_id, avg(length(path)) as avg_distance\n            RETURN doc_id, avg_distance\n            "
            results = self.neo4j_client.execute_query(
                query, parameters={"doc_ids": doc_ids, "tenant_id": tenant_id, "kb_id": kb_id}
            )
            scores = {}
            for r in results:
                distance = r["avg_distance"]
                scores[r["doc_id"]] = math.exp(-distance)
            for doc_id in doc_ids:
                if doc_id not in scores:
                    scores[doc_id] = 0.1
            return scores
        except Exception as e:
            logger.error(f"Failed to calculate path distances: {e}")
            return dict.fromkeys(doc_ids, 0.5)


graph_reranker: GraphReranker | None = None


def get_graph_reranker() -> GraphReranker:
    global graph_reranker
    if graph_reranker is None:
        from app.services.graph.neo4j_client import neo4j_client


        graph_reranker = GraphReranker(neo4j_client)
    return graph_reranker
