import asyncio

from common_logging import get_logger
from sqlalchemy.orm import Session

from app.models.dpo_task import DPOTask
from app.models.sft_task import SFTTask
from app.services.training_manager import TrainingManager

logger = get_logger(__name__)


async def sync_training_status(db: Session, task_id: int, task_type: str):
    manager = TrainingManager(db)
    while True:
        try:
            if task_type == 'sft':
                task = db.query(SFTTask).filter(SFTTask.id == task_id).first()
            else:
                task = db.query(DPOTask).filter(DPOTask.id == task_id).first()
            await manager.update_task_status(task_id, task_type)
            if not task or task.status in ['completed', 'failed', 'cancelled']:
                break
            await asyncio.sleep(5)
        except Exception as e:
            logger.error(f'Error syncing task {task_id}: {e}')
            await asyncio.sleep(10)
