import asyncio

from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Request
from sqlalchemy.orm import Session

from app.api.deps import get_db
from app.api.permissions import require_create, require_delete, require_read, require_update
from app.core.i18n import get_translator
from app.models import SFTTask, User
from app.schemas.sft_task import SFTTaskCreate, SFTTaskResponse, SFTTaskUpdate
from app.services.training_manager import TrainingManager
from app.services.training_sync import sync_training_status

from common_logging import get_logger

logger = get_logger(__name__)

router = APIRouter(tags=['sft_tasks'])

@router.get('/', response_model=list[SFTTaskResponse])
def get_sft_tasks(skip: int=0, limit: int=100, db: Session=Depends(get_db), current_user: User=Depends(require_read('sft_tasks'))):
    tasks = db.query(SFTTask).filter(not SFTTask.is_deleted, SFTTask.tenant_id == current_user.tenant_id).offset(skip).limit(limit).all()
    return tasks

@router.get('/{task_id}', response_model=SFTTaskResponse)
def get_sft_task(task_id: int, request: Request, db: Session=Depends(get_db), current_user: User=Depends(require_read('sft_tasks'))):
    t = get_translator(request)
    task = db.query(SFTTask).filter(SFTTask.id == task_id, not SFTTask.is_deleted, SFTTask.tenant_id == current_user.tenant_id).first()
    if not task:
        raise HTTPException(status_code=404, detail=t.t('sft_task.not_found'))
    return task

@router.post('/', response_model=SFTTaskResponse)
def create_sft_task(task_in: SFTTaskCreate, db: Session=Depends(get_db), current_user: User=Depends(require_create('sft_tasks'))):
    task = SFTTask(**task_in.model_dump(), tenant_id=current_user.tenant_id, created_by=current_user.id)
    db.add(task)
    db.commit()
    db.refresh(task)
    logger.bind(task_id=task.id).info("SFT task created")
    return task

@router.put('/{task_id}', response_model=SFTTaskResponse)
def update_sft_task(task_id: int, task_in: SFTTaskUpdate, request: Request, db: Session=Depends(get_db), current_user: User=Depends(require_update('sft_tasks'))):
    t = get_translator(request)
    task = db.query(SFTTask).filter(SFTTask.id == task_id, not SFTTask.is_deleted, SFTTask.tenant_id == current_user.tenant_id).first()
    if not task:
        raise HTTPException(status_code=404, detail=t.t('sft_task.not_found'))
    for field, value in task_in.model_dump(exclude_unset=True).items():
        setattr(task, field, value)
    db.commit()
    db.refresh(task)
    return task

@router.delete('/{task_id}')
def delete_sft_task(task_id: int, request: Request, db: Session=Depends(get_db), current_user: User=Depends(require_delete('sft_tasks'))):
    t = get_translator(request)
    task = db.query(SFTTask).filter(SFTTask.id == task_id, not SFTTask.is_deleted, SFTTask.tenant_id == current_user.tenant_id).first()
    if not task:
        raise HTTPException(status_code=404, detail=t.t('sft_task.not_found'))
    task.is_deleted = True
    db.commit()
    logger.bind(task_id=task_id).info("SFT task deleted")
    return {'message': t.t('sft_task.deleted')}

@router.post('/{task_id}/start')
async def start_sft_training(task_id: int, background_tasks: BackgroundTasks, request: Request, db: Session=Depends(get_db), current_user: User=Depends(require_update('sft_tasks'))):
    t = get_translator(request)
    task = db.query(SFTTask).filter(SFTTask.id == task_id, SFTTask.tenant_id == current_user.tenant_id).first()
    if not task:
        raise HTTPException(status_code=404, detail=t.t('sft_task.not_found'))
    manager = TrainingManager(db)
    job_id = await manager.start_sft_training(task_id)
    logger.bind(task_id=task_id, job_id=job_id).info("SFT training started")
    asyncio.create_task(sync_training_status(db, task_id, 'sft'))
    return {'job_id': job_id, 'message': 'Training started'}

@router.get('/{task_id}/status')
def get_sft_training_status(task_id: int, request: Request, db: Session=Depends(get_db), current_user: User=Depends(require_read('sft_tasks'))):
    t = get_translator(request)
    task = db.query(SFTTask).filter(SFTTask.id == task_id, SFTTask.tenant_id == current_user.tenant_id).first()
    if not task:
        raise HTTPException(status_code=404, detail=t.t('sft_task.not_found'))
    return {'status': task.status, 'progress': task.progress, 'logs': task.logs, 'error_message': task.error_message}

@router.get('/{task_id}/logs')
def get_sft_training_logs(task_id: int, request: Request, db: Session=Depends(get_db), current_user: User=Depends(require_read('sft_tasks'))):
    t = get_translator(request)
    task = db.query(SFTTask).filter(SFTTask.id == task_id, SFTTask.tenant_id == current_user.tenant_id).first()
    if not task:
        raise HTTPException(status_code=404, detail=t.t('sft_task.not_found'))
    return {'logs': task.logs or ''}

@router.post('/{task_id}/cancel')
async def cancel_sft_training(task_id: int, request: Request, db: Session=Depends(get_db), current_user: User=Depends(require_update('sft_tasks'))):
    t = get_translator(request)
    task = db.query(SFTTask).filter(SFTTask.id == task_id, SFTTask.tenant_id == current_user.tenant_id).first()
    if not task:
        raise HTTPException(status_code=404, detail=t.t('sft_task.not_found'))
    manager = TrainingManager(db)
    success = await manager.cancel_training(task_id, 'sft')
    logger.bind(task_id=task_id, success=success).info("SFT training cancelled")
    return {'success': success, 'message': 'Training cancelled' if success else 'Failed to cancel'}
