import os

import httpx

from ..base import TrainingConfig
from ..qwen_local import QwenLocalPlatform

from common_logging import get_logger

logger = get_logger(__name__)


class LlamaFactoryPlatform(QwenLocalPlatform):

    def create_training_job(self, config: TrainingConfig) -> str:
        self._notify_switch_mode('training', {'task_name': config.name, 'dataset_id': config.dataset_id})
        job_id = super().create_training_job(config)
        logger.bind(job_id=job_id, model=config.model_name).info("LlamaFactory training job created")
        return job_id

    def _on_training_complete(self, job_id: str):
        logger.bind(job_id=job_id).info("LlamaFactory training completed")
        pending = self._get_pending_jobs()
        if not pending:
            self._notify_switch_mode('inference')

    def _notify_switch_mode(self, mode: str, training_info: dict=None):
        try:
            token = os.getenv('INTERNAL_API_TOKEN', '')
            if not token:
                return
            payload = {'mode': mode}
            if training_info:
                payload['training_info'] = training_info
            httpx.post('http://localhost:8000/internal/switch_mode', json=payload, headers={'X-Internal-Token': token}, timeout=10.0)
        except Exception as e:
            logger.warning(f"Failed to switch mode to {mode}: {e}")

    def _get_pending_jobs(self) -> list:
        try:
            response = httpx.get('http://localhost:8001/api/v1/training-tasks?status=pending', timeout=5.0)
            if response.status_code == 200:
                return response.json()
        except Exception:
            logger.warning("Failed to fetch pending jobs")
        return []
