from fastapi import APIRouter, Depends, Request
from sqlalchemy.orm import Session

from app.api.permissions import require_read
from app.core.exceptions import DocumentNotFoundError
from app.core.i18n import get_translator
from app.db.session import get_db
from app.models import KnowledgeDocument, User
from common_logging import get_logger

logger = get_logger(__name__)
router = APIRouter()


@router.get("/documents/{document_id}/chunks")
def get_document_chunks(
    request: Request,
    document_id: int,
    db: Session = Depends(get_db),
    current_user: User = Depends(require_read("knowledge_bases")),
):
    from app.models.knowledge_base import KnowledgeCategory

    get_translator(request)
    doc = db.query(KnowledgeDocument).filter(KnowledgeDocument.id == document_id).first()
    if not doc:
        raise DocumentNotFoundError(document_id)
    knowledge_base_id = None
    if doc.category_id:
        category = (
            db.query(KnowledgeCategory).filter(KnowledgeCategory.id == doc.category_id).first()
        )
        if category:
            knowledge_base_id = category.knowledge_base_id
            logger.info(f"Document {document_id} belongs to knowledge base ID: {knowledge_base_id}")
    try:
        from app.config import settings
        from app.services.vector.milvus_client import get_milvus_manager

        manager = get_milvus_manager()
        if not manager.connected:
            manager.connect(
                host=settings.MILVUS_HOST,
                port=settings.MILVUS_PORT,
                use_lite=settings.USE_MILVUS_LITE,
            )
        collection = manager.get_collection("document_vectors")
        expr = f"document_id == {document_id}"
        logger.info(
            f"Querying chunk data - Document ID: {document_id}, Knowledge base ID: {knowledge_base_id}, Expression: {expr}"
        )
        results = collection.query(
            expr=expr,
            output_fields=[
                "chunk_index",
                "chunk_text",
                "knowledge_base_id",
                "document_id",
                "parent_chunk_id",
                "is_parent",
            ],
            limit=1000,
        )
        logger.info(f"Query result: Found {len(results)} chunks for document_id={document_id}")
        if results:
            logger.info(
                f"First chunk example: document_id={results[0].get('document_id')}, knowledge_base_id={results[0].get('knowledge_base_id')}, chunk_index={results[0].get('chunk_index')}"
            )
            logger.info(f"Available fields in result: {list(results[0].keys())}")
        else:
            logger.warning(
                f"No chunks found for document_id={document_id}, trying to check collection stats..."
            )
            all_results = collection.query(
                expr="document_id > 0", output_fields=["document_id"], limit=10
            )
            logger.info(
                f"Collection has {len(all_results)} total records, sample document_ids: {[r.get('document_id') for r in all_results]}"
            )
            if knowledge_base_id:
                expr = f"document_id == {document_id} && knowledge_base_id == {knowledge_base_id}"
                logger.info(f"Retrying query - Expression: {expr}")
                results = collection.query(
                    expr=expr,
                    output_fields=[
                        "chunk_index",
                        "chunk_text",
                        "knowledge_base_id",
                        "document_id",
                        "parent_chunk_id",
                        "is_parent",
                    ],
                    limit=1000,
                )
                logger.info(f"Retry found {len(results)} chunks")
        if results:
            is_parent_values = [r.get("is_parent") for r in results[:5]]
            logger.info(f"is_parent values for first 5 chunks: {is_parent_values}")
            logger.info(
                f"chunk sizes for first 5 chunks: {[len(r.get('chunk_text', '')) for r in results[:5]]}"
            )
        parent_chunks = []
        child_chunks = []
        for result in results:
            if result.get("is_parent", False):
                parent_chunks.append(result)
            else:
                child_chunks.append(result)
        logger.info(
            f"Sorting: {len(parent_chunks)} parent chunks, {len(child_chunks)} child chunks"
        )
        parent_chunks.sort(key=lambda x: x.get("chunk_index", 0))
        child_chunks.sort(key=lambda x: x.get("chunk_index", 0))
        results = parent_chunks + child_chunks
        deduplicated_results = []
        seen = set()
        for result in results:
            chunk_text = (result.get("chunk_text") or "").strip()
            if not chunk_text:
                continue
            signature = (result.get("chunk_index", 0), chunk_text)
            if signature in seen:
                continue
            seen.add(signature)
            deduplicated_results.append(result)
        chunks = []
        for display_index, result in enumerate(deduplicated_results):
            chunk_text = result.get("chunk_text", "")
            chunks.append(
                {
                    "chunk_index": display_index,
                    "original_chunk_index": result.get("chunk_index", 0),
                    "text": chunk_text,
                    "character_count": len(chunk_text),
                    "is_parent": result.get("is_parent", False),
                    "parent_chunk_id": result.get("parent_chunk_id", ""),
                }
            )
        logger.info("Chunks retrieved", doc_id=document_id, chunk_count=len(chunks))
        return {
            "document_id": document_id,
            "is_vectorized": doc.is_vectorized,
            "chunks": chunks,
            "total": len(chunks),
        }
    except Exception as e:
        logger.error(f"Failed to get document chunks: {e}".opt(exception=True))
        return {
            "document_id": document_id,
            "is_vectorized": doc.is_vectorized,
            "chunks": [],
            "total": 0,
        }
