from fastapi import Request
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.errors import RateLimitExceeded
from slowapi.util import get_remote_address

from app.config import settings
from common_logging import get_logger

logger = get_logger(__name__)


def custom_key_func(request: Request) -> str:
    if request.method == "OPTIONS":
        return "skip"
    return get_remote_address(request)


def get_limiter() -> Limiter:
    storage_uri = "memory://"
    if hasattr(settings, "REDIS_URL") and settings.REDIS_URL:
        try:
            storage_uri = settings.REDIS_URL
            logger.info(f"Using Redis for rate limiting: {settings.REDIS_URL}")
        except Exception as e:
            logger.warning(
                f"Failed to configure Redis for rate limiting, falling back to memory: {e}"
            )
            storage_uri = "memory://"
    elif hasattr(settings, "REDIS_HOST") and settings.REDIS_HOST:
        try:
            redis_password = getattr(settings, "REDIS_PASSWORD", None)
            redis_db = getattr(settings, "REDIS_DB", 0)
            if redis_password:
                storage_uri = f"redis://:{redis_password}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{redis_db}"
            else:
                storage_uri = f"redis://{settings.REDIS_HOST}:{settings.REDIS_PORT}/{redis_db}"
            logger.info(
                f"Using Redis for rate limiting: {settings.REDIS_HOST}:{settings.REDIS_PORT}"
            )
        except Exception as e:
            logger.warning(
                f"Failed to configure Redis for rate limiting, falling back to memory: {e}"
            )
            storage_uri = "memory://"
    else:
        logger.warning(
            "Redis not configured, using in-memory rate limiting (not suitable for production)"
        )
    limiter = Limiter(
        key_func=custom_key_func,
        default_limits=[
            f"{settings.RATE_LIMIT_PER_MINUTE}/minute",
            f"{settings.RATE_LIMIT_PER_HOUR}/hour",
        ],
        enabled=settings.RATE_LIMIT_ENABLED,
        storage_uri=storage_uri,
    )
    return limiter


limiter = get_limiter()


def setup_rate_limiting(app):
    if settings.RATE_LIMIT_ENABLED:
        app.state.limiter = limiter
        app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
