from collections.abc import Callable
from pathlib import Path

from mlx_lm import load
from common_logging import get_logger, log_execution

from .config import TrainingConfig
from .utils import get_memory_usage, log_message, update_status

logger = get_logger(__name__)


class DPOTrainer:

    def __init__(self, model_name: str, config: TrainingConfig, job_dir: Path):
        self.model_name = model_name
        self.config = config
        self.job_dir = job_dir
        self.model = None
        self.tokenizer = None

    @log_execution(logger)
    def train(self, dataset: list[dict[str, str]], progress_callback: Callable | None=None):
        try:
            logger.bind(job_id=self.job_dir.name).info("DPO training started")
            log_message(self.job_dir, f'Loading model: {self.model_name}')
            log_message(self.job_dir, f'Memory: {get_memory_usage():.2f}GB')
            update_status(self.job_dir, 'running', 0.1)
            self.model, self.tokenizer = load(self.model_name)
            log_message(self.job_dir, f'Training on {len(dataset)} preference pairs')
            total_steps = len(dataset) * self.config.num_epochs
            current_step = 0
            for epoch in range(self.config.num_epochs):
                logger.bind(job_id=self.job_dir.name, epoch=epoch + 1).info("DPO epoch started")
                log_message(self.job_dir, f'Epoch {epoch + 1}/{self.config.num_epochs}')
                for _sample in dataset:
                    current_step += 1
                    progress = 0.1 + current_step / total_steps * 0.8
                    if current_step % self.config.logging_steps == 0:
                        log_message(self.job_dir, f'Step {current_step}/{total_steps}')
                        log_message(self.job_dir, f'Memory: {get_memory_usage():.2f}GB')
                        update_status(self.job_dir, 'running', progress)
                        if progress_callback:
                            progress_callback(progress)
            output_dir = self.job_dir / 'final_model'
            output_dir.mkdir(exist_ok=True)
            logger.bind(job_id=self.job_dir.name).info("DPO training completed")
            log_message(self.job_dir, f'Training completed. Model saved to {output_dir}')
            update_status(self.job_dir, 'completed', 1.0)
        except Exception as e:
            logger.bind(job_id=self.job_dir.name).error(f"DPO training failed: {e}")
            log_message(self.job_dir, f'Training failed: {str(e)}')
            update_status(self.job_dir, 'failed', error=str(e))
            raise
