from fastapi import APIRouter, Depends, HTTPException, Request
from sqlalchemy.orm import Session

from app.api.permissions import require_read
from app.core.exceptions import EmbeddingGenerationError, VectorSearchError
from app.core.i18n import get_translator
from app.db.session import get_db
from app.models import KnowledgeBase, KnowledgeDocument, Model
from app.models.knowledge_base import KnowledgeCategory
from app.models.user import User
from app.schemas.knowledge_document import SearchRequest
from common_logging import get_logger, log_performance

logger = get_logger(__name__)
router = APIRouter()


@router.post("/search")
@log_performance(logger, threshold_ms=1000)
def search_knowledge(
    request: Request,
    search: SearchRequest,
    db: Session = Depends(get_db),
    current_user: User = Depends(require_read("knowledge_bases")),
):
    t = get_translator(request)
    if search.search_type == "keyword":
        query = db.query(KnowledgeDocument).filter(
            KnowledgeDocument.status == "published", KnowledgeDocument.is_public
        )
        from app.core.tenant_context import get_current_tenant_id

        tenant_id = get_current_tenant_id()
        if tenant_id:
            tenant_category_ids = (
                db.query(KnowledgeCategory.id)
                .join(KnowledgeBase, KnowledgeCategory.knowledge_base_id == KnowledgeBase.id)
                .filter(KnowledgeBase.tenant_id == tenant_id)
                .all()
            )
            tenant_category_ids = [cid[0] for cid in tenant_category_ids]
            if tenant_category_ids:
                query = query.filter(KnowledgeDocument.category_id.in_(tenant_category_ids))
            else:
                query = query.filter(KnowledgeDocument.id == -1)
        search_filter = KnowledgeDocument.title.contains(
            search.query
        ) | KnowledgeDocument.content.contains(search.query)
        query = query.filter(search_filter)
        if search.category_id:
            query = query.filter(KnowledgeDocument.category_id == search.category_id)
        documents = query.limit(search.limit).all()
        return {
            "results": [
                {
                    "id": doc.id,
                    "title": doc.title,
                    "summary": doc.summary or doc.content[:200],
                    "category_id": doc.category_id,
                    "created_at": doc.created_at,
                }
                for doc in documents
            ],
            "total": len(documents),
        }
    elif search.search_type == "semantic":
        try:
            from app.services.llm.backends.embedding_backend_factory import get_embedding_factory
            from app.services.storage.vector_store_factory import get_vector_store

            embedding_factory = get_embedding_factory()
            vector_model_id = None
            if search.knowledge_base_id:
                kb = (
                    db.query(KnowledgeBase)
                    .filter(KnowledgeBase.id == search.knowledge_base_id)
                    .first()
                )
                from app.core.tenant_context import get_current_tenant_id

                tenant_id = get_current_tenant_id()
                if kb and tenant_id:
                    if hasattr(kb, "tenant_id") and kb.tenant_id != tenant_id:
                        raise HTTPException(status_code=404, detail="Knowledge base not found")
                if kb and kb.code:
                    if kb.code.isdigit():
                        vector_model = db.query(Model).filter(Model.id == int(kb.code)).first()
                    else:
                        vector_model = db.query(Model).filter(Model.code == kb.code).first()
                    if vector_model:
                        vector_model_id = vector_model.id
            if not vector_model_id:
                default_model = (
                    db.query(Model)
                    .filter(Model.type == "embedding", Model.is_active)
                    .first()
                )
                if default_model:
                    vector_model_id = default_model.id
            if not vector_model_id:
                return {
                    "results": [],
                    "total": 0,
                    "message": t.t("knowledge.no_embedding_model_available"),
                }
            query_embedding = embedding_factory.generate_embedding(
                text=search.query, db=db, model_id=vector_model_id
            )
            if not query_embedding:
                raise EmbeddingGenerationError("Failed to generate query embedding")
            vector_store = get_vector_store(db, knowledge_base_id=search.knowledge_base_id)
            filter_dict = {}
            if search.category_id:
                category_docs = (
                    db.query(KnowledgeDocument.id)
                    .filter(KnowledgeDocument.category_id == search.category_id)
                    .all()
                )
                doc_ids = [doc.id for doc in category_docs]
            similar_results = vector_store.similarity_search(
                query_embedding=query_embedding,
                k=search.limit * 2,
                threshold=0.5,
                filter_dict=filter_dict,
            )
            results = []
            seen_doc_ids = set()
            doc_ids_to_fetch = []
            for doc_dict, score in similar_results:
                doc_id = doc_dict.get("document_id")
                if search.category_id and doc_dict.get("category_id") != search.category_id:
                    continue
                if doc_id in seen_doc_ids:
                    continue
                seen_doc_ids.add(doc_id)
                doc_ids_to_fetch.append((doc_id, doc_dict, score))
                if len(doc_ids_to_fetch) >= search.limit:
                    break
            if doc_ids_to_fetch:
                doc_ids = [item[0] for item in doc_ids_to_fetch]
                docs_query = db.query(KnowledgeDocument).filter(KnowledgeDocument.id.in_(doc_ids))
                if current_user.role in ["customer_user", "customer_admin"]:
                    pass
                docs = docs_query.all()
                doc_map = {doc.id: doc for doc in docs}
                for doc_id, doc_dict, score in doc_ids_to_fetch:
                    doc = doc_map.get(doc_id)
                    if doc:
                        results.append(
                            {
                                "id": doc.id,
                                "title": doc.title,
                                "summary": doc.summary or doc.content[:200],
                                "content": doc_dict.get("text", ""),
                                "category_id": doc.category_id,
                                "score": score,
                                "created_at": doc.created_at,
                            }
                        )
            return {"results": results, "total": len(results)}
            logger.bind(query=search.query, result_count=len(results)).info("Semantic search completed")
            return {"results": results, "total": len(results)}
        except VectorSearchError:
            raise
        except EmbeddingGenerationError:
            raise
        except Exception as e:
            import traceback

            logger.bind(query=search.query).error(f"Vector search failed: {e}")
            logger.error(f"Traceback: {traceback.format_exc()}")
            raise VectorSearchError(f"向量搜索失败: {str(e)}") from None
    else:
        raise HTTPException(status_code=400, detail=t.t("knowledge.unsupported_search_type"))


@router.post("/search/vector")
@log_performance(logger, threshold_ms=1000)
def vector_search(
    request: Request,
    search: SearchRequest,
    db: Session = Depends(get_db),
    current_user: User = Depends(require_read("knowledge_bases")),
):
    get_translator(request)
    search.search_type = "semantic"
    return search_knowledge(request, search, db, current_user)
