from typing import Any

from sqlalchemy.orm import Session

from common_logging import get_logger

logger = get_logger(__name__)




class PropositionRetriever:

    def __init__(self, db: Session, knowledge_base_id: int, tenant_id: int):
        self.db = db
        self.knowledge_base_id = knowledge_base_id
        self.tenant_id = tenant_id
        self._init_services()

    def _init_services(self):
        from app.services.llm.backends.embedding_backend_factory import get_embedding_factory
        from app.services.storage.vector_store_factory import get_vector_store

        self.vector_store = get_vector_store(self.db, self.knowledge_base_id)
        self.embedding_factory = get_embedding_factory()

    def retrieve(
        self, query: str, k: int = 5, return_parent: bool = True, model_id: int = None
    ) -> list[dict[str, Any]]:
        try:
            if model_id is None:
                from app.models.knowledge_base import KnowledgeBase

                kb = (
                    self.db.query(KnowledgeBase)
                    .filter(KnowledgeBase.id == self.knowledge_base_id)
                    .first()
                )
                if kb and kb.code:
                    model_id = int(kb.code)
                else:
                    logger.error(
                        f"No model_id configured for knowledge base {self.knowledge_base_id}"
                    )
                    return []
            embedding = self.embedding_factory.generate_embedding(
                text=query, model_id=model_id, db=self.db
            )
            if not embedding:
                logger.error("Failed to generate embedding")
                return []
            child_results = self.vector_store.similarity_search(
                query_embedding=embedding, k=k, filter_dict={"is_parent": False}
            )
            if not child_results:
                return []
            if return_parent:
                return self._expand_with_parent(child_results)
            return child_results
        except Exception as e:
            logger.error(f"Proposition retrieval failed: {e}")
            return []

    def _expand_with_parent(self, child_results: list[dict]) -> list[dict]:
        from app.models.knowledge_base import DocumentVector


        expanded = []
        for result in child_results:
            parent_chunk_id = result.get("parent_chunk_id")
            if parent_chunk_id:
                parent = (
                    self.db.query(DocumentVector)
                    .filter(DocumentVector.chunk_id == parent_chunk_id)
                    .first()
                )
                if parent:
                    result = dict(result)
                    result["child_text"] = result.get("chunk_text", result.get("text", ""))
                    result["text"] = parent.chunk_text
                    result["chunk_text"] = parent.chunk_text
                    result["is_expanded"] = True
                    result["parent_chunk_id"] = parent_chunk_id
                    result["chunk_level"] = result.get("chunk_level", "unknown")
                else:
                    result = dict(result)
                    result["is_expanded"] = False
            else:
                result = dict(result)
                result["is_expanded"] = False
            expanded.append(result)
        return expanded


def get_proposition_retriever(
    db: Session, knowledge_base_id: int, tenant_id: int
) -> PropositionRetriever:
    return PropositionRetriever(db, knowledge_base_id, tenant_id)
