from datetime import datetime

from common_logging import get_logger, log_execution
from sqlalchemy.orm import Session

from app.core.websocket_manager import manager as ws_manager
from app.models.dpo_task import DPOTask
from app.models.sft_task import SFTTask
from app.services.training_platform import TrainingConfig, get_training_platform

logger = get_logger(__name__)


class TrainingManager:

    def __init__(self, db: Session):
        self.db = db

    @log_execution(logger)
    async def start_sft_training(self, task_id: int):
        task = self.db.query(SFTTask).filter(SFTTask.id == task_id).first()
        if not task:
            logger.warning(f"SFT task {task_id} not found")
            raise ValueError(f'Task {task_id} not found')
        platform = get_training_platform(task.platform)
        hyperparameters = task.hyperparameters or {}
        hyperparameters['task_type'] = 'sft'
        config = TrainingConfig(name=task.name, description=task.description or '', dataset_id=task.dataset_id, model_name=task.model_name, hyperparameters=hyperparameters)
        job_id = platform.create_training_job(config)
        task.platform_job_id = job_id
        task.status = 'running'
        task.started_at = datetime.utcnow()
        self.db.commit()
        logger.bind(task_id=task_id, job_id=job_id).info("SFT training started")
        return job_id

    @log_execution(logger)
    async def start_dpo_training(self, task_id: int):
        task = self.db.query(DPOTask).filter(DPOTask.id == task_id).first()
        if not task:
            logger.warning(f"DPO task {task_id} not found")
            raise ValueError(f'Task {task_id} not found')
        platform = get_training_platform(task.platform)
        hyperparameters = task.hyperparameters or {}
        hyperparameters['task_type'] = 'dpo'
        config = TrainingConfig(name=task.name, description=task.description or '', dataset_id=task.dataset_id, model_name=task.model_name, hyperparameters=hyperparameters)
        job_id = platform.create_training_job(config)
        task.platform_job_id = job_id
        task.status = 'running'
        task.started_at = datetime.utcnow()
        self.db.commit()
        logger.bind(task_id=task_id, job_id=job_id).info("DPO training started")
        return job_id

    async def update_task_status(self, task_id: int, task_type: str):
        if task_type == 'sft':
            task = self.db.query(SFTTask).filter(SFTTask.id == task_id).first()
        else:
            task = self.db.query(DPOTask).filter(DPOTask.id == task_id).first()
        if not task or not task.platform_job_id:
            return
        platform = get_training_platform(task.platform)
        status = platform.get_job_status(task.platform_job_id)
        task.status = status.status
        task.progress = status.progress
        task.logs = status.logs
        if status.status == 'completed':
            task.completed_at = datetime.utcnow()
            result = platform.get_job_result(task.platform_job_id)
            task.result = {'model_id': result.model_id, 'metrics': result.metrics, 'artifacts': result.artifacts}
        self.db.commit()
        await ws_manager.broadcast(str(task_id), {'type': 'status_update', 'task_id': task_id, 'status': status.status, 'progress': status.progress, 'message': status.message})

    @log_execution(logger)
    async def cancel_training(self, task_id: int, task_type: str):
        if task_type == 'sft':
            task = self.db.query(SFTTask).filter(SFTTask.id == task_id).first()
        else:
            task = self.db.query(DPOTask).filter(DPOTask.id == task_id).first()
        if not task or not task.platform_job_id:
            return False
        platform = get_training_platform(task.platform)
        success = platform.cancel_job(task.platform_job_id)
        if success:
            task.status = 'cancelled'
            self.db.commit()
        return success
