from typing import Any

from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import ChatOpenAI
from sqlalchemy.orm import Session

from app.services.llm.backends.embedding_backend_factory import get_embedding_factory
from app.services.storage.vector_store_factory import get_vector_store

from common_logging import get_logger

logger = get_logger(__name__)




class RetrievalService:

    def __init__(self, db: Session, knowledge_base_id: int | None = None):
        self.db = db
        self.knowledge_base_id = knowledge_base_id
        self.vector_store = get_vector_store(db, knowledge_base_id)
        self.embedding_factory = get_embedding_factory()

    def retrieve(
        self,
        query: str,
        mode: str = "semantic",
        k: int = 5,
        threshold: float = 0.7,
        model_id: int | None = None,
        **kwargs,
    ) -> list[dict[str, Any]]:
        try:
            if mode == "semantic":
                return self._semantic_search(query, k, threshold, model_id)
            elif mode == "keyword":
                return self._keyword_search(query, k)
            elif mode == "hybrid":
                keyword_weight = kwargs.get("keyword_weight", 0.3)
                semantic_weight = kwargs.get("semantic_weight", 0.7)
                return self._hybrid_search(
                    query, k, threshold, model_id, keyword_weight, semantic_weight
                )
            else:
                raise ValueError(f"Unsupported retrieval mode: {mode}")
        except Exception as e:
            logger.error(f"Retrieval failed: {e}")
            raise

    def _semantic_search(
        self, query: str, k: int, threshold: float, model_id: int | None
    ) -> list[dict[str, Any]]:
        query_embedding = self.embedding_factory.generate_embedding(
            text=query, db=self.db, model_id=model_id
        )
        if not query_embedding:
            raise ValueError("Failed to generate query embedding")
        results = self.vector_store.similarity_search(
            query_embedding=query_embedding, k=k, threshold=threshold
        )
        formatted_results = []
        for doc, score in results:
            formatted_results.append(
                {
                    "id": doc["id"],
                    "document_id": doc["document_id"],
                    "title": doc["title"],
                    "text": doc["text"],
                    "chunk_index": doc["chunk_index"],
                    "score": score,
                    "metadata": {
                        "category_id": doc["category_id"],
                        "model_name": doc["model_name"],
                    },
                }
            )
        return formatted_results

    def _keyword_search(self, query: str, k: int) -> list[dict[str, Any]]:
        from sqlalchemy import text

        where_clauses = [
            "(setweight(to_tsvector('simple', COALESCE(kd.title, '')), 'A') ||  setweight(to_tsvector('simple', COALESCE(dv.chunk_text, '')), 'B')) @@ plainto_tsquery('simple', :query)"
        ]
        params = {"query": query, "k": k}
        if self.knowledge_base_id is not None:
            where_clauses.append("dv.knowledge_base_id = :knowledge_base_id")
            params["knowledge_base_id"] = self.knowledge_base_id
        keyword_query = text(
            f"\n            SELECT\n                dv.id,\n                dv.document_id,\n                dv.chunk_index,\n                dv.chunk_text,\n                dv.model_name,\n                kd.title,\n                kd.category_id,\n                ts_rank(\n                    setweight(to_tsvector('simple', COALESCE(kd.title, '')), 'A') ||\n                    setweight(to_tsvector('simple', COALESCE(dv.chunk_text, '')), 'B'),\n                    plainto_tsquery('simple', :query)\n                ) as score\n            FROM document_vectors dv\n            JOIN knowledge_documents kd ON dv.document_id = kd.id\n            WHERE {' AND '.join(where_clauses)}\n            ORDER BY score DESC\n            LIMIT :k\n        "
        )
        result = self.db.execute(keyword_query, params)
        formatted_results = []
        for row in result:
            formatted_results.append(
                {
                    "id": row.id,
                    "document_id": row.document_id,
                    "title": row.title,
                    "text": row.chunk_text,
                    "chunk_index": row.chunk_index,
                    "score": float(row.score),
                    "metadata": {"category_id": row.category_id, "model_name": row.model_name},
                }
            )
        return formatted_results

    def _hybrid_search(
        self,
        query: str,
        k: int,
        threshold: float,
        model_id: int | None,
        keyword_weight: float,
        semantic_weight: float,
    ) -> list[dict[str, Any]]:
        query_embedding = self.embedding_factory.generate_embedding(
            text=query, db=self.db, model_id=model_id
        )
        if not query_embedding:
            raise ValueError("Failed to generate query embedding")
        results = self.vector_store.hybrid_search(
            query=query,
            query_embedding=query_embedding,
            k=k,
            threshold=threshold,
            keyword_weight=keyword_weight,
            semantic_weight=semantic_weight,
        )
        formatted_results = []
        for doc, score in results:
            formatted_results.append(
                {
                    "id": doc["id"],
                    "document_id": doc["document_id"],
                    "title": doc["title"],
                    "text": doc["text"],
                    "chunk_index": doc["chunk_index"],
                    "score": score,
                    "metadata": {
                        "category_id": doc["category_id"],
                        "model_name": doc["model_name"],
                    },
                }
            )
        return formatted_results

    def recall_test(
        self,
        query: str,
        mode: str = "semantic",
        k: int = 5,
        threshold: float = 0.7,
        model_id: int | None = None,
        **kwargs,
    ) -> dict[str, Any]:
        try:
            import time


            start_time = time.time()
            results = self.retrieve(
                query=query, mode=mode, k=k, threshold=threshold, model_id=model_id, **kwargs
            )
            end_time = time.time()
            elapsed_time = end_time - start_time
            scores = [r["score"] for r in results]
            avg_score = sum(scores) / len(scores) if scores else 0
            max_score = max(scores) if scores else 0
            min_score = min(scores) if scores else 0
            return {
                "query": query,
                "mode": mode,
                "params": {"k": k, "threshold": threshold, "model_id": model_id, **kwargs},
                "results": results,
                "statistics": {
                    "total_results": len(results),
                    "avg_score": avg_score,
                    "max_score": max_score,
                    "min_score": min_score,
                    "elapsed_time": elapsed_time,
                },
            }
        except Exception as e:
            logger.error(f"Recall test failed: {e}")
            raise


class RAGService:

    def __init__(
        self,
        db: Session,
        knowledge_base_id: int | None = None,
        llm_api_key: str | None = None,
        llm_base_url: str | None = None,
        llm_model: str = "gpt-3.5-turbo",
    ):
        self.db = db
        self.knowledge_base_id = knowledge_base_id
        self.retrieval_service = RetrievalService(db, knowledge_base_id)
        if llm_api_key:
            self.llm = ChatOpenAI(
                api_key=llm_api_key, base_url=llm_base_url, model=llm_model, temperature=0.7
            )
        else:
            self.llm = None

    def answer_question(
        self,
        question: str,
        mode: str = "semantic",
        k: int = 5,
        threshold: float = 0.7,
        model_id: int | None = None,
        system_prompt: str | None = None,
    ) -> dict[str, Any]:
        try:
            if not self.llm:
                raise ValueError("LLM not configured, cannot generate answer")
            retrieved_docs = self.retrieval_service.retrieve(
                query=question, mode=mode, k=k, threshold=threshold, model_id=model_id
            )
            if not retrieved_docs:
                return {
                    "answer": "Sorry, I could not find relevant information in the knowledge base.",
                    "sources": [],
                    "retrieved_docs": [],
                }
            context = "\n\n".join(
                [
                    f"Document {i + 1} (Source: {doc['title']}):\n{doc['text']}"
                    for i, doc in enumerate(retrieved_docs)
                ]
            )
            default_system_prompt = "You are a professional knowledge base assistant. Please answer the user's question based on the provided document content.\n\nRequirements:\n1. Only use the provided document content to answer questions\n2. If there is no relevant information in the documents, clearly inform the user\n3. Answers should be accurate, concise, and professional\n4. If possible, cite specific document sources"
            prompt = ChatPromptTemplate.from_messages(
                [
                    ("system", system_prompt or default_system_prompt),
                    ("human", "Reference documents:\n{context}\n\nUser question: {question}"),
                ]
            )
            rag_chain = (
                {"context": lambda x: context, "question": RunnablePassthrough()}
                | prompt
                | self.llm
                | StrOutputParser()
            )
            answer = rag_chain.invoke(question)
            return {
                "answer": answer,
                "sources": [
                    {
                        "document_id": doc["document_id"],
                        "title": doc["title"],
                        "text": doc["text"][:200] + "...",
                        "score": doc["score"],
                    }
                    for doc in retrieved_docs
                ],
                "retrieved_docs": retrieved_docs,
            }
        except Exception as e:
            logger.error(f"RAG Q&A failed: {e}")
            raise


def get_retrieval_service(db: Session, knowledge_base_id: int | None = None) -> RetrievalService:
    return RetrievalService(db, knowledge_base_id)


def get_rag_service(
    db: Session,
    knowledge_base_id: int | None = None,
    llm_api_key: str | None = None,
    llm_base_url: str | None = None,
    llm_model: str = "gpt-3.5-turbo",
) -> RAGService:
    return RAGService(db, knowledge_base_id, llm_api_key, llm_base_url, llm_model)
