
from sqlalchemy.orm import Session

from app.core.exceptions import (

    DocumentNotFoundError,
    EmbeddingGenerationError,
    KnowledgeBaseNotFoundError,
)
from app.models import DocumentVector, KnowledgeBase, KnowledgeDocument
from app.services.knowledge.text_splitter_service import SplitterType, TextSplitterService
from app.services.knowledge.vectorization_task_manager import get_task_manager
from app.services.llm.backends.embedding_backend_factory import get_embedding_factory
from app.services.rag.langchain_document_processor import get_document_processor
from app.services.storage.vector_store_factory import get_vector_store

from common_logging import get_logger

logger = get_logger(__name__)



class DocumentVectorizationService:

    def __init__(self, db: Session):
        self.db = db
        self.doc_processor = get_document_processor()
        self.text_splitter_service = TextSplitterService()
        self.embedding_factory = get_embedding_factory()
        self.task_manager = get_task_manager()

    def vectorize_document(
        self,
        document_id: int,
        chunk_strategy: str = "recursive",
        chunk_size: int = 1000,
        chunk_overlap: int = 200,
        window_size: int = 4,
        model_id: int | None = None,
        tenant_id: int | None = None,
        user_id: int | None = None,
        cleanup_model: bool = True,
    ) -> dict:
        if user_id is None:
            raise ValueError("user_id is required for vectorization")
        try:
            document = (
                self.db.query(KnowledgeDocument).filter(KnowledgeDocument.id == document_id).first()
            )
            if not document:
                raise DocumentNotFoundError(document_id)
            if tenant_id is None:
                if document.tenant_id is not None:
                    logger.warning(
                        f"Tenant mismatch: document {document_id} belongs to tenant {document.tenant_id}, but platform-level access (tenant_id=None) was requested"
                    )
                    raise DocumentNotFoundError(document_id)
            elif document.tenant_id != tenant_id:
                logger.warning(
                    f"Tenant mismatch: document {document_id} belongs to tenant {document.tenant_id}, but tenant {tenant_id} was requested"
                )
                raise DocumentNotFoundError(document_id)
            logger.info(f"Starting document vectorization: {document.title} (ID: {document_id})")
            if chunk_strategy == "tax_adaptive":
                chunks_with_metadata = self.text_splitter_service.create_chunks_with_metadata(
                    text=document.content,
                    chunk_size=chunk_size,
                    chunk_overlap=chunk_overlap,
                    splitter_type=SplitterType.TAX_ADAPTIVE,
                    document_id=document_id,
                    document_title=document.title,
                    document_number=document.doc_number,
                    window_size=window_size,
                )
                chunks = [
                    {
                        "text": chunk["text"],
                        "metadata": chunk.get("metadata", {}),
                        "chunk_id": chunk.get("chunk_id"),
                        "chunk_index": chunk.get("chunk_index", i),
                        "is_parent": chunk.get("is_parent", False),
                        "parent_chunk_id": chunk.get("parent_chunk_id"),
                        "chunk_level": chunk.get("chunk_level"),
                        "references": chunk.get("references", []),
                    }
                    for i, chunk in enumerate(chunks_with_metadata)
                ]
            elif document.file_path:
                chunks = self.doc_processor.process_file(
                    file_path=document.file_path,
                    strategy=chunk_strategy,
                    chunk_size=chunk_size,
                    chunk_overlap=chunk_overlap,
                    metadata={
                        "document_id": document_id,
                        "title": document.title,
                        "source": document.source,
                    },
                )
            else:
                chunks = self.doc_processor.process_text(
                    text=document.content,
                    strategy=chunk_strategy,
                    chunk_size=chunk_size,
                    chunk_overlap=chunk_overlap,
                    metadata={"document_id": document_id, "title": document.title},
                )
            logger.info(f"Document chunking completed: {len(chunks)} chunks")
            task = self.task_manager.create_task(document_id, len(chunks))
            task.start()
            embeddings = []
            chunk_texts = [chunk["text"] for chunk in chunks]
            batch_embeddings = self.embedding_factory.generate_embeddings_batch(
                texts=chunk_texts, db=self.db, model_id=model_id
            )
            for i, embedding in enumerate(batch_embeddings):
                if embedding:
                    embeddings.append(embedding)
                    self.task_manager.update_task(document_id, i + 1)
                else:
                    raise EmbeddingGenerationError(f"Failed to generate embedding for chunk {i}")
            logger.info(f"Embedding generation completed: {len(embeddings)} vectors")
            if model_id:
                from app.models.provider import Model

                model = self.db.query(Model).filter(Model.id == model_id).first()
                model_name = model.name if model else "unknown"
            else:
                model_name = "default"
            knowledge_base_id = None
            if document.category_id:
                from app.models.knowledge_base import KnowledgeCategory

                category = (
                    self.db.query(KnowledgeCategory)
                    .filter(KnowledgeCategory.id == document.category_id)
                    .first()
                )
                if category:
                    knowledge_base_id = category.knowledge_base_id
                    logger.info(f"Document belongs to knowledge base ID: {knowledge_base_id}")
            vector_store = get_vector_store(
                self.db, knowledge_base_id=knowledge_base_id, tenant_id=tenant_id
            )
            vector_ids = vector_store.add_documents(
                document_id=document_id,
                chunks=chunks,
                embeddings=embeddings,
                model_name=model_name,
                doc_status_list=[document.doc_status or "effective"] * len(chunks),
                issue_date_int_list=[
                    int(document.issue_date.strftime("%Y%m%d")) if document.issue_date else 0
                ]
                * len(chunks),
            )
            logger.info(f"Vector storage completed: {len(vector_ids)} vectors")
            for i, chunk in enumerate(chunks):
                doc_vector = DocumentVector(
                    document_id=document_id,
                    chunk_index=chunk.get("chunk_index", i),
                    chunk_text=chunk["text"],
                    milvus_id=str(vector_ids[i]) if i < len(vector_ids) else None,
                    model_name=model_name,
                    parent_chunk_id=chunk.get("parent_chunk_id"),
                    is_parent=chunk.get("is_parent", False),
                    chunk_level=chunk.get("chunk_level"),
                    doc_type=document.doc_type,
                    doc_number=document.doc_number,
                    issuing_authority=document.issuing_authority,
                    references=chunk.get("references"),
                    chunk_id=chunk.get("chunk_id"),
                    doc_status=document.doc_status or "effective",
                    issue_date_int=(
                        int(document.issue_date.strftime("%Y%m%d")) if document.issue_date else 0
                    ),
                )
                self.db.add(doc_vector)
            self.db.flush()
            logger.info(f"PostgreSQL metadata storage completed: {len(chunks)} records")
            if chunk_strategy == "tax_adaptive" and knowledge_base_id:
                try:
                    from app.config import settings
                    from app.services.graph.graph_builder import get_graph_builder

                    if settings.ENABLE_KNOWLEDGE_GRAPH:
                        graph_builder = get_graph_builder()
                        graph_result = graph_builder.build_chunk_graph(
                            document_id=document_id,
                            chunks=chunks,
                            tenant_id=tenant_id or 0,
                            kb_id=knowledge_base_id,
                        )
                        if graph_result.get("success"):
                            logger.info(
                                f"Chunk graph built: {graph_result.get('chunk_count')} chunks, {graph_result.get('reference_count')} references"
                            )
                        else:
                            logger.warning(f"Chunk graph build failed: {graph_result.get('error')}")
                    else:
                        logger.info("Knowledge graph disabled, skipping chunk graph build")
                except Exception as e:
                    logger.warning(f"Failed to build chunk graph (non-fatal): {e}")
            document.is_vectorized = True
            document.vector_model = model_name
            document.enable_parent_child = chunk_strategy == "tax_adaptive"
            document.window_size = window_size if chunk_strategy == "tax_adaptive" else None
            self.db.commit()
            self.task_manager.complete_task(document_id)
            if cleanup_model:
                try:
                    self.embedding_factory.unload_local_embedding_model()
                    logger.info("Local embedding model unloaded after vectorization")
                except Exception as e:
                    logger.warning(f"Failed to unload embedding model: {e}")
            return {
                "success": True,
                "document_id": document_id,
                "chunks_count": len(chunks),
                "vectors_count": len(vector_ids),
                "model_name": model_name,
                "strategy": chunk_strategy,
                "chunk_size": chunk_size,
                "chunk_overlap": chunk_overlap,
            }
        except Exception as e:
            logger.error(f"Document vectorization failed: {e}")
            self.task_manager.fail_task(document_id, str(e))
            raise

    def batch_vectorize_documents(
        self,
        document_ids: list,
        chunk_strategy: str = "recursive",
        chunk_size: int = 1000,
        chunk_overlap: int = 200,
        model_id: int | None = None,
        tenant_id: int | None = None,
        user_id: int | None = None,
        cleanup_model: bool = True,
    ) -> dict:
        if user_id is None:
            raise ValueError("user_id is required for batch vectorization")
        results = {"success": [], "failed": [], "total": len(document_ids)}
        for doc_id in document_ids:
            try:
                result = self.vectorize_document(
                    document_id=doc_id,
                    chunk_strategy=chunk_strategy,
                    chunk_size=chunk_size,
                    chunk_overlap=chunk_overlap,
                    model_id=model_id,
                    tenant_id=tenant_id,
                    user_id=user_id,
                    cleanup_model=False,
                )
                results["success"].append({"document_id": doc_id, "result": result})
            except Exception as e:
                results["failed"].append({"document_id": doc_id, "error": str(e)})
        if cleanup_model:
            try:
                self.embedding_factory.unload_local_embedding_model()
                logger.info("Local embedding model unloaded after batch vectorization")
            except Exception as e:
                logger.warning(f"Failed to unload embedding model after batch vectorization: {e}")
        return results

    def re_vectorize_knowledge_base(
        self,
        knowledge_base_id: int,
        model_id: int | None = None,
        tenant_id: int | None = None,
        user_id: int | None = None,
        cleanup_model: bool = True,
    ) -> dict:
        if user_id is None:
            raise ValueError("user_id is required for knowledge base re-vectorization")
        try:
            kb = self.db.query(KnowledgeBase).filter(KnowledgeBase.id == knowledge_base_id).first()
            if not kb:
                raise KnowledgeBaseNotFoundError(knowledge_base_id)
            if tenant_id is None:
                if kb.tenant_id is not None:
                    logger.warning(
                        f"Tenant mismatch: knowledge base {knowledge_base_id} belongs to tenant {kb.tenant_id}, but platform-level access was requested"
                    )
                    raise KnowledgeBaseNotFoundError(knowledge_base_id)
            elif kb.tenant_id != tenant_id:
                logger.warning(
                    f"Tenant mismatch: knowledge base {knowledge_base_id} belongs to tenant {kb.tenant_id}, but tenant {tenant_id} was requested"
                )
                raise KnowledgeBaseNotFoundError(knowledge_base_id)
            from app.models import KnowledgeCategory

            category_ids = [
                cat.id
                for cat in self.db.query(KnowledgeCategory)
                .filter(KnowledgeCategory.knowledge_base_id == knowledge_base_id)
                .all()
            ]
            documents = (
                self.db.query(KnowledgeDocument)
                .filter(KnowledgeDocument.category_id.in_(category_ids))
                .all()
                if category_ids
                else []
            )
            document_ids = [doc.id for doc in documents]
            logger.info(
                f"Starting knowledge base re-vectorization: {kb.name}, document count: {len(document_ids)}"
            )
            results = self.batch_vectorize_documents(
                document_ids=document_ids,
                model_id=model_id,
                tenant_id=tenant_id,
                user_id=user_id,
                cleanup_model=cleanup_model,
            )
            return results
        except Exception as e:
            logger.error(f"Knowledge base re-vectorization failed: {e}")
            raise


def get_vectorization_service(db: Session) -> DocumentVectorizationService:
    return DocumentVectorizationService(db)


if __name__ == "__main__":
    from app.db.session import SessionLocal


    db = SessionLocal()
    try:
        vectorization_service = DocumentVectorizationService(db)
        result = vectorization_service.vectorize_document(
            document_id=123,
            chunk_strategy="recursive",
            chunk_size=1000,
            chunk_overlap=200,
            model_id=1,
            user_id=1,
        )
        logger.info(f"Vectorization result: {result}")
        batch_result = vectorization_service.batch_vectorize_documents(
            document_ids=[123, 124, 125],
            chunk_strategy="recursive",
            chunk_size=1000,
            chunk_overlap=200,
            model_id=1,
            user_id=1,
        )
        logger.info(f"Batch processing result: {batch_result}")
        kb_result = vectorization_service.re_vectorize_knowledge_base(
            knowledge_base_id=1, model_id=2, user_id=1
        )
        logger.info(f"Knowledge base re-vectorization result: {kb_result}")
    finally:
        db.close()
