import importlib

from app.core.config import settings
from common_logging import get_logger

from .base import TrainingPlatform

logger = get_logger(__name__)

PLATFORMS = {'auto': 'platforms.auto.AutoPlatform', 'mlx': 'platforms.mlx.MLXPlatform', 'llamafactory': 'platforms.llamafactory.LlamaFactoryPlatform', 'qwen_local': 'qwen_local.QwenLocalPlatform', 'volcengine': 'volcengine.VolcenginePlatform', 'aliyun': 'aliyun.AliyunPlatform', 'mock': 'mock.MockPlatform'}

def get_training_platform(platform_name: str=None) -> TrainingPlatform:
    platform = platform_name or settings.DEFAULT_TRAINING_PLATFORM
    if platform not in PLATFORMS:
        raise ValueError(f'Unknown platform: {platform}. Available: {list(PLATFORMS.keys())}')
    module_path, class_name = PLATFORMS[platform].rsplit('.', 1)
    module = importlib.import_module(f'app.services.training_platform.{module_path}')
    logger.info(f'Loading training platform: {platform}')
    return getattr(module, class_name)()
