import sys
from pathlib import Path

sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent.parent))
from app.services.training_platform.qwen.config import TrainingConfig
from app.services.training_platform.qwen.dpo_trainer import DPOTrainer
from common_logging import get_logger

logger = get_logger(__name__)

job_dir = Path('app/services/training_platform/models/jobs/smoke_dpo')
job_dir.mkdir(parents=True, exist_ok=True)
model_dir = Path('app/services/training_platform/models/base_models/Qwen/Qwen3___5-9B')
config = TrainingConfig(num_epochs=1, logging_steps=1)
dataset = [{'prompt': '请写一句欢迎词', 'chosen': '欢迎使用本地训练平台。', 'rejected': '不知道。'}]
trainer = DPOTrainer(str(model_dir), config, job_dir)
trainer.train(dataset)
logger.info('SMOKE_DPO_OK')
print('SMOKE_DPO_OK')
