from collections.abc import Callable

from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from common_logging import get_logger

logger = get_logger(__name__)


class NetworkTrafficMiddleware(BaseHTTPMiddleware):
    EXCLUDED_PATHS = ["/docs", "/redoc", "/openapi.json", "/health", "/favicon.ico"]

    def __init__(self, app, enabled: bool = True):
        super().__init__(app)
        self.enabled = enabled
        self._initialize_metrics()

    def _initialize_metrics(self):
        if not self.enabled:
            return
        try:
            from prometheus_client import Counter, Histogram

            self.request_size_bytes = Histogram(
                "http_request_size_bytes",
                "HTTP request size in bytes",
                ["method", "endpoint"],
                buckets=[100, 1000, 10000, 100000, 1000000, 10000000],
            )
            self.response_size_bytes = Histogram(
                "http_response_size_bytes",
                "HTTP response size in bytes",
                ["method", "endpoint", "status"],
                buckets=[100, 1000, 10000, 100000, 1000000, 10000000],
            )
            self.bandwidth_bytes_total = Counter(
                "network_bandwidth_bytes_total",
                "Total network bandwidth in bytes",
                ["direction", "endpoint"],
            )
            logger.info("Network traffic metrics initialized")
        except ImportError:
            logger.warning("prometheus_client not installed. Traffic metrics disabled.")
            self.enabled = False

    async def dispatch(self, request: Request, call_next: Callable) -> Response:
        if not self.enabled or any(
            request.url.path.startswith(path) for path in self.EXCLUDED_PATHS
        ):
            return await call_next(request)
        request_size = 0
        if request.method in ["POST", "PUT", "PATCH"]:
            try:
                body = await request.body()
                request_size = len(body)
                request.state.request_size = request_size
                endpoint = self._normalize_endpoint(request.url.path)
                self.request_size_bytes.labels(method=request.method, endpoint=endpoint).observe(
                    request_size
                )
                self.bandwidth_bytes_total.labels(direction="inbound", endpoint=endpoint).inc(
                    request_size
                )
            except Exception as e:
                logger.warning(f"Failed to measure request size: {e}")
        response = await call_next(request)
        content_type = (response.headers.get("content-type") or "").lower()
        if content_type.startswith("text/event-stream"):
            return response
        response_size = 0
        try:
            content_length = response.headers.get("content-length")
            if content_length:
                response_size = int(content_length)
            else:
                response_body = b""
                async for chunk in response.body_iterator:
                    response_body += chunk
                response_size = len(response_body)
                from starlette.responses import Response as StarletteResponse

                response = StarletteResponse(
                    content=response_body,
                    status_code=response.status_code,
                    headers=dict(response.headers),
                    media_type=response.media_type,
                )
            request.state.response_size = response_size
            endpoint = self._normalize_endpoint(request.url.path)
            self.response_size_bytes.labels(
                method=request.method, endpoint=endpoint, status=str(response.status_code)
            ).observe(response_size)
            self.bandwidth_bytes_total.labels(direction="outbound", endpoint=endpoint).inc(
                response_size
            )
            response.headers["X-Request-Size"] = str(request_size)
            response.headers["X-Response-Size"] = str(response_size)
            response.headers["X-Bandwidth-Total"] = str(request_size + response_size)
        except Exception as e:
            logger.error(f"Failed to measure response size: {e}")
        return response

    def _normalize_endpoint(self, path: str) -> str:
        parts = path.strip("/").split("/")
        normalized_parts = []
        for _i, part in enumerate(parts):
            if part.isdigit():
                normalized_parts.append("{id}")
            elif len(part) == 36 and part.count("-") == 4:
                normalized_parts.append("{uuid}")
            else:
                normalized_parts.append(part)
        return "/" + "/".join(normalized_parts)


def setup_network_traffic_monitoring(app, enabled: bool = True):
    if enabled:
        app.add_middleware(NetworkTrafficMiddleware, enabled=enabled)
        logger.info("Network traffic monitoring enabled")
    else:
        logger.info("Network traffic monitoring disabled")
