from typing import Any

from sqlalchemy.orm import Session

from app.config import settings
from app.services.graph.graph_query import get_graph_query_service
from app.services.graph.graph_reranker import get_graph_reranker
from app.services.llm.backends.embedding_backend_factory import get_embedding_factory
from app.services.storage.vector_store_factory import get_vector_store

from common_logging import get_logger

logger = get_logger(__name__)



class GraphEnhancedRetrieval:

    def __init__(self, db: Session, knowledge_base_id: int, tenant_id: int):
        self.db = db
        self.knowledge_base_id = knowledge_base_id
        self.tenant_id = tenant_id
        self.vector_store = get_vector_store(db, knowledge_base_id)
        self.embedding_factory = get_embedding_factory()
        self.graph_query_service = get_graph_query_service()
        self.graph_reranker = get_graph_reranker()

    def retrieve(
        self,
        query: str,
        k: int = 5,
        threshold: float = 0.7,
        model_id: int | None = None,
        expand_depth: int = 1,
        relation_types: list[str] | None = None,
        enable_rerank: bool = True,
    ) -> list[dict[str, Any]]:
        try:
            vector_results = self._vector_recall(query, k, threshold, model_id)
            if not vector_results:
                logger.info("No vector results found")
                return []
            doc_ids = list({r["document_id"] for r in vector_results})
            logger.info(f"Vector recall returned {len(doc_ids)} unique documents")
            expanded_docs = self._graph_expansion(doc_ids, expand_depth, relation_types)
            logger.info(f"Graph expansion added {len(expanded_docs)} neighbor documents")
            all_results = self._merge_results(vector_results, expanded_docs)
            if enable_rerank and len(all_results) > k:
                all_results = self._rerank(query, all_results, model_id)
            return all_results[:k]
        except Exception as e:
            logger.error(f"Graph-enhanced retrieval failed: {e}")
            return self._vector_recall(query, k, threshold, model_id)

    def _vector_recall(
        self, query: str, k: int, threshold: float, model_id: int | None
    ) -> list[dict[str, Any]]:
        query_embedding = self.embedding_factory.generate_embedding(
            text=query, db=self.db, model_id=model_id
        )
        if not query_embedding:
            return []
        results = self.vector_store.similarity_search(
            query_embedding=query_embedding, k=k * 2, threshold=threshold
        )
        formatted_results = []
        for doc, score in results:
            formatted_results.append(
                {
                    "id": doc["id"],
                    "document_id": doc["document_id"],
                    "title": doc["title"],
                    "text": doc["text"],
                    "chunk_index": doc["chunk_index"],
                    "score": score,
                    "source": "vector",
                    "metadata": {
                        "category_id": doc.get("category_id"),
                        "model_name": doc.get("model_name"),
                    },
                }
            )
        return formatted_results

    def _graph_expansion(
        self, doc_ids: list[int], depth: int, relation_types: list[str] | None
    ) -> list[dict[str, Any]]:
        if not settings.ENABLE_KNOWLEDGE_GRAPH:
            return []
        try:
            neighbors = self.graph_query_service.expand_neighbors(
                document_ids=doc_ids,
                tenant_id=self.tenant_id,
                kb_id=self.knowledge_base_id,
                depth=min(depth, settings.GRAPH_EXPAND_DEPTH),
                relation_types=relation_types,
                limit=20,
            )
            formatted_neighbors = []
            for neighbor in neighbors:
                formatted_neighbors.append(
                    {
                        "document_id": neighbor["id"],
                        "title": neighbor["title"],
                        "text": neighbor.get("summary", ""),
                        "score": neighbor.get("score", 0.5),
                        "source": "graph",
                        "relation_type": neighbor.get("relation_type", "unknown"),
                        "metadata": {},
                    }
                )
            return formatted_neighbors
        except Exception as e:
            logger.error(f"Graph expansion failed: {e}")
            return []

    def _merge_results(
        self, vector_results: list[dict[str, Any]], graph_results: list[dict[str, Any]]
    ) -> list[dict[str, Any]]:
        result_map = {}
        for result in vector_results:
            doc_id = result["document_id"]
            result_map[doc_id] = result
        for result in graph_results:
            doc_id = result["document_id"]
            if doc_id not in result_map:
                result_map[doc_id] = result
        merged_results = list(result_map.values())
        merged_results.sort(key=lambda x: x["score"], reverse=True)
        return merged_results

    def _rerank(
        self, query: str, results: list[dict[str, Any]], model_id: int | None
    ) -> list[dict[str, Any]]:
        try:
            reranked = self.graph_reranker.rerank(
                results=results,
                query=query,
                tenant_id=self.tenant_id,
                kb_id=self.knowledge_base_id,
                db=self.db,
            )
            return reranked
        except Exception as e:
            logger.error(f"Graph reranking failed: {e}")
            for result in results:
                if result.get("source") == "graph":
                    result["score"] = min(result["score"] * 1.1, 1.0)
            results.sort(key=lambda x: x["score"], reverse=True)
            return results


def get_graph_enhanced_retrieval(
    db: Session, knowledge_base_id: int, tenant_id: int
) -> GraphEnhancedRetrieval:
    return GraphEnhancedRetrieval(db, knowledge_base_id, tenant_id)
