import threading
from concurrent.futures import ThreadPoolExecutor

from app.db.session import SessionLocal
from app.models.knowledge_base import KnowledgeDocument
from app.services.knowledge.vectorization_service import DocumentVectorizationService
from app.services.knowledge.vectorization_task_manager import get_task_manager

from common_logging import get_logger

logger = get_logger(__name__)



class AsyncVectorizationService:
    _executor = None
    _lock = threading.Lock()
    _max_workers = 3

    @classmethod
    def get_executor(cls):
        if cls._executor is None:
            with cls._lock:
                if cls._executor is None:
                    cls._executor = ThreadPoolExecutor(
                        max_workers=cls._max_workers, thread_name_prefix="vectorization"
                    )
                    logger.info(f"Created ThreadPoolExecutor with {cls._max_workers} workers")
        return cls._executor

    @staticmethod
    def vectorize_document_async(
        document_id: int,
        model_id: int | None = None,
        chunk_size: int | None = None,
        chunk_overlap: int | None = None,
        splitter_type: str | None = None,
    ):
        executor = AsyncVectorizationService.get_executor()
        executor.submit(
            AsyncVectorizationService._vectorize_document_worker,
            document_id,
            model_id,
            chunk_size,
            chunk_overlap,
            splitter_type,
        )
        logger.info(f"Submitted async vectorization task to thread pool: document_id={document_id}")

    @staticmethod
    def _vectorize_document_worker(
        document_id: int,
        model_id: int | None,
        chunk_size: int | None,
        chunk_overlap: int | None,
        splitter_type: str | None,
    ):
        db = SessionLocal()
        task_manager = get_task_manager()
        try:
            document = (
                db.query(KnowledgeDocument).filter(KnowledgeDocument.id == document_id).first()
            )
            if not document:
                logger.error(f"Document does not exist: document_id={document_id}")
                return
            document.vectorization_status = "processing"
            document.vectorization_progress = 0
            db.commit()
            task = task_manager.create_task(document_id, total_chunks=1)
            task.start()
            vectorization_service = DocumentVectorizationService(db)
            if document.segmentation_mode != "none":
                actual_chunk_size = document.chunk_size or chunk_size or 1000
                actual_chunk_overlap = document.chunk_overlap or chunk_overlap or 200
                actual_chunk_strategy = document.splitter_type or splitter_type or "recursive"
                logger.info(
                    f"Starting document vectorization {document_id}: chunk_size={actual_chunk_size}, chunk_overlap={actual_chunk_overlap}, chunk_strategy={actual_chunk_strategy}"
                )
                result = vectorization_service.vectorize_document(
                    document_id=document_id,
                    model_id=model_id,
                    chunk_size=actual_chunk_size,
                    chunk_overlap=actual_chunk_overlap,
                    chunk_strategy=actual_chunk_strategy,
                    tenant_id=document.tenant_id,
                    user_id=document.author_id,
                )
                if result and "chunks_count" in result:
                    task_manager.update_task(document_id, result["chunks_count"])
            else:
                logger.info(f"Starting document vectorization {document_id} (without chunking)")
                result = vectorization_service.vectorize_document(
                    document_id=document_id,
                    model_id=model_id,
                    tenant_id=document.tenant_id,
                    user_id=document.author_id,
                )
            document.vectorization_status = "completed"
            document.vectorization_progress = 100
            document.is_vectorized = True
            db.commit()
            task_manager.complete_task(document_id)
            logger.info(f"Document vectorization completed: document_id={document_id}")
        except Exception as e:
            logger.bind(document_id=document_id, error=str(e)).opt(exception=True).error(
                "Document vectorization failed"
            )
            try:
                document = (
                    db.query(KnowledgeDocument).filter(KnowledgeDocument.id == document_id).first()
                )
                if document:
                    document.vectorization_status = "failed"
                    document.vectorization_error = str(e)
                    db.commit()
            except Exception as update_error:
                logger.error(f"Failed to update document status: {update_error}")
            task_manager.fail_task(document_id, str(e))
        finally:
            db.close()


def get_async_vectorization_service() -> AsyncVectorizationService:
    return AsyncVectorizationService()
