import uuid
from datetime import datetime

from sqlalchemy.orm import Session

from app.models.tax_data import DataProcessingTask, ProcessingLog
from app.services.tax_data_processor.category_processor import CategoryProcessor
from common_logging import get_logger
logger = get_logger(__name__)




class ProcessorScheduler:

    def __init__(self, db: Session, max_concurrent: int=5):
        self.db = db
        self.max_concurrent = max_concurrent
        self.category_processor = CategoryProcessor()

    def create_task(self, category_id: int | None, mode: str='full') -> DataProcessingTask:
        task_id = str(uuid.uuid4())
        task = DataProcessingTask(task_id=task_id, category_id=category_id, mode=mode, status='pending', progress=0, total_count=0, success_count=0, failed_count=0)
        self.db.add(task)
        self.db.commit()
        self.db.refresh(task)
        logger.info(f'创建任务: {task_id}, category_id={category_id}, mode={mode}')
        return task

    def update_task_status(self, task_id: str, status: str, progress: int | None=None, error_message: str | None=None):
        task = self.db.query(DataProcessingTask).filter(DataProcessingTask.task_id == task_id).first()
        if not task:
            logger.error(f'任务不存在: {task_id}')
            return
        task.status = status
        if progress is not None:
            task.progress = progress
        if error_message:
            task.error_message = error_message
        if status == 'running' and (not task.started_at):
            task.started_at = datetime.now()
        if status in ['completed', 'failed']:
            task.completed_at = datetime.now()
        self.db.commit()
        logger.info(f'更新任务状态: {task_id}, status={status}, progress={progress}')

    def update_task_counts(self, task_id: str, total_count: int | None=None, success_count: int | None=None, failed_count: int | None=None):
        task = self.db.query(DataProcessingTask).filter(DataProcessingTask.task_id == task_id).first()
        if not task:
            return
        if total_count is not None:
            task.total_count = total_count
        if success_count is not None:
            task.success_count = success_count
        if failed_count is not None:
            task.failed_count = failed_count
        if task.total_count > 0:
            processed = (task.success_count or 0) + (task.failed_count or 0)
            task.progress = int(processed / task.total_count * 100)
        self.db.commit()

    def log_processing(self, task_id: str, document_url: str | None, log_level: str, message: str, exception: str | None=None):
        log = ProcessingLog(task_id=task_id, document_url=document_url, log_level=log_level, message=message, exception=exception)
        self.db.add(log)
        self.db.commit()

    def get_category_priority(self, category_id: int) -> int:
        config = self.category_processor.get_category_config(category_id)
        if not config:
            return 999
        return config['count']

    def get_categories_by_priority(self, category_ids: list[int] | None=None) -> list[int]:
        if category_ids:
            categories = [self.category_processor.get_category_config(cid) for cid in category_ids]
            categories = [c for c in categories if c]
        else:
            categories = self.category_processor.get_all_categories()
        categories.sort(key=lambda x: x['count'])
        return [c['id'] for c in categories]
