from datetime import datetime

from fastapi import APIRouter, Depends, HTTPException, Request
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session

from app.api.deps import get_db
from app.api.permissions import require_read
from app.config import settings
from app.core.exceptions import KnowledgeBaseNotFoundError
from app.core.i18n import get_translator
from app.models import KnowledgeBase, User
from app.services.rag.graph_enhanced_retrieval import get_graph_enhanced_retrieval
from app.services.rag.langchain_retrieval import get_retrieval_service

from common_logging import get_logger

logger = get_logger(__name__)

router = APIRouter(tags=["knowledge-recall"])


class RecallTestRequest(BaseModel):
    query: str = Field(..., description="查询文本")
    mode: str = Field(
        default="semantic", description="检索模式: semantic, keyword, hybrid, graph_enhanced"
    )
    top_k: int = Field(default=5, ge=1, le=50, description="返回结果数量")
    threshold: float = Field(default=0.7, ge=0.0, le=1.0, description="相似度阈值")
    model_id: int | None = Field(default=None, description="向量模型ID")
    keyword_weight: float | None = Field(
        default=0.3, ge=0.0, le=1.0, description="关键词权重（混合模式）"
    )
    semantic_weight: float | None = Field(
        default=0.7, ge=0.0, le=1.0, description="语义权重（混合模式）"
    )
    enable_graph: bool = Field(default=False, description="启用图增强检索")
    graph_expand_depth: int = Field(default=1, ge=1, le=2, description="图扩展深度（1-2跳）")
    graph_relation_types: list[str] | None = Field(default=None, description="图关系类型过滤")
    enable_parent_expansion: bool = Field(
        default=True, description="启用父块扩展（提供完整上下文）"
    )
    enable_reference_expansion: bool = Field(
        default=True, description="启用引用扩展（自动关联引用条款）"
    )
    enable_parent_expansion: bool = Field(
        default=True, description="启用父块扩展（提供完整上下文）"
    )
    enable_reference_expansion: bool = Field(
        default=True, description="启用引用扩展（自动关联引用条款）"
    )


class RecallTestResponse(BaseModel):
    query: str
    mode: str
    params: dict
    results: list[dict]
    statistics: dict
    timestamp: datetime = Field(default_factory=datetime.now)


@router.post("/{base_id}/test", response_model=RecallTestResponse)
@router.post("/{base_id}/recall-test", response_model=RecallTestResponse)
def recall_test(
    base_id: int,
    test_request: RecallTestRequest,
    request: Request,
    db: Session = Depends(get_db),
    current_user: User = Depends(require_read("knowledge_bases")),
):
    t = get_translator(request)
    kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == base_id).first()
    if not kb:
        raise KnowledgeBaseNotFoundError(base_id)
    from app.core.tenant_context import get_current_tenant_id

    tenant_id = get_current_tenant_id()
    if tenant_id:
        if hasattr(kb, "tenant_id") and kb.tenant_id != tenant_id:
            raise KnowledgeBaseNotFoundError(base_id)
    try:
        if test_request.mode == "graph_enhanced" or test_request.enable_graph:
            if not settings.ENABLE_KNOWLEDGE_GRAPH:
                raise HTTPException(
                    status_code=400,
                    detail="Knowledge graph is not enabled. Set ENABLE_KNOWLEDGE_GRAPH=true in settings.",
                )
            graph_retrieval = get_graph_enhanced_retrieval(
                db=db, knowledge_base_id=base_id, tenant_id=current_user.tenant_id
            )
            import time

            start_time = time.time()
            results = graph_retrieval.retrieve(
                query=test_request.query,
                k=test_request.top_k,
                threshold=test_request.threshold,
                model_id=test_request.model_id,
                expand_depth=test_request.graph_expand_depth,
                relation_types=test_request.graph_relation_types,
            )
            end_time = time.time()
            elapsed_time = end_time - start_time
            scores = [r["score"] for r in results]
            avg_score = sum(scores) / len(scores) if scores else 0
            max_score = max(scores) if scores else 0
            min_score = min(scores) if scores else 0
            return {
                "query": test_request.query,
                "mode": "graph_enhanced",
                "params": {
                    "k": test_request.top_k,
                    "threshold": test_request.threshold,
                    "model_id": test_request.model_id,
                    "graph_expand_depth": test_request.graph_expand_depth,
                    "graph_relation_types": test_request.graph_relation_types,
                    "enable_parent_expansion": test_request.enable_parent_expansion,
                    "enable_reference_expansion": test_request.enable_reference_expansion,
                },
                "results": results,
                "statistics": {
                    "total_results": len(results),
                    "avg_score": avg_score,
                    "max_score": max_score,
                    "min_score": min_score,
                    "elapsed_time": elapsed_time,
                    "parent_expanded_count": sum(1 for r in results if r.get("is_expanded")),
                    "reference_expanded_count": sum(
                        1 for r in results if r.get("is_reference_expanded")
                    ),
                    "graph_expanded_count": sum(1 for r in results if r.get("is_graph_expanded")),
                },
            }
        if test_request.enable_parent_expansion or test_request.enable_reference_expansion:
            from app.services.rag.hybrid_retrieval import get_hybrid_retriever

            hybrid_retriever = get_hybrid_retriever(
                db=db, knowledge_base_id=base_id, tenant_id=current_user.tenant_id
            )
            import time

            start_time = time.time()
            results = hybrid_retriever.retrieve(
                query=test_request.query,
                k=test_request.top_k,
                expand_graph=test_request.enable_reference_expansion,
                expand_parent=test_request.enable_parent_expansion,
                rerank=True,
            )
            end_time = time.time()
            elapsed_time = end_time - start_time
            scores = [r.get("score", 0) for r in results]
            avg_score = sum(scores) / len(scores) if scores else 0
            max_score = max(scores) if scores else 0
            min_score = min(scores) if scores else 0
            return {
                "query": test_request.query,
                "mode": "hybrid_with_parent_child",
                "params": {
                    "k": test_request.top_k,
                    "enable_parent_expansion": test_request.enable_parent_expansion,
                    "enable_reference_expansion": test_request.enable_reference_expansion,
                },
                "results": results,
                "statistics": {
                    "total_results": len(results),
                    "avg_score": avg_score,
                    "max_score": max_score,
                    "min_score": min_score,
                    "elapsed_time": elapsed_time,
                    "parent_expanded_count": sum(1 for r in results if r.get("is_expanded")),
                    "reference_expanded_count": sum(
                        1 for r in results if r.get("is_reference_expanded")
                    ),
                },
            }
        retrieval_service = get_retrieval_service(db, base_id)
        model_id = test_request.model_id
        if not model_id and kb.code:
            from app.models import Model

            if kb.code.isdigit():
                vector_model = db.query(Model).filter(Model.id == int(kb.code)).first()
            else:
                vector_model = db.query(Model).filter(Model.code == kb.code).first()
            if vector_model:
                model_id = vector_model.id
                logger.info(f"Using knowledge base default model: {kb.code} (ID: {model_id})")
        result = retrieval_service.recall_test(
            query=test_request.query,
            mode=test_request.mode,
            k=test_request.top_k,
            threshold=test_request.threshold,
            model_id=model_id,
            keyword_weight=test_request.keyword_weight,
            semantic_weight=test_request.semantic_weight,
        )
        return result
    except Exception as e:
        import traceback

        logger.error(f"Recall test error: {str(e)}")
        logger.error(f"Traceback: {traceback.format_exc()}")
        raise HTTPException(status_code=500, detail=f"{t.t('recall.test_failed')}: {str(e)}") from None


@router.get("/{base_id}/recall-history")
def get_recall_history(
    base_id: int,
    request: Request,
    limit: int = 10,
    db: Session = Depends(get_db),
    current_user: User = Depends(require_read("knowledge_bases")),
):
    get_translator(request)
    kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == base_id).first()
    if not kb:
        raise KnowledgeBaseNotFoundError(base_id)
    return {"history": [], "total": 0}


@router.post("/{base_id}/recall-test/export")
def export_recall_results(
    base_id: int,
    test_request: RecallTestRequest,
    request: Request,
    format: str = "json",
    db: Session = Depends(get_db),
    current_user: User = Depends(require_read("knowledge_bases")),
):
    t = get_translator(request)
    kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == base_id).first()
    if not kb:
        raise KnowledgeBaseNotFoundError(base_id)
    try:
        retrieval_service = get_retrieval_service(db, base_id)
        result = retrieval_service.recall_test(
            query=test_request.query,
            mode=test_request.mode,
            k=test_request.top_k,
            threshold=test_request.threshold,
            model_id=test_request.model_id,
            keyword_weight=test_request.keyword_weight,
            semantic_weight=test_request.semantic_weight,
        )
        if format == "csv":
            import csv
            import io

            output = io.StringIO()
            writer = csv.writer(output)
            writer.writerow(["ID", "Document ID", "Title", "Text", "Score", "Chunk Index"])
            for item in result["results"]:
                writer.writerow(
                    [
                        item["id"],
                        item["document_id"],
                        item["title"],
                        item["text"][:100] + "...",
                        item["score"],
                        item["chunk_index"],
                    ]
                )
            return {"format": "csv", "data": output.getvalue()}
        else:
            return {"format": "json", "data": result}
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"{t.t('recall.export_failed')}: {str(e)}") from None


@router.get("/{base_id}/statistics")
def get_knowledge_base_statistics(
    base_id: int,
    request: Request,
    db: Session = Depends(get_db),
    current_user: User = Depends(require_read("knowledge_bases")),
):
    get_translator(request)
    kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == base_id).first()
    if not kb:
        raise KnowledgeBaseNotFoundError(base_id)
    from sqlalchemy import func

    from app.models import DocumentVector, KnowledgeCategory, KnowledgeDocument

    category_count = (
        db.query(func.count(KnowledgeCategory.id))
        .filter(KnowledgeCategory.knowledge_base_id == base_id)
        .scalar()
    )
    category_ids = (
        db.query(KnowledgeCategory.id).filter(KnowledgeCategory.knowledge_base_id == base_id).all()
    )
    category_ids = [cat_id[0] for cat_id in category_ids]
    if not category_ids:
        return {
            "knowledge_base_id": base_id,
            "knowledge_base_name": kb.name,
            "category_count": 0,
            "document_count": 0,
            "vector_count": 0,
            "vectorized_document_count": 0,
            "vectorization_rate": 0,
            "qa_count": kb.qa_count,
            "entity_count": kb.entity_count,
        }
    doc_stats = (
        db.query(
            func.count(KnowledgeDocument.id).label("total"),
            func.count(func.nullif(KnowledgeDocument.is_vectorized, False)).label("vectorized"),
        )
        .filter(KnowledgeDocument.category_id.in_(category_ids))
        .first()
    )
    doc_count = doc_stats.total or 0
    vectorized_doc_count = doc_stats.vectorized or 0
    vector_count = (
        db.query(func.count(DocumentVector.id))
        .join(KnowledgeDocument, DocumentVector.document_id == KnowledgeDocument.id)
        .filter(KnowledgeDocument.category_id.in_(category_ids))
        .scalar()
        or 0
    )
    return {
        "knowledge_base_id": base_id,
        "knowledge_base_name": kb.name,
        "category_count": category_count,
        "document_count": doc_count,
        "vector_count": vector_count,
        "vectorized_document_count": vectorized_doc_count,
        "vectorization_rate": vectorized_doc_count / doc_count * 100 if doc_count > 0 else 0,
        "qa_count": kb.qa_count,
        "entity_count": kb.entity_count,
    }
