from typing import Any

from sqlalchemy.orm import Session

from common_logging import get_logger

logger = get_logger(__name__)




class HybridRetriever:

    def __init__(self, db: Session, knowledge_base_id: int, tenant_id: int):
        self.db = db
        self.knowledge_base_id = knowledge_base_id
        self.tenant_id = tenant_id
        self._init_services()

    def _init_services(self):
        from app.services.graph.graph_query import get_graph_query_service
        from app.services.rag.proposition_retriever import get_proposition_retriever
        from app.services.rag.sliding_window_retriever import get_sliding_window_retriever

        self.proposition_retriever = get_proposition_retriever(
            self.db, self.knowledge_base_id, self.tenant_id
        )
        self.sliding_window_retriever = get_sliding_window_retriever(
            self.db, self.knowledge_base_id, self.tenant_id
        )
        self.graph_query = get_graph_query_service()

    def retrieve(
        self,
        query: str,
        k: int = 5,
        expand_graph: bool = True,
        expand_parent: bool = True,
        rerank: bool = True,
        **kwargs,
    ) -> list[dict[str, Any]]:
        try:
            vector_results = self.proposition_retriever.retrieve(
                query, k=k * 3, return_parent=expand_parent
            )
            logger.info(f"Proposition recall: {len(vector_results)} results")
            if not vector_results:
                return []
            if expand_graph:
                vector_results = self._graph_expand(query, vector_results)
                logger.info(f"After graph expand: {len(vector_results)} results")
            if rerank:
                vector_results = self._rerank(query, vector_results, k)
            else:
                vector_results = vector_results[:k]
            if vector_results:
                top_score = vector_results[0].get("score", 1.0)
                if top_score < 0.6:
                    top_chunk_id = vector_results[0].get("id") or vector_results[0].get("milvus_id")
                    if top_chunk_id:
                        expanded_text = self.sliding_window_retriever.get_expanded_context(
                            top_chunk_id, window_size=2
                        )
                        if expanded_text:
                            vector_results[0] = dict(vector_results[0])
                            vector_results[0]["text"] = expanded_text
                            vector_results[0]["is_window_expanded"] = True
                            logger.info(
                                f"Sliding window fallback triggered (score={top_score:.3f})"
                            )
            return vector_results
        except Exception as e:
            logger.error(f"Hybrid retrieval failed: {e}")
            return self.proposition_retriever.retrieve(query, k=k)

    def _graph_expand(self, query: str, results: list[dict]) -> list[dict]:
        try:
            doc_ids = list({r.get("document_id") for r in results if r.get("document_id")})
            if doc_ids:
                expanded = self.graph_query.expand_neighbors(
                    document_ids=doc_ids,
                    tenant_id=self.tenant_id,
                    kb_id=self.knowledge_base_id,
                    depth=1,
                    relation_types=["REFERENCES", "AMENDS"],
                    limit=20,
                )
                existing_ids = {r.get("id") for r in results}
                for doc in expanded:
                    if doc.get("id") not in existing_ids:
                        doc["score"] = doc.get("score", 0.5) * 0.8
                        doc["is_graph_expanded"] = True
                        results.append(doc)
                        existing_ids.add(doc.get("id"))
            results = self._expand_chunk_references(results)
            return results
        except Exception as e:
            logger.warning(f"Graph expand failed (non-fatal): {e}")
            return results

    def _expand_chunk_references(self, results: list[dict]) -> list[dict]:
        try:
            from app.config import settings
            from app.models.knowledge_base import DocumentVector
            from app.services.graph.neo4j_client import get_neo4j_client

            if not settings.ENABLE_KNOWLEDGE_GRAPH:
                return self._expand_chunk_references_fallback(results)
            neo4j_client = get_neo4j_client()
            if not neo4j_client.driver:
                logger.warning("Neo4j not available, using fallback reference expansion")
                return self._expand_chunk_references_fallback(results)
            chunk_ids = []
            for result in results:
                chunk_id = result.get("chunk_id")
                if chunk_id:
                    chunk_ids.append(chunk_id)
            if not chunk_ids:
                return results
            referenced_chunks_data = []
            for chunk_id in chunk_ids:
                try:
                    refs = neo4j_client.get_referenced_chunks(
                        chunk_id=chunk_id,
                        tenant_id=self.tenant_id,
                        kb_id=self.knowledge_base_id,
                        max_depth=1,
                        min_confidence=0.5,
                    )
                    referenced_chunks_data.extend(refs)
                except Exception as e:
                    logger.warning(f"Failed to get references for chunk {chunk_id}: {e}")
            if not referenced_chunks_data:
                return results
            seen_chunk_ids = {r.get("chunk_id") for r in results}
            unique_refs = []
            for ref in referenced_chunks_data:
                ref_chunk_id = ref.get("chunk_id")
                if ref_chunk_id and ref_chunk_id not in seen_chunk_ids:
                    unique_refs.append(ref)
                    seen_chunk_ids.add(ref_chunk_id)
                    if len(unique_refs) >= 10:
                        break
            ref_chunk_ids = [ref.get("chunk_id") for ref in unique_refs]
            if ref_chunk_ids:
                referenced_chunks = (
                    self.db.query(DocumentVector)
                    .filter(DocumentVector.chunk_id.in_(ref_chunk_ids))
                    .all()
                )
                for chunk in referenced_chunks:
                    graph_data = next(
                        (r for r in unique_refs if r.get("chunk_id") == chunk.chunk_id), None
                    )
                    confidence_path = (
                        graph_data.get("confidence_path", [0.7]) if graph_data else [0.7]
                    )
                    avg_confidence = (
                        sum(confidence_path) / len(confidence_path) if confidence_path else 0.7
                    )
                    results.append(
                        {
                            "id": chunk.milvus_id,
                            "document_id": chunk.document_id,
                            "chunk_text": chunk.chunk_text,
                            "text": chunk.chunk_text,
                            "chunk_id": chunk.chunk_id,
                            "chunk_level": chunk.chunk_level,
                            "parent_chunk_id": chunk.parent_chunk_id,
                            "is_parent": chunk.is_parent,
                            "references": chunk.references,
                            "score": avg_confidence,
                            "is_reference_expanded": True,
                            "reference_depth": graph_data.get("depth", 1) if graph_data else 1,
                        }
                    )
                logger.info(
                    f"Graph-based reference expansion added {len(referenced_chunks)} chunks"
                )
            return results
        except Exception as e:
            logger.warning(f"Graph-based chunk reference expansion failed, using fallback: {e}")
            return self._expand_chunk_references_fallback(results)

    def _expand_chunk_references_fallback(self, results: list[dict]) -> list[dict]:
        try:
            from app.models.knowledge_base import DocumentVector, KnowledgeDocument

            referenced_doc_numbers = set()
            for result in results:
                references = result.get("references", [])
                if isinstance(references, str):
                    import json

                    try:
                        references = json.loads(references)
                    except Exception:
                        references = []
                for ref in references:
                    if isinstance(ref, dict):
                        target_doc_number = ref.get("target_doc_number")
                        if target_doc_number:
                            referenced_doc_numbers.add(target_doc_number)
            if not referenced_doc_numbers:
                return results
            referenced_docs = (
                self.db.query(KnowledgeDocument)
                .filter(
                    KnowledgeDocument.doc_number.in_(list(referenced_doc_numbers)),
                    KnowledgeDocument.tenant_id == self.tenant_id,
                    KnowledgeDocument.doc_status.notin_(["obsolete", "expired"]),
                )
                .all()
            )
            if not referenced_docs:
                return results
            referenced_doc_ids = [doc.id for doc in referenced_docs]
            referenced_chunks = (
                self.db.query(DocumentVector)
                .filter(
                    DocumentVector.document_id.in_(referenced_doc_ids),
                    not DocumentVector.is_parent,
                )
                .limit(10)
                .all()
            )
            existing_ids = {r.get("id") for r in results}
            for chunk in referenced_chunks:
                if chunk.milvus_id not in existing_ids:
                    results.append(
                        {
                            "id": chunk.milvus_id,
                            "document_id": chunk.document_id,
                            "chunk_text": chunk.chunk_text,
                            "text": chunk.chunk_text,
                            "chunk_id": chunk.chunk_id,
                            "chunk_level": chunk.chunk_level,
                            "parent_chunk_id": chunk.parent_chunk_id,
                            "is_parent": chunk.is_parent,
                            "references": chunk.references,
                            "score": 0.7,
                            "is_reference_expanded": True,
                        }
                    )
                    existing_ids.add(chunk.milvus_id)
            logger.info(
                f"Fallback reference expansion added {len(referenced_chunks)} chunks from {len(referenced_docs)} documents"
            )
            return results
        except Exception as e:
            logger.warning(f"Chunk reference expansion failed (non-fatal): {e}")
            return results

    def _parent_expand(self, results: list[dict]) -> list[dict]:
        try:
            from app.models.knowledge_base import DocumentVector

            expanded = []
            for result in results:
                parent_chunk_id = result.get("parent_chunk_id")
                if parent_chunk_id:
                    parent = (
                        self.db.query(DocumentVector)
                        .filter(DocumentVector.chunk_id == parent_chunk_id)
                        .first()
                    )
                    if parent:
                        result = dict(result)
                        result["child_text"] = result.get("chunk_text", result.get("text", ""))
                        result["text"] = parent.chunk_text
                        result["chunk_text"] = parent.chunk_text
                        result["is_expanded"] = True
                expanded.append(result)
            return expanded
        except Exception as e:
            logger.warning(f"Parent expand failed (non-fatal): {e}")
            return results

    def _rerank(self, query: str, results: list[dict], k: int) -> list[dict]:
        try:
            from app.services.graph.graph_reranker import get_graph_reranker


            reranker = get_graph_reranker()
            return reranker.rerank(
                results, query, self.tenant_id, self.knowledge_base_id, db=self.db
            )[:k]
        except Exception as e:
            logger.warning(f"Rerank failed, using score order: {e}")
            return sorted(results, key=lambda x: x.get("score", 0), reverse=True)[:k]


def get_hybrid_retriever(db: Session, knowledge_base_id: int, tenant_id: int) -> HybridRetriever:
    return HybridRetriever(db, knowledge_base_id, tenant_id)
