from fastapi import APIRouter, Depends, HTTPException, Request, status
from sqlalchemy import text
from sqlalchemy.orm import Session

from app.api.deps import get_db
from app.api.permissions import require_create, require_delete, require_read, require_update
from app.config import settings
from app.core.exceptions import KnowledgeBaseNotFoundError
from common_logging import get_logger

logger = get_logger(__name__)
from app.core.i18n import get_translator
from app.models import KnowledgeBase, KnowledgeCategory, KnowledgeDocument, User
from app.models.knowledge_base import (
    DocumentMetadataValue,
    DocumentTag,
    DocumentVector,
    DocumentVersion,
    KnowledgeBaseMetadata,
    KnowledgeBaseMetadataField,
    KnowledgeMetadataField,
    MetadataField,
)
from app.schemas import KnowledgeBaseCreate, KnowledgeBaseResponse, KnowledgeBaseUpdate
from app.schemas.knowledge_document_list import KnowledgeBaseListResponse
from app.services.graph.neo4j_client import get_neo4j_client
from app.services.storage.milvus import get_vector_store
from app.services.storage.minio import get_minio_service

router = APIRouter(tags=["knowledge-bases"])


@router.get("/", response_model=list[KnowledgeBaseListResponse])
def get_knowledge_bases(
    search: str | None = None,
    code: str | None = None,
    status: str | None = None,
    type: str | None = None,
    page: int = 1,
    page_size: int = 10,
    db: Session = Depends(get_db),
    current_user: User = Depends(require_read("knowledge_bases")),
):
    query = db.query(KnowledgeBase)
    from sqlalchemy import text as _text

    from app.core.tenant_context import get_current_tenant_id

    tenant_id = get_current_tenant_id()
    if tenant_id:
        db.execute(_text("SET search_path TO public"))
    if search:
        query = query.filter(KnowledgeBase.name.contains(search))
    if code:
        query = query.filter(KnowledgeBase.code == code)
    if status:
        query = query.filter(KnowledgeBase.status == status)
    if type:
        query = query.filter(KnowledgeBase.type == type)
    query.count()
    knowledge_bases = (
        query.order_by(KnowledgeBase.created_at.desc())
        .offset((page - 1) * page_size)
        .limit(page_size)
        .all()
    )
    from sqlalchemy import func

    kb_ids = [kb.id for kb in knowledge_bases]
    category_data = (
        db.query(KnowledgeCategory.knowledge_base_id, KnowledgeCategory.id)
        .filter(KnowledgeCategory.knowledge_base_id.in_(kb_ids))
        .all()
    )
    kb_category_map = {}
    for kb_id, cat_id in category_data:
        if kb_id not in kb_category_map:
            kb_category_map[kb_id] = []
        kb_category_map[kb_id].append(cat_id)
    all_category_ids = [cat_id for _, cat_id in category_data]
    doc_counts = {}
    if all_category_ids:
        doc_count_data = (
            db.query(KnowledgeDocument.category_id, func.count(KnowledgeDocument.id))
            .filter(KnowledgeDocument.category_id.in_(all_category_ids))
            .group_by(KnowledgeDocument.category_id)
            .all()
        )
        category_doc_counts = dict(doc_count_data)
        for kb_id, cat_ids in kb_category_map.items():
            doc_counts[kb_id] = sum(category_doc_counts.get(cat_id, 0) for cat_id in cat_ids)
    result = []
    for kb in knowledge_bases:
        doc_count = doc_counts.get(kb.id, 0)
        kb_response = KnowledgeBaseListResponse(
            id=kb.id,
            name=kb.name,
            code=kb.code,
            description=kb.description,
            icon=kb.icon,
            type=kb.type,
            status=kb.status,
            is_public=kb.is_public,
            doc_count=doc_count,
            created_at=kb.created_at,
            updated_at=kb.updated_at,
        )
        result.append(kb_response)
    return result


@router.get("/{base_id}", response_model=KnowledgeBaseResponse)
def get_knowledge_base(
    base_id: int,
    request: Request,
    db: Session = Depends(get_db),
    current_user: User = Depends(require_read("knowledge_bases")),
):
    get_translator(request)
    kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == base_id).first()
    if not kb:
        raise KnowledgeBaseNotFoundError(base_id)
    if hasattr(current_user, "tenant_id") and current_user.tenant_id:
        if not hasattr(kb, "tenant_id") or kb.tenant_id != current_user.tenant_id:
            raise KnowledgeBaseNotFoundError(base_id)
    from sqlalchemy import func, select

    doc_count = (
        db.query(func.count(KnowledgeDocument.id))
        .filter(
            KnowledgeDocument.category_id.in_(
                select(KnowledgeCategory.id).where(KnowledgeCategory.knowledge_base_id == kb.id)
            )
        )
        .scalar()
        or 0
    )
    return {
        "id": kb.id,
        "name": kb.name,
        "code": kb.code,
        "description": kb.description,
        "icon": kb.icon,
        "type": kb.type,
        "status": kb.status,
        "is_public": kb.is_public,
        "doc_count": doc_count,
        "qa_count": kb.qa_count,
        "entity_count": kb.entity_count,
        "created_at": kb.created_at,
        "updated_at": kb.updated_at,
    }


@router.post("/", response_model=KnowledgeBaseResponse, status_code=status.HTTP_201_CREATED)
def create_knowledge_base(
    kb: KnowledgeBaseCreate,
    db: Session = Depends(get_db),
    current_user: User = Depends(require_create("knowledge_bases")),
):
    try:
        kb_data = kb.model_dump()
        kb_data["created_by"] = current_user.id
        if hasattr(current_user, "tenant_id") and current_user.tenant_id:
            kb_data["tenant_id"] = current_user.tenant_id
        else:
            kb_data["tenant_id"] = 0
        if not kb_data.get("code"):
            kb_data["code"] = "qwen3-vl-embedding"
        db_kb = KnowledgeBase(**kb_data)
        db.add(db_kb)
        db.commit()
        db.refresh(db_kb)
        return {
            "id": db_kb.id,
            "name": db_kb.name,
            "code": db_kb.code,
            "description": db_kb.description,
            "icon": db_kb.icon,
            "type": db_kb.type,
            "status": db_kb.status,
            "is_public": db_kb.is_public,
            "doc_count": 0,
            "qa_count": 0,
            "entity_count": 0,
            "created_at": db_kb.created_at,
            "updated_at": db_kb.updated_at,
        }
    except Exception as e:
        db.rollback()
        logger.opt(exception=e).error("Failed to create knowledge base")
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"创建知识库失败: {str(e)}"
        ) from e


@router.put("/{base_id}", response_model=KnowledgeBaseResponse)
def update_knowledge_base(
    base_id: int,
    kb: KnowledgeBaseUpdate,
    request: Request,
    db: Session = Depends(get_db),
    current_user: User = Depends(require_update("knowledge_bases")),
):
    from app.core.security_utils import check_xss

    get_translator(request)
    db_kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == base_id).first()
    if not db_kb:
        raise KnowledgeBaseNotFoundError(base_id)
    if hasattr(current_user, "tenant_id") and current_user.tenant_id:
        if not hasattr(db_kb, "tenant_id") or db_kb.tenant_id != current_user.tenant_id:
            raise KnowledgeBaseNotFoundError(base_id)
    if kb.name and check_xss(kb.name):
        raise HTTPException(
            status_code=400, detail="Knowledge base name contains potentially malicious content"
        )
    if kb.description and check_xss(kb.description):
        raise HTTPException(
            status_code=400,
            detail="Knowledge base description contains potentially malicious content",
        )
    update_data = kb.model_dump(exclude_unset=True)
    for key, value in update_data.items():
        setattr(db_kb, key, value)
    db.commit()
    db.refresh(db_kb)
    from app.services.cache.query_cache import get_query_cache_service

    cache_service = get_query_cache_service()
    cache_service.invalidate_knowledge_base_cache(base_id)
    return db_kb


@router.get("/{base_id}/categories")
def get_knowledge_base_categories(
    base_id: int,
    request: Request,
    db: Session = Depends(get_db),
    current_user: User = Depends(require_read("knowledge_bases")),
):

    get_translator(request)
    kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == base_id).first()
    if not kb:
        raise KnowledgeBaseNotFoundError(base_id)
    if hasattr(current_user, "tenant_id") and current_user.tenant_id:
        if not hasattr(kb, "tenant_id") or kb.tenant_id != current_user.tenant_id:
            raise KnowledgeBaseNotFoundError(base_id)
    categories = (
        db.query(KnowledgeCategory)
        .filter(KnowledgeCategory.knowledge_base_id == base_id)
        .order_by(KnowledgeCategory.sort_order.asc(), KnowledgeCategory.created_at.desc())
        .all()
    )
    result = []
    for cat in categories:
        doc_count = (
            db.query(KnowledgeDocument).filter(KnowledgeDocument.category_id == cat.id).count()
        )
        result.append(
            {
                "id": cat.id,
                "knowledge_base_id": cat.knowledge_base_id,
                "knowledge_base_name": kb.name,
                "name": cat.name,
                "description": cat.description,
                "parent_id": cat.parent_id,
                "icon": cat.icon,
                "color": cat.color,
                "sort_order": cat.sort_order,
                "doc_count": doc_count,
                "document_count": doc_count,
                "created_at": cat.created_at,
                "updated_at": cat.updated_at,
            }
        )
    return result


@router.get("/{base_id}/statistics")
def get_knowledge_base_statistics(
    base_id: int,
    request: Request,
    db: Session = Depends(get_db),
    current_user: User = Depends(require_read("knowledge_bases")),
):
    from sqlalchemy import func

    get_translator(request)
    kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == base_id).first()
    if not kb:
        raise KnowledgeBaseNotFoundError(base_id)
    if hasattr(current_user, "tenant_id") and current_user.tenant_id:
        if not hasattr(kb, "tenant_id") or kb.tenant_id != current_user.tenant_id:
            raise KnowledgeBaseNotFoundError(base_id)
    category_ids = [
        cat.id
        for cat in db.query(KnowledgeCategory)
        .filter(KnowledgeCategory.knowledge_base_id == base_id)
        .all()
    ]
    total_documents = 0
    vectorized_documents = 0
    if category_ids:
        total_documents = (
            db.query(func.count(KnowledgeDocument.id))
            .filter(KnowledgeDocument.category_id.in_(category_ids))
            .scalar()
            or 0
        )
        vectorized_documents = (
            db.query(func.count(KnowledgeDocument.id))
            .filter(
                KnowledgeDocument.category_id.in_(category_ids),
                KnowledgeDocument.vectorization_status == "completed",
            )
            .scalar()
            or 0
        )
    total_categories = len(category_ids)
    total_vectors = 0
    if category_ids:
        document_ids = [
            doc.id
            for doc in db.query(KnowledgeDocument)
            .filter(KnowledgeDocument.category_id.in_(category_ids))
            .all()
        ]
        if document_ids:
            try:
                from app.config import settings
                from app.services.vector.milvus_client import get_milvus_manager

                manager = get_milvus_manager()
                if not manager.connected:
                    manager.connect(
                        host=settings.MILVUS_HOST,
                        port=settings.MILVUS_PORT,
                        use_lite=settings.USE_MILVUS_LITE,
                    )
                collection = manager.get_collection("document_vectors")
                for doc_id in document_ids:
                    expr = f"document_id == {doc_id}"
                    results = collection.query(
                        expr=expr, output_fields=["document_id"], limit=10000
                    )
                    total_vectors += len(results)
            except Exception as e:
                logger.warning(f"Failed to count vectors from Milvus: {e}")
                from app.models.knowledge_base import DocumentVector

                total_vectors = (
                    db.query(func.count(DocumentVector.id))
                    .filter(DocumentVector.document_id.in_(document_ids))
                    .scalar()
                    or 0
                )
    vectorization_rate = 0
    if total_documents > 0:
        vectorization_rate = vectorized_documents / total_documents * 100
    qa_count = 0
    try:
        from app.models.knowledge_base import KnowledgeQA

        if category_ids:
            document_ids = [
                doc.id
                for doc in db.query(KnowledgeDocument)
                .filter(KnowledgeDocument.category_id.in_(category_ids))
                .all()
            ]
            if document_ids:
                qa_count = (
                    db.query(func.count(KnowledgeQA.id))
                    .filter(
                        KnowledgeQA.document_id.in_(document_ids), KnowledgeQA.status == "active"
                    )
                    .scalar()
                    or 0
                )
    except Exception:
        pass
    return {
        "category_count": total_categories,
        "document_count": total_documents,
        "vector_count": total_vectors,
        "vectorized_document_count": vectorized_documents,
        "vectorization_rate": vectorization_rate,
        "qa_count": qa_count,
    }


@router.delete("/{base_id}", status_code=status.HTTP_200_OK)
def delete_knowledge_base(
    base_id: int,
    request: Request,
    db: Session = Depends(get_db),
    current_user: User = Depends(require_delete("knowledge_bases")),
):
    t = get_translator(request)
    db_kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == base_id).first()
    if not db_kb:
        raise KnowledgeBaseNotFoundError(base_id)
    if hasattr(current_user, "tenant_id") and current_user.tenant_id:
        if not hasattr(db_kb, "tenant_id") or db_kb.tenant_id != current_user.tenant_id:
            raise KnowledgeBaseNotFoundError(base_id)
    tenant_id = db_kb.tenant_id
    if tenant_id is None:
        logger.warning(f"[delete_kb={base_id}] KB has no tenant_id; MinIO cleanup will be skipped")
    if settings.ENABLE_KNOWLEDGE_GRAPH:
        try:
            neo4j = get_neo4j_client()
            neo4j.delete_kb_data(kb_id=base_id, tenant_id=tenant_id)
            logger.info(f"[delete_kb={base_id}] Neo4j graph data cleared")
        except Exception as e:
            logger.error(f"[delete_kb={base_id}] Failed to clear Neo4j data: {e}")
            raise HTTPException(
                status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
                detail="删除知识库失败（图谱清理错误）",
            ) from e
    try:
        vector_store = get_vector_store(db, knowledge_base_id=base_id, tenant_id=tenant_id)
        vector_store.drop_partition()
        logger.info(f"[delete_kb={base_id}] Milvus partition dropped")
    except Exception as e:
        logger.error(f"[delete_kb={base_id}] Failed to drop Milvus partition: {e}")
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail="删除知识库失败（向量清理错误）",
        ) from e
    if tenant_id is not None:
        try:
            minio_service = get_minio_service()
            prefix = f"tenant_{tenant_id}/kb_{base_id}/"
            minio_service.delete_by_prefix(bucket_name=settings.MINIO_BUCKET, prefix=prefix)
            logger.info(f"[delete_kb={base_id}] MinIO files deleted with prefix: {prefix}")
        except Exception as e:
            logger.error(f"[delete_kb={base_id}] Failed to delete MinIO files: {e}")
            raise HTTPException(
                status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
                detail="删除知识库失败（文件清理错误）",
            ) from e
    else:
        logger.info(f"[delete_kb={base_id}] MinIO cleanup skipped (no tenant_id)")
    try:
        category_ids = [
            cat.id
            for cat in db.query(KnowledgeCategory)
            .filter(KnowledgeCategory.knowledge_base_id == base_id)
            .all()
        ]
        all_document_ids = []
        if category_ids:
            all_document_ids = [
                doc.id
                for doc in db.query(KnowledgeDocument)
                .filter(KnowledgeDocument.category_id.in_(category_ids))
                .all()
            ]
        if tenant_id is not None:
            kb_prefix = f"tenant_{tenant_id}/kb_{base_id}/"
            orphan_doc_ids = [
                doc.id
                for doc in db.query(KnowledgeDocument)
                .filter(
                    KnowledgeDocument.category_id.is_(None),
                    KnowledgeDocument.tenant_id == tenant_id,
                    KnowledgeDocument.file_path.like(f"{kb_prefix}%"),
                )
                .all()
            ]
        else:
            orphan_doc_ids = []
        all_document_ids = list(set(all_document_ids) | set(orphan_doc_ids))
        if all_document_ids:
            db.query(DocumentVector).filter(
                DocumentVector.document_id.in_(all_document_ids)
            ).delete(synchronize_session=False)
            db.query(DocumentTag).filter(DocumentTag.document_id.in_(all_document_ids)).delete(
                synchronize_session=False
            )
            db.query(DocumentVersion).filter(
                DocumentVersion.document_id.in_(all_document_ids)
            ).delete(synchronize_session=False)
            db.query(DocumentMetadataValue).filter(
                DocumentMetadataValue.document_id.in_(all_document_ids)
            ).delete(synchronize_session=False)
            db.query(KnowledgeDocument).filter(KnowledgeDocument.id.in_(all_document_ids)).delete(
                synchronize_session=False
            )
        db.query(KnowledgeCategory).filter(KnowledgeCategory.knowledge_base_id == base_id).delete(
            synchronize_session=False
        )
        try:
            savepoint = db.begin_nested()
            db.execute(
                text("DELETE FROM graph_build_status WHERE knowledge_base_id = :kb_id"),
                {"kb_id": base_id},
            )
            savepoint.commit()
        except Exception as e:
            savepoint.rollback()
            if "does not exist" not in str(e).lower() and "undefined" not in str(e).lower():
                raise
            logger.debug("GraphBuildStatus table does not exist, skipping")
        try:
            savepoint = db.begin_nested()
            db.execute(text("DELETE FROM graph_entities WHERE kb_id = :kb_id"), {"kb_id": base_id})
            savepoint.commit()
        except Exception as e:
            savepoint.rollback()
            if "does not exist" not in str(e).lower() and "undefined" not in str(e).lower():
                raise
            logger.debug("GraphEntity table does not exist, skipping")
        db.query(KnowledgeBaseMetadata).filter(
            KnowledgeBaseMetadata.knowledge_base_id == base_id
        ).delete(synchronize_session=False)
        db.query(KnowledgeBaseMetadataField).filter(
            KnowledgeBaseMetadataField.knowledge_base_id == base_id
        ).delete(synchronize_session=False)
        db.query(KnowledgeMetadataField).filter(
            KnowledgeMetadataField.knowledge_base_id == base_id
        ).delete(synchronize_session=False)
        db.query(MetadataField).filter(MetadataField.knowledge_base_id == base_id).delete(
            synchronize_session=False
        )
        db.delete(db_kb)
        db.commit()
        logger.info(f"[delete_kb={base_id}] PostgreSQL records deleted and committed")
    except Exception as e:
        db.rollback()
        logger.error(f"[delete_kb={base_id}] PG transaction failed, rolled back: {e}")
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="删除知识库失败（数据库错误）"
        ) from e
    try:
        from app.services.cache.query_cache import get_query_cache_service

        cache_service = get_query_cache_service()
        cache_service.invalidate_knowledge_base_cache(base_id)
    except Exception as e:
        logger.warning(f"[delete_kb={base_id}] Cache invalidation failed (non-fatal): {e}")
    return {"success": True, "message": t.t("knowledge_base.deleted")}
