#!/usr/bin/env bash
# CLI training with LLaMA-Factory (SFT or DPO)
# Usage: bash run_training.sh [sft|dpo] [config_file]
#   e.g. bash run_training.sh sft
#   e.g. bash run_training.sh dpo /path/to/custom_dpo.yaml

set -euo pipefail

LLM_SERVICE_DIR="/lsinfo/ai/hellotax_ai/llm_service"
INSTALL_DIR="${LLM_SERVICE_DIR}/LLaMA-Factory"
LOG_DIR="${LLM_SERVICE_DIR}/logs"
MODEL_DIR="${LLM_SERVICE_DIR}/base_models"
OUTPUT_DIR="${LLM_SERVICE_DIR}/trained_models"
CONFIG_DIR="${LLM_SERVICE_DIR}/training_configs"

TRAIN_TYPE="${1:-sft}"
CONFIG_FILE="${2:-}"

mkdir -p "${LOG_DIR}" "${OUTPUT_DIR}" "${CONFIG_DIR}"

if [[ ! -d "${INSTALL_DIR}" ]]; then
  echo "ERROR: LLaMA-Factory not found. Run install_llamafactory.sh first."
  exit 1
fi

cd "${INSTALL_DIR}"

# Generate default config if none provided
if [[ -z "${CONFIG_FILE}" ]]; then
  if [[ "${TRAIN_TYPE}" == "sft" ]]; then
    CONFIG_FILE="${CONFIG_DIR}/qwen_sft.yaml"
    if [[ ! -f "${CONFIG_FILE}" ]]; then
      cat > "${CONFIG_FILE}" <<EOF
### Model
model_name_or_path: ${MODEL_DIR}/Qwen3.6-27B
trust_remote_code: true

### Method
stage: sft
do_train: true
finetuning_type: lora
lora_rank: 16
lora_target: all
gradient_checkpointing: true
optim: adamw_8bit

### Dataset
dataset: identity,alpaca_en_demo
template: qwen
cutoff_len: 4096
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 4

### Output
output_dir: ${OUTPUT_DIR}/qwen3.6-27B-sft-lora
logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true

### Train
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
learning_rate: 1.0e-4
num_train_epochs: 3
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000
ddp_find_unused_parameters: false
EOF
      echo "Default SFT config generated: ${CONFIG_FILE}"
    fi
  elif [[ "${TRAIN_TYPE}" == "dpo" ]]; then
    CONFIG_FILE="${CONFIG_DIR}/qwen_dpo.yaml"
    if [[ ! -f "${CONFIG_FILE}" ]]; then
      cat > "${CONFIG_FILE}" <<EOF
### Model
model_name_or_path: ${MODEL_DIR}/Qwen3.6-27B
trust_remote_code: true

### Method
stage: dpo
do_train: true
finetuning_type: lora
lora_rank: 16
lora_target: all
gradient_checkpointing: true
optim: adamw_8bit
pref_beta: 0.1
pref_loss: sigmoid

### Dataset
dataset: dpo_mix_en
template: qwen
cutoff_len: 4096
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 4

### Output
output_dir: ${OUTPUT_DIR}/qwen3.6-27B-dpo-lora
logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true

### Train
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
learning_rate: 5.0e-5
num_train_epochs: 3
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000
ddp_find_unused_parameters: false
EOF
      echo "Default DPO config generated: ${CONFIG_FILE}"
    fi
  else
    echo "ERROR: Unknown train type '${TRAIN_TYPE}'. Use: sft or dpo"
    exit 1
  fi
fi

TIMESTAMP=$(date +"%Y%m%d_%H%M%S")
LOG_FILE="${LOG_DIR}/training_${TRAIN_TYPE}_${TIMESTAMP}.log"

echo "=========================================="
echo "Starting ${TRAIN_TYPE^^} training..."
echo "Config: ${CONFIG_FILE}"
echo "Log:    ${LOG_FILE}"
echo "=========================================="

torchrun --nproc_per_node=2 \
  "${INSTALL_DIR}/src/llamafactory/launcher.py" "${CONFIG_FILE}" 2>&1 | tee "${LOG_FILE}"

echo ""
echo "Training complete. Output: ${OUTPUT_DIR}"
