import json
from collections.abc import Callable
from pathlib import Path

import mlx.core as mx
import mlx.optimizers as optim
from mlx.utils import tree_flatten
from mlx_lm import load, lora
from common_logging import get_logger, log_execution

from .config import TrainingConfig
from .utils import get_memory_usage, log_message, update_status

logger = get_logger(__name__)


class SFTTrainer:

    def __init__(self, model_name: str, config: TrainingConfig, job_dir: Path):
        self.model_name = model_name
        self.config = config
        self.job_dir = job_dir
        self.model = None
        self.tokenizer = None

    @log_execution(logger)
    def train(self, dataset: list[dict[str, str]], progress_callback: Callable | None=None):
        try:
            logger.bind(job_id=self.job_dir.name).info("SFT training started")
            log_message(self.job_dir, f'Loading model: {self.model_name}')
            log_message(self.job_dir, f'Memory: {get_memory_usage():.2f}GB')
            update_status(self.job_dir, 'running', 0.1)
            self.model, self.tokenizer = load(self.model_name)
            log_message(self.job_dir, f'Applying LoRA (rank={self.config.lora_rank}, alpha={self.config.lora_alpha})')
            self.model.freeze()
            lora.linear_to_lora_layers(self.model, num_layers=self.config.num_layers, config={'rank': self.config.lora_rank, 'alpha': self.config.lora_alpha, 'scale': self.config.lora_alpha / self.config.lora_rank, 'dropout': self.config.lora_dropout})
            trainable_params = sum((v.size for k, v in tree_flatten(self.model.trainable_parameters())))
            log_message(self.job_dir, f'Trainable parameters: {trainable_params:,}')
            log_message(self.job_dir, f'Model loaded. Training on {len(dataset)} samples')
            optim.Adam(learning_rate=self.config.learning_rate)
            total_steps = len(dataset) * self.config.num_epochs
            current_step = 0
            for epoch in range(self.config.num_epochs):
                log_message(self.job_dir, f'Epoch {epoch + 1}/{self.config.num_epochs}')
                for sample in dataset:
                    current_step += 1
                    if not isinstance(sample, dict):
                        log_message(self.job_dir, f'Error: Sample is {type(sample)}, skipping')
                        continue
                    prompt = sample.get('prompt', '')
                    completion = sample.get('completion', '')
                    text = f'{prompt}\n{completion}'
                    tokens = self.tokenizer.encode(text)
                    if len(tokens) > self.config.max_seq_length:
                        tokens = tokens[:self.config.max_seq_length]
                    progress = 0.1 + current_step / total_steps * 0.8
                    if current_step % self.config.logging_steps == 0:
                        log_message(self.job_dir, f'Step {current_step}/{total_steps}')
                        log_message(self.job_dir, f'Memory: {get_memory_usage():.2f}GB')
                        update_status(self.job_dir, 'running', progress)
                        if progress_callback:
                            progress_callback(progress)
            output_dir = self.job_dir / 'final_model'
            output_dir.mkdir(exist_ok=True)
            all_params = self.model.trainable_parameters()
            log_message(self.job_dir, f'Total trainable param groups: {len(all_params)}')
            flat_params = dict(tree_flatten(all_params))
            log_message(self.job_dir, f'Flattened to {len(flat_params)} arrays')
            param_names = list(flat_params.keys())[:10]
            log_message(self.job_dir, f'Sample param names: {param_names}')
            if flat_params:
                mx.save_safetensors(str(output_dir / 'adapters.safetensors'), flat_params)
                log_message(self.job_dir, f'Saved {len(flat_params)} adapter weights')
            else:
                log_message(self.job_dir, 'Warning: No trainable weights found to save')
            adapter_config = {'fine_tune_type': 'lora', 'model': self.model_name, 'base_model': self.model_name, 'model_type': 'qwen3_5', 'num_layers': self.config.num_layers, 'lora_parameters': {'rank': self.config.lora_rank, 'dropout': self.config.lora_dropout, 'scale': self.config.lora_alpha / self.config.lora_rank}, 'lora_rank': self.config.lora_rank, 'lora_alpha': self.config.lora_alpha, 'lora_dropout': self.config.lora_dropout}
            with open(output_dir / 'adapter_config.json', 'w') as f:
                json.dump(adapter_config, f, indent=2)
            log_message(self.job_dir, f'Training completed. Model saved to {output_dir}')
            update_status(self.job_dir, 'completed', 1.0)
            logger.bind(job_id=self.job_dir.name).info("SFT training completed")
        except Exception as e:
            logger.bind(job_id=self.job_dir.name).error(f"SFT training failed: {e}")
            log_message(self.job_dir, f'Training failed: {str(e)}')
            update_status(self.job_dir, 'failed', error=str(e))
            raise
