from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any

from common_logging import get_logger

logger = get_logger(__name__)


@dataclass
class TrainingConfig:
    name: str
    description: str
    dataset_id: str
    model_name: str
    hyperparameters: dict[str, Any]

@dataclass
class JobStatus:
    status: str
    progress: float
    message: str
    logs: str | None = None

@dataclass
class TrainingResult:
    model_id: str
    metrics: dict[str, Any]
    artifacts: dict[str, str]

class TrainingPlatform(ABC):

    @abstractmethod
    def create_training_job(self, config: TrainingConfig) -> str:
        pass

    @abstractmethod
    def get_job_status(self, job_id: str) -> JobStatus:
        pass

    @abstractmethod
    def get_job_result(self, job_id: str) -> TrainingResult:
        pass

    @abstractmethod
    def cancel_job(self, job_id: str) -> bool:
        pass

    def _log_job_started(self, job_id: str, config: TrainingConfig):
        logger.bind(job_id=job_id, model=config.model_name).info("Training job started")

    def _log_job_completed(self, job_id: str):
        logger.bind(job_id=job_id).info("Training job completed")

    def _log_job_failed(self, job_id: str, error: str):
        logger.bind(job_id=job_id).error(f"Training job failed: {error}")
