from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse

from common_logging import get_logger

logger = get_logger(__name__)
GUARDED_PREFIXES = ("/api/v1/chat", "/api/v1/completions")


class ServiceGuardMiddleware(BaseHTTPMiddleware):

    async def dispatch(self, request: Request, call_next):
        path = request.url.path
        if any(path.startswith(prefix) for prefix in GUARDED_PREFIXES):
            try:
                from app.services.llm.service_status_manager import ServiceMode, get_status_manager

                manager = get_status_manager()
                mode = manager.get_current_mode()
                if mode == ServiceMode.TRAINING:
                    training_info = manager.get_training_info()
                    body = {
                        "detail": "LLM service is currently in training mode. Inference is unavailable.",
                        "mode": "training",
                        "training_info": training_info,
                    }
                    logger.info(f"ServiceGuard blocked {path}: training mode active")
                    return JSONResponse(status_code=503, content=body)
                if mode == ServiceMode.SWITCHING:
                    body = {
                        "detail": "LLM service is switching modes. Please retry shortly.",
                        "mode": "switching",
                    }
                    logger.info(f"ServiceGuard blocked {path}: switching mode active")
                    return JSONResponse(status_code=503, content=body)
            except Exception as e:
                logger.error(f"ServiceGuardMiddleware error (failing open): {e}")
        return await call_next(request)
