import json
from enum import Enum

import redis

from common_logging import get_logger

logger = get_logger(__name__)
REDIS_KEY_STATUS = "llm:service:status"
REDIS_KEY_TRAINING_INFO = "llm:training:info"


class ServiceMode(str, Enum):
    INFERENCE = "inference"
    TRAINING = "training"
    MEDIA_PROCESSING = "media_processing"
    SWITCHING = "switching"


class ServiceStatusManager:

    def __init__(self, redis_client: redis.Redis):
        self._redis = redis_client

    def get_current_mode(self) -> ServiceMode:
        try:
            value = self._redis.get(REDIS_KEY_STATUS)
            if value is None:
                return ServiceMode.INFERENCE
            decoded = value.decode() if isinstance(value, bytes) else value
            return ServiceMode(decoded)
        except (redis.RedisError, ValueError) as e:
            logger.warning(f"Failed to read service status from Redis: {e}")
            return ServiceMode.INFERENCE

    def set_mode(self, mode: ServiceMode, training_info: dict | None = None) -> None:
        try:
            self._redis.set(REDIS_KEY_STATUS, mode.value)
            if mode == ServiceMode.TRAINING and training_info:
                self._redis.set(REDIS_KEY_TRAINING_INFO, json.dumps(training_info))
            elif mode == ServiceMode.INFERENCE:
                self._redis.delete(REDIS_KEY_TRAINING_INFO)
            logger.info(f"Service mode set to: {mode.value}")
        except redis.RedisError as e:
            logger.error(f"Failed to set service mode in Redis: {e}")
            raise

    def get_training_info(self) -> dict | None:
        try:
            value = self._redis.get(REDIS_KEY_TRAINING_INFO)
            if value is None:
                return None
            decoded = value.decode() if isinstance(value, bytes) else value
            return json.loads(decoded)
        except (redis.RedisError, json.JSONDecodeError, ValueError) as e:
            logger.warning(f"Failed to read training info from Redis: {e}")
            return None

    def is_inference_available(self) -> bool:
        return self.get_current_mode() == ServiceMode.INFERENCE


_status_manager: ServiceStatusManager | None = None


def get_status_manager() -> ServiceStatusManager:
    global _status_manager
    if _status_manager is None:
        from app.config import settings

        if settings.REDIS_URL:
            client = redis.from_url(
                settings.REDIS_URL, decode_responses=False, socket_connect_timeout=5
            )
        else:
            client = redis.Redis(
                host=settings.REDIS_HOST,
                port=settings.REDIS_PORT,
                db=settings.REDIS_DB,
                password=settings.REDIS_PASSWORD,
                decode_responses=False,
                socket_connect_timeout=5,
            )
        _status_manager = ServiceStatusManager(client)
    return _status_manager
