import asyncio
import json
import math
import re
from functools import partial
from typing import Any

from sqlalchemy.orm import Session

from app.config import settings
from app.models.agent import Agent
from app.models.knowledge_base import KnowledgeBase
from app.services.rag.langchain_retrieval import RetrievalService

from common_logging import get_logger
from common_metrics import rag_retrieval_duration, rag_rerank_duration

logger = get_logger(__name__)




class AgentRAGService:
    _CJK_PATTERN = re.compile("[\\u3400-\\u4dbf\\u4e00-\\u9fff]")
    _ALNUM_TOKEN_PATTERN = re.compile("[A-Za-z0-9]+")
    _SMALL_TALK_MESSAGES = {
        "hi",
        "hello",
        "hey",
        "ok",
        "okay",
        "你好",
        "您好",
        "嗨",
        "哈喽",
        "在吗",
        "在么",
        "嗯",
        "嗯嗯",
        "哦",
        "哦哦",
        "好的",
        "收到",
        "谢谢",
        "多谢",
    }
    _QUERY_REPLACEMENTS = {
        "个税": "个人所得税",
        "个所税": "个人所得税",
        "企税": "企业所得税",
        "研发加计减扣": "研发费用加计扣除",
        "加计减扣": "加计扣除",
        "专票": "增值税专用发票",
        "普票": "增值税普通发票",
    }

    def __init__(self, db: Session, agent: Agent):
        self.db = db
        self.agent = agent
        self.knowledge_base_ids = agent.knowledge_bases or []
        logger.info(
            f"AgentRAGService initialized - Agent: {agent.name}, Knowledge base IDs: {self.knowledge_base_ids}"
        )

    def _normalize_message(self, user_message: str) -> str:
        return re.sub("\\s+", " ", user_message or "").strip()

    def _normalize_query(self, query: str) -> str:
        normalized_query = self._normalize_message(query)
        for source, target in self._QUERY_REPLACEMENTS.items():
            normalized_query = normalized_query.replace(source, target)
        return normalized_query

    def _is_small_talk(self, message: str) -> bool:
        compact_message = re.sub("\\s+", "", message).lower()
        return compact_message in self._SMALL_TALK_MESSAGES

    def should_retrieve(self, user_message: str, min_length: int = 5) -> bool:
        if not self.knowledge_base_ids:
            logger.info("should_retrieve=False: Agent has no bound knowledge bases")
            return False
        normalized_message = self._normalize_message(user_message)
        if not normalized_message:
            logger.info("should_retrieve=False: Message is empty after normalization")
            return False
        if self._is_small_talk(normalized_message):
            logger.info(
                f"should_retrieve=False: Message '{normalized_message}' identified as small talk"
            )
            return False
        cjk_char_count = len(self._CJK_PATTERN.findall(normalized_message))
        alnum_tokens = self._ALNUM_TOKEN_PATTERN.findall(normalized_message)
        has_meaningful_acronym = any(len(token) >= 3 for token in alnum_tokens)
        should_retrieve = (
            len(normalized_message) >= min_length or cjk_char_count >= 2 or has_meaningful_acronym
        )
        if not should_retrieve:
            logger.info(
                f"should_retrieve=False: message too short after heuristics (message='{normalized_message}', length={len(normalized_message)}, cjk_chars={cjk_char_count}, alnum_tokens={alnum_tokens})"
            )
            return False
        logger.info(
            f"should_retrieve=True: Message='{normalized_message}', length={len(normalized_message)}, cjk_chars={cjk_char_count}, knowledge_base_ids={self.knowledge_base_ids}"
        )
        return True

    async def _analyze_intent(
        self, user_query: str, chat_client, model: str, max_queries: int = 3
    ) -> list[str]:

        prompt = f'你是一个税务知识库检索专家。请分析用户的查询意图，生成1-{max_queries}个优化的检索查询。\n\n要求：\n1. 标准化税务术语（如"研发加计减扣"→"研发费用加计扣除"，"个税"→"个人所得税"）\n2. 拆解复合问题（如"高新企业优惠和申报条件"→两个独立查询）\n3. 扩展同义词和相关概念\n4. 保持查询简洁，每个查询聚焦单一主题\n\n用户查询：{user_query}\n\n请以JSON格式返回，包含intent（意图分析）和queries（查询列表）字段。\n示例：{{"intent": "咨询研发费用加计扣除", "queries": ["研发费用加计扣除政策", "研发费用加计扣除计算方法"]}}'
        try:
            response = await chat_client.chat_completion(
                messages=[{"role": "user", "content": prompt}],
                model=model,
                temperature=settings.AGENTIC_RAG_INTENT_TEMPERATURE,
                max_tokens=settings.AGENTIC_RAG_INTENT_MAX_TOKENS,
            )
            text = chat_client.extract_response_text(response)
            result = json.loads(text)
            queries = result.get("queries", [user_query])
            logger.info(
                f"Intent analysis: {result.get('intent', 'N/A')}, Generated {len(queries)} queries"
            )
            return queries[:max_queries]
        except Exception as e:
            logger.warning(f"Intent analysis failed, using original query: {e}")
            return [user_query]

    def _multi_query_retrieve(
        self,
        queries: list[str],
        kb_id: int,
        kb,
        top_k: int,
        threshold: float,
        mode: str,
        model_id: int | None,
    ) -> list[dict[str, Any]]:
        all_results = []
        for query in queries:
            try:
                retrieval_service = RetrievalService(self.db, kb_id)
                import time
                _t0 = time.perf_counter()
                results = retrieval_service.retrieve(
                    query=query, mode=mode, k=top_k, threshold=threshold, model_id=model_id
                )
                if rag_retrieval_duration:
                    rag_retrieval_duration.labels(mode=mode).observe(time.perf_counter() - _t0)
                for result in results:
                    result["knowledge_base_id"] = kb_id
                    result["knowledge_base_name"] = kb.name
                    result["source_query"] = query
                    all_results.append(result)
                logger.info(f"Query '{query}' retrieved {len(results)} results from {kb.name}")
            except Exception as e:
                logger.error(f"Failed to retrieve for query '{query}': {e}")
                continue
        return all_results

    def _merge_and_deduplicate(
        self, all_results: list[dict[str, Any]], top_k: int
    ) -> list[dict[str, Any]]:
        seen = {}
        for result in all_results:
            key = (result.get("document_id"), result.get("chunk_index"))
            if key not in seen:
                seen[key] = result
            elif result.get("score", 0) > seen[key].get("score", 0):
                seen[key] = result
        unique_results = list(seen.values())
        unique_results.sort(key=lambda x: x.get("score", 0), reverse=True)
        logger.info(f"Deduplication: {len(all_results)} → {len(unique_results)} unique results")
        return unique_results[:top_k]

    async def retrieve_context(
        self,
        query: str,
        top_k: int = 5,
        threshold: float = 0.5,
        mode: str = "hybrid",
        tenant_id: int = None,
        user_role: str = None,
        enable_agentic: bool = True,
        chat_client=None,
        chat_model: str | None = None,
        use_reranker: bool = True,
    ) -> dict[str, Any]:

        use_agentic = (
            enable_agentic
            and settings.ENABLE_AGENTIC_RAG
            and (chat_client is not None)
            and (chat_model is not None)
        )
        enable_reranker = use_reranker and settings.ENABLE_BGE_RERANKER
        initial_k = (
            min(settings.BGE_RERANKER_INITIAL_K, max(top_k * 3, top_k))
            if enable_reranker
            else top_k
        )
        logger.info(
            f"Retrieval strategy: initial_k={initial_k}, final_k={top_k}, reranker_enabled={enable_reranker}"
        )
        normalized_query = self._normalize_query(query)
        if use_agentic and settings.AGENTIC_RAG_MAX_QUERIES > 1:
            queries = await self._analyze_intent(
                normalized_query, chat_client, chat_model, settings.AGENTIC_RAG_MAX_QUERIES
            )
        else:
            queries = [normalized_query]
        all_results = []
        knowledge_base_names = {}
        loop = asyncio.get_event_loop()
        for kb_id in self.knowledge_base_ids:
            try:
                if user_role == "platform_admin":
                    query_kb = self.db.query(KnowledgeBase).filter(
                        KnowledgeBase.id == kb_id, KnowledgeBase.status == "enabled"
                    )
                else:
                    if tenant_id is None:
                        logger.warning("tenant_id is required for non-admin users")
                        continue
                    query_kb = self.db.query(KnowledgeBase).filter(
                        KnowledgeBase.id == kb_id,
                        KnowledgeBase.status == "enabled",
                        KnowledgeBase.tenant_id == tenant_id,
                    )
                kb = query_kb.first()
                if not kb:
                    logger.warning(
                        f"Knowledge base {kb_id} does not exist, is disabled, or access denied, skipping"
                    )
                    continue
                knowledge_base_names[kb_id] = kb.name
                model_id = None
                if kb.code:
                    from app.models.provider import Model

                    if kb.code.isdigit():
                        model = self.db.query(Model).filter(Model.id == int(kb.code)).first()
                    else:
                        model = self.db.query(Model).filter(Model.code == kb.code).first()
                    if model:
                        model_id = model.id
                        logger.info(
                            f"Knowledge base {kb.name} using vector model: {kb.code} (ID: {model_id})"
                        )
                    else:
                        logger.warning(
                            f"Vector model {kb.code} not found, will use default configuration"
                        )
                kb_results = await loop.run_in_executor(
                    None,
                    partial(
                        self._multi_query_retrieve,
                        queries=queries,
                        kb_id=kb_id,
                        kb=kb,
                        top_k=initial_k,
                        threshold=threshold,
                        mode=mode,
                        model_id=model_id,
                    ),
                )
                all_results.extend(kb_results)
                logger.info(f"Retrieved {len(kb_results)} results from knowledge base {kb.name}")
            except Exception as e:
                logger.error(f"Failed to retrieve from knowledge base {kb_id}: {e}")
                continue
        all_results = self._merge_and_deduplicate(all_results, initial_k)
        if settings.ENABLE_KNOWLEDGE_GRAPH and len(all_results) > 0:
            try:
                from app.services.graph.graph_query import get_graph_query_service

                doc_ids = list({r["document_id"] for r in all_results if r.get("document_id")})
                if doc_ids and self.knowledge_base_ids:
                    graph_query = get_graph_query_service()
                    for kb_id in self.knowledge_base_ids:
                        graph_results = graph_query.expand_neighbors(
                            document_ids=doc_ids,
                            tenant_id=tenant_id,
                            kb_id=kb_id,
                            relation_types=["REFERENCES", "SUPERSEDES"],
                            depth=1,
                            limit=5,
                        )
                        existing_doc_ids = {
                            r["document_id"] for r in all_results if r.get("document_id")
                        }
                        for gr in graph_results:
                            gr_doc_id = gr.get("id")
                            if gr_doc_id and gr_doc_id not in existing_doc_ids:
                                all_results.append(
                                    {
                                        "document_id": gr_doc_id,
                                        "title": gr.get("title", ""),
                                        "text": gr.get("summary", ""),
                                        "chunk_text": gr.get("summary", ""),
                                        "score": gr.get("score", 0.5),
                                        "source": "graph",
                                    }
                                )
                                existing_doc_ids.add(gr_doc_id)
            except Exception as e:
                logger.warning(f"Graph expansion failed (non-fatal): {e}")
        if enable_reranker and len(all_results) > top_k:
            try:
                from app.models.provider import Model
                from app.services.llm.backends.rerank_backend_factory import get_rerank_factory

                logger.info(f"Starting BGE Reranker: {len(all_results)} candidates -> top {top_k}")
                rerank_model_key = settings.BGE_RERANKER_MODEL
                if str(rerank_model_key).isdigit():
                    rerank_model = (
                        self.db.query(Model).filter(Model.id == int(rerank_model_key)).first()
                    )
                else:
                    rerank_model = (
                        self.db.query(Model).filter(Model.code == rerank_model_key).first()
                    )
                if rerank_model:
                    rerank_factory = get_rerank_factory()
                    import time
                    _t0 = time.perf_counter()
                    all_results = await loop.run_in_executor(
                        None,
                        partial(
                            rerank_factory.rerank,
                            query=query,
                            documents=all_results,
                            model_id=rerank_model.id,
                            db=self.db,
                            top_k=top_k,
                        ),
                    )
                    if rag_rerank_duration:
                        rag_rerank_duration.observe(time.perf_counter() - _t0)
                else:
                    logger.warning(
                        f"Rerank model not found in provider registry: {settings.BGE_RERANKER_MODEL}"
                    )
                    all_results = all_results[:top_k]
                logger.info(f"Reranking completed: {len(all_results)} results returned")
            except Exception as e:
                logger.error(f"Reranker failed, falling back to original results: {e}")
                all_results = all_results[:top_k]
        elif len(all_results) > top_k:
            all_results = all_results[:top_k]
        context_text = self._format_context(all_results)
        sources = self._format_sources(all_results)
        return {
            "context_text": context_text,
            "sources": sources,
            "retrieved_count": len(all_results),
            "knowledge_bases": list(knowledge_base_names.values()),
        }

    def _format_context(self, results: list[dict[str, Any]]) -> str:
        if not results:
            return ""
        context_parts = ["以下是从知识库中检索到的相关信息：\n"]
        for i, result in enumerate(results, 1):
            kb_name = result.get("knowledge_base_name", "未知知识库")
            title = result.get("title", "无标题")
            text = result.get("text", "")
            score = result.get("score", 0)
            context_parts.append(
                f"【参考资料 {i}】\n来源：{kb_name} - {title}\n相关度：{score:.2f}\n内容：{text}\n"
            )
        context_parts.append(
            "\n请基于以上参考资料回答用户问题。如果参考资料中没有相关信息，请使用你的通用知识回答。"
        )
        return "\n".join(context_parts)

    def _format_sources(self, results: list[dict[str, Any]]) -> list[dict[str, Any]]:
        from app.db.session import get_data_center_db
        from app.models.knowledge_base import KnowledgeDocument

        doc_ids = [r.get("document_id") for r in results if r.get("document_id")]
        dc_id_map = {}
        if doc_ids and self.db:
            rows = (
                self.db.query(KnowledgeDocument.id, KnowledgeDocument.data_center_doc_id)
                .filter(KnowledgeDocument.id.in_(doc_ids))
                .all()
            )
            dc_id_map = {row.id: row.data_center_doc_id for row in rows}
        source_url_map = {}
        dc_ids = [v for v in dc_id_map.values() if v]
        if dc_ids:
            try:
                with get_data_center_db() as dc_db:
                    if dc_db:
                        from sqlalchemy import text


                        rows = dc_db.execute(
                            text("SELECT id, source_url FROM tax_documents WHERE id = ANY(:ids)"),
                            {"ids": dc_ids},
                        ).fetchall()
                        source_url_map = {row.id: row.source_url for row in rows}
            except Exception as e:
                logger.warning(f"Failed to query data_center source_url: {e}")
        sources = []
        seen_doc_ids = set()
        for result in results:
            doc_id = result.get("document_id")
            if doc_id and doc_id in seen_doc_ids:
                continue
            if doc_id:
                seen_doc_ids.add(doc_id)
            rerank_score = result.get("rerank_score")
            display_score = (
                1 / (1 + math.exp(-rerank_score))
                if rerank_score is not None
                else result.get("score")
            )
            sources.append(
                {
                    "document_id": doc_id,
                    "data_center_doc_id": dc_id_map.get(doc_id),
                    "title": result.get("title"),
                    "score": display_score,
                    "raw_score": result.get("score"),
                    "rerank_score": rerank_score,
                    "knowledge_base_id": result.get("knowledge_base_id"),
                    "knowledge_base_name": result.get("knowledge_base_name"),
                    "chunk_index": result.get("chunk_index"),
                    "source_query": result.get("source_query"),
                    "reference_url": result.get("reference_url"),
                    "reference_source_url": source_url_map.get(dc_id_map.get(doc_id)),
                }
            )
        return sources

    def enhance_system_prompt(self, original_prompt: str, has_context: bool = False) -> str:
        if not has_context:
            return original_prompt
        rag_instruction = "\n\n【知识库增强模式】\n你已接入专业知识库。在回答问题时：\n1. 优先使用知识库中提供的参考资料\n2. 如果参考资料充分，请基于资料给出准确回答\n3. 如果参考资料不足，可以结合你的通用知识补充\n4. 回答时保持专业、准确、简洁\n5. 必要时可以引用具体的参考资料来源"
        return original_prompt + rag_instruction


def get_agent_rag_service(db: Session, agent: Agent) -> AgentRAGService:
    return AgentRAGService(db, agent)
