import sys
from pathlib import Path
from typing import Any

from app.core.config import settings
from common_logging import get_logger, log_execution

from .config import TrainingConfig
from .data_loader import load_dpo_dataset, load_sft_dataset
from .dpo_trainer import DPOTrainer
from .sft_trainer import SFTTrainer
from .utils import get_memory_usage, log_message, update_status

logger = get_logger(__name__)

DEFAULT_LOCAL_MODEL_DIR = Path(settings.QWEN_LOCAL_MODEL_CACHE_DIR) / 'Qwen3.5-4B-MLX'

@log_execution(logger)
def run_training(job_id: str, job_dir: str, config: dict[str, Any], db_url: str):
    job_dir = Path(job_dir)
    job_dir.mkdir(parents=True, exist_ok=True)
    try:
        logger.bind(job_id=job_id).info("Training job started")
        log_message(job_dir, f'Starting training job {job_id}')
        log_message(job_dir, f'Memory usage: {get_memory_usage():.2f} GB')
        task_type = config.get('task_type', 'sft')
        dataset_id = config['dataset_id']
        model_name = config.get('model_name') or str(DEFAULT_LOCAL_MODEL_DIR)
        if model_name == 'Qwen/Qwen3.5-4B':
            model_name = str(DEFAULT_LOCAL_MODEL_DIR)
        training_config = TrainingConfig.from_dict(config.get('hyperparameters', {}))
        update_status(job_dir, 'running', 0.05)
        if task_type == 'sft':
            log_message(job_dir, 'Loading SFT dataset')
            dataset = load_sft_dataset(dataset_id, db_url)
            trainer = SFTTrainer(model_name, training_config, job_dir)
        else:
            log_message(job_dir, 'Loading DPO dataset')
            dataset = load_dpo_dataset(dataset_id, db_url)
            trainer = DPOTrainer(model_name, training_config, job_dir)
        log_message(job_dir, f'Loaded {len(dataset)} samples')
        if not dataset:
            raise ValueError(f'Dataset {dataset_id} is empty or unreadable')
        trainer.train(dataset)
        logger.bind(job_id=job_id).info("Training job completed")
        log_message(job_dir, 'Training completed successfully')
    except Exception as e:
        import traceback
        logger.bind(job_id=job_id).error(f"Training job failed: {e}")
        log_message(job_dir, f'Error: {str(e)}')
        log_message(job_dir, f'Traceback: {traceback.format_exc()}')
        update_status(job_dir, 'failed', error=str(e))
        sys.exit(1)
