from dataclasses import dataclass

from app.core.config import settings
from common_logging import get_logger, log_execution

logger = get_logger(__name__)


@dataclass
class TrainingConfig:
    learning_rate: float = 0.0001
    num_epochs: int = 3
    batch_size: int = 1
    gradient_accumulation_steps: int = 4
    max_seq_length: int = settings.QWEN_LOCAL_DEFAULT_MAX_SEQ_LENGTH
    num_layers: int = 16
    lora_rank: int = 16
    lora_alpha: int = 32
    lora_dropout: float = 0.05
    warmup_steps: int = 100
    save_steps: int = 500
    logging_steps: int = 10
    quantization_bits: int = 4

    @classmethod
    @log_execution(logger)
    def from_dict(cls, data: dict) -> 'TrainingConfig':
        config = cls(**{k: v for k, v in data.items() if k in cls.__annotations__})
        logger.info("Training config validated successfully")
        return config
