import json
import multiprocessing as mp
from datetime import datetime
from pathlib import Path
from typing import Any

import psutil
from common_logging import get_logger, log_execution, log_performance

from .base import JobStatus, TrainingConfig, TrainingPlatform, TrainingResult

logger = get_logger(__name__)


class QwenLocalPlatform(TrainingPlatform):
    DEFAULT_MODEL_NAME = 'Qwen3.5-4B-MLX'
    MODEL_ALIASES = {'Qwen/Qwen3.5-4B': 'Qwen3.5-4B-MLX'}

    def __init__(self):
        from app.core.config import settings
        self.base_dir = Path(settings.QWEN_LOCAL_TRAINING_BASE_DIR)
        self.model_cache_dir = Path(settings.QWEN_LOCAL_MODEL_CACHE_DIR)
        self.max_concurrent = settings.QWEN_LOCAL_MAX_CONCURRENT_JOBS
        self.max_memory_gb = settings.QWEN_LOCAL_MAX_MEMORY_GB
        self.min_free_memory_gb = settings.QWEN_LOCAL_MIN_FREE_MEMORY_GB
        self.base_dir.mkdir(parents=True, exist_ok=True)
        self.processes: dict[str, mp.Process] = {}

    def _normalize_model_dir_name(self, model_name: str | None) -> str:
        if not model_name:
            return self.DEFAULT_MODEL_NAME
        if model_name in self.MODEL_ALIASES:
            return self.MODEL_ALIASES[model_name]
        model_path = Path(model_name)
        return model_path.name or self.DEFAULT_MODEL_NAME

    def _resolve_model_path(self, model_name: str | None) -> Path:
        model_dir_name = self._normalize_model_dir_name(model_name)
        candidate = Path(model_name) if model_name else None
        if candidate and candidate.exists():
            return candidate
        return self.model_cache_dir / model_dir_name

    def _get_job_dir(self, job_id: str, model_dir_name: str | None=None) -> Path:
        if model_dir_name:
            return self.base_dir / model_dir_name / 'jobs' / job_id
        for candidate in self.base_dir.glob(f'*/jobs/{job_id}'):
            return candidate
        return self.base_dir / self.DEFAULT_MODEL_NAME / 'jobs' / job_id

    @log_execution(logger)
    @log_performance(logger, threshold_ms=5000)
    def create_training_job(self, config: TrainingConfig) -> str:
        if not self._check_memory():
            raise RuntimeError(f'Insufficient memory (limit: {self.max_memory_gb}GB)')
        if not self._check_concurrent_limit():
            raise RuntimeError(f'Max concurrent jobs reached ({self.max_concurrent})')
        job_id = f"job_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
        model_dir_name = self._normalize_model_dir_name(config.model_name)
        job_dir = self._get_job_dir(job_id, model_dir_name)
        job_dir.mkdir(parents=True, exist_ok=True)
        logger.bind(job_id=job_id).info("Creating local training job")
        task_type = (config.hyperparameters or {}).get('task_type', 'sft')
        model_name = str(self._resolve_model_path(config.model_name))
        job_config = {'task_type': task_type, 'dataset_id': int(config.dataset_id), 'model_name': model_name, 'model_dir_name': model_dir_name, 'hyperparameters': config.hyperparameters or {}}
        from app.core.config import settings
        db_url = settings.DATABASE_URL
        process = mp.Process(target=self._run_training_process, args=(job_id, str(job_dir), job_config, db_url))
        process.start()
        self.processes[job_id] = process
        logger.bind(job_id=job_id).info("Training process started")
        return job_id

    def get_job_status(self, job_id: str) -> JobStatus:
        job_dir = self._get_job_dir(job_id)
        status_file = job_dir / 'status.json'
        if not status_file.exists():
            if job_dir.exists():
                return JobStatus(status='running', progress=0.0, message='Training starting')
            return JobStatus(status='pending', progress=0.0, message='Job not found')
        data = json.loads(status_file.read_text())
        log_file = job_dir / 'training.log'
        logs = log_file.read_text() if log_file.exists() else None
        return JobStatus(status=data.get('status', 'running'), progress=data.get('progress', 0.0), message=data.get('error', 'Training in progress'), logs=logs)

    def get_job_result(self, job_id: str) -> TrainingResult:
        job_dir = self._get_job_dir(job_id)
        model_dir = job_dir / 'final_model'
        if not model_dir.exists():
            raise ValueError('Training not completed')
        return TrainingResult(model_id=job_id, metrics={}, artifacts={'model_path': str(model_dir)})

    def cancel_job(self, job_id: str) -> bool:
        if job_id in self.processes:
            process = self.processes[job_id]
            if process.is_alive():
                process.terminate()
                process.join(timeout=5)
                if process.is_alive():
                    process.kill()
            del self.processes[job_id]
            logger.bind(job_id=job_id).info("Training job cancelled")
            return True
        return False

    def _check_memory(self) -> bool:
        memory = psutil.virtual_memory()
        used_gb = memory.used / 1024 ** 3
        available_gb = memory.available / 1024 ** 3
        return used_gb < self.max_memory_gb * 0.9 and available_gb >= self.min_free_memory_gb

    def _check_concurrent_limit(self) -> bool:
        active = sum(1 for p in self.processes.values() if p.is_alive())
        return active < self.max_concurrent

    @staticmethod
    def _run_training_process(job_id: str, job_dir: str, config: dict[str, Any], db_url: str):
        from .qwen.trainer import run_training
        run_training(job_id, job_dir, config, db_url)
