import threading
from datetime import datetime

from common_logging import get_logger

logger = get_logger(__name__)


class VectorizationTask:

    def __init__(self, document_id: int, total_chunks: int = 1):
        self.document_id = document_id
        self.total_chunks = total_chunks
        self.completed_chunks = 0
        self.status = "pending"
        self.error_message = None
        self.started_at = None
        self.completed_at = None

    @property
    def progress(self) -> float:
        if self.total_chunks == 0:
            return 0.0
        return self.completed_chunks / self.total_chunks * 100

    def start(self):
        self.status = "processing"
        self.started_at = datetime.now()

    def update_progress(self, completed: int):
        self.completed_chunks = completed

    def complete(self):
        self.status = "completed"
        self.completed_at = datetime.now()
        self.completed_chunks = self.total_chunks

    def fail(self, error: str):
        self.status = "failed"
        self.error_message = error
        self.completed_at = datetime.now()

    def to_dict(self) -> dict:
        return {
            "document_id": self.document_id,
            "status": self.status,
            "progress": self.progress,
            "total_chunks": self.total_chunks,
            "completed_chunks": self.completed_chunks,
            "error_message": self.error_message,
            "started_at": self.started_at.isoformat() if self.started_at else None,
            "completed_at": self.completed_at.isoformat() if self.completed_at else None,
        }


class VectorizationTaskManager:

    def __init__(self):
        self.tasks: dict[int, VectorizationTask] = {}
        self.lock = threading.Lock()

    def create_task(self, document_id: int, total_chunks: int = 1) -> VectorizationTask:
        with self.lock:
            task = VectorizationTask(document_id, total_chunks)
            self.tasks[document_id] = task
            logger.bind(task_id=document_id).info("vectorization task created")
            return task

    def get_task(self, document_id: int) -> VectorizationTask | None:
        with self.lock:
            return self.tasks.get(document_id)

    def update_task(self, document_id: int, completed: int):
        with self.lock:
            task = self.tasks.get(document_id)
            if task:
                task.update_progress(completed)

    def complete_task(self, document_id: int):
        with self.lock:
            task = self.tasks.get(document_id)
            if task:
                task.complete()
                logger.bind(task_id=document_id).info("vectorization task completed")

    def fail_task(self, document_id: int, error: str):
        with self.lock:
            task = self.tasks.get(document_id)
            if task:
                task.fail(error)
                logger.bind(task_id=document_id).warning(
                    "vectorization task failed: {error}", error=error
                )

    def remove_task(self, document_id: int):
        with self.lock:
            if document_id in self.tasks:
                del self.tasks[document_id]
                logger.bind(task_id=document_id).info("vectorization task removed")

    def get_all_tasks(self) -> list:
        with self.lock:
            return [task.to_dict() for task in self.tasks.values()]

    def cleanup_old_tasks(self, max_age_seconds: int = 3600):
        with self.lock:
            now = datetime.now()
            to_remove = []
            for doc_id, task in self.tasks.items():
                if task.completed_at:
                    age = (now - task.completed_at).total_seconds()
                    if age > max_age_seconds:
                        to_remove.append(doc_id)
            for doc_id in to_remove:
                del self.tasks[doc_id]


_task_manager = None


def get_task_manager() -> VectorizationTaskManager:
    global _task_manager
    if _task_manager is None:
        _task_manager = VectorizationTaskManager()
    return _task_manager
