import asyncio

from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Request
from sqlalchemy.orm import Session

from app.api.deps import get_db
from app.api.permissions import require_create, require_delete, require_read, require_update
from app.core.i18n import get_translator
from app.models import DPOTask, User
from app.schemas.dpo_task import DPOTaskCreate, DPOTaskResponse, DPOTaskUpdate
from app.services.training_manager import TrainingManager
from app.services.training_sync import sync_training_status

from common_logging import get_logger

logger = get_logger(__name__)

router = APIRouter(tags=['dpo_tasks'])

@router.get('/', response_model=list[DPOTaskResponse])
def get_dpo_tasks(skip: int=0, limit: int=100, db: Session=Depends(get_db), current_user: User=Depends(require_read('dpo_tasks'))):
    tasks = db.query(DPOTask).filter(not DPOTask.is_deleted, DPOTask.tenant_id == current_user.tenant_id).offset(skip).limit(limit).all()
    return tasks

@router.get('/{task_id}', response_model=DPOTaskResponse)
def get_dpo_task(task_id: int, request: Request, db: Session=Depends(get_db), current_user: User=Depends(require_read('dpo_tasks'))):
    t = get_translator(request)
    task = db.query(DPOTask).filter(DPOTask.id == task_id, not DPOTask.is_deleted, DPOTask.tenant_id == current_user.tenant_id).first()
    if not task:
        raise HTTPException(status_code=404, detail=t.t('dpo_task.not_found'))
    return task

@router.post('/', response_model=DPOTaskResponse)
def create_dpo_task(task_in: DPOTaskCreate, db: Session=Depends(get_db), current_user: User=Depends(require_create('dpo_tasks'))):
    task = DPOTask(**task_in.model_dump(), tenant_id=current_user.tenant_id, created_by=current_user.id)
    db.add(task)
    db.commit()
    db.refresh(task)
    logger.bind(task_id=task.id).info("DPO task created")
    return task

@router.put('/{task_id}', response_model=DPOTaskResponse)
def update_dpo_task(task_id: int, task_in: DPOTaskUpdate, request: Request, db: Session=Depends(get_db), current_user: User=Depends(require_update('dpo_tasks'))):
    t = get_translator(request)
    task = db.query(DPOTask).filter(DPOTask.id == task_id, not DPOTask.is_deleted, DPOTask.tenant_id == current_user.tenant_id).first()
    if not task:
        raise HTTPException(status_code=404, detail=t.t('dpo_task.not_found'))
    for field, value in task_in.model_dump(exclude_unset=True).items():
        setattr(task, field, value)
    db.commit()
    db.refresh(task)
    return task

@router.delete('/{task_id}')
def delete_dpo_task(task_id: int, request: Request, db: Session=Depends(get_db), current_user: User=Depends(require_delete('dpo_tasks'))):
    t = get_translator(request)
    task = db.query(DPOTask).filter(DPOTask.id == task_id, not DPOTask.is_deleted, DPOTask.tenant_id == current_user.tenant_id).first()
    if not task:
        raise HTTPException(status_code=404, detail=t.t('dpo_task.not_found'))
    task.is_deleted = True
    db.commit()
    logger.bind(task_id=task_id).info("DPO task deleted")
    return {'message': t.t('dpo_task.deleted')}

@router.post('/{task_id}/start')
async def start_dpo_training(task_id: int, background_tasks: BackgroundTasks, request: Request, db: Session=Depends(get_db), current_user: User=Depends(require_update('dpo_tasks'))):
    t = get_translator(request)
    task = db.query(DPOTask).filter(DPOTask.id == task_id, DPOTask.tenant_id == current_user.tenant_id).first()
    if not task:
        raise HTTPException(status_code=404, detail=t.t('dpo_task.not_found'))
    manager = TrainingManager(db)
    job_id = await manager.start_dpo_training(task_id)
    logger.bind(task_id=task_id, job_id=job_id).info("DPO training started")
    asyncio.create_task(sync_training_status(db, task_id, 'dpo'))
    return {'job_id': job_id, 'message': 'Training started'}

@router.get('/{task_id}/status')
def get_dpo_training_status(task_id: int, request: Request, db: Session=Depends(get_db), current_user: User=Depends(require_read('dpo_tasks'))):
    t = get_translator(request)
    task = db.query(DPOTask).filter(DPOTask.id == task_id, DPOTask.tenant_id == current_user.tenant_id).first()
    if not task:
        raise HTTPException(status_code=404, detail=t.t('dpo_task.not_found'))
    return {'status': task.status, 'progress': task.progress, 'logs': task.logs, 'error_message': task.error_message}

@router.get('/{task_id}/logs')
def get_dpo_training_logs(task_id: int, request: Request, db: Session=Depends(get_db), current_user: User=Depends(require_read('dpo_tasks'))):
    t = get_translator(request)
    task = db.query(DPOTask).filter(DPOTask.id == task_id, DPOTask.tenant_id == current_user.tenant_id).first()
    if not task:
        raise HTTPException(status_code=404, detail=t.t('dpo_task.not_found'))
    return {'logs': task.logs or ''}

@router.post('/{task_id}/cancel')
async def cancel_dpo_training(task_id: int, request: Request, db: Session=Depends(get_db), current_user: User=Depends(require_update('dpo_tasks'))):
    t = get_translator(request)
    task = db.query(DPOTask).filter(DPOTask.id == task_id, DPOTask.tenant_id == current_user.tenant_id).first()
    if not task:
        raise HTTPException(status_code=404, detail=t.t('dpo_task.not_found'))
    manager = TrainingManager(db)
    success = await manager.cancel_training(task_id, 'dpo')
    logger.bind(task_id=task_id, success=success).info("DPO training cancelled")
    return {'success': success, 'message': 'Training cancelled' if success else 'Failed to cancel'}
