
from fastapi import APIRouter, Depends, Query
from pydantic import BaseModel, Field
from common_logging import get_logger

from app.services.monitoring.traffic_stats import TrafficStatsService, get_traffic_stats_service

logger = get_logger(__name__)
router = APIRouter()


class CurrentTrafficStats(BaseModel):
    requests_per_sec: float = Field(..., description="Requests per second")
    bytes_per_sec: float = Field(..., description="Bytes per second")
    timestamp: str = Field(..., description="Timestamp of measurement")


class WindowTrafficStats(BaseModel):
    total_requests: int = Field(..., description="Total requests in window")
    total_bytes: int = Field(..., description="Total bytes in window")
    request_bytes: int = Field(..., description="Total request bytes")
    response_bytes: int = Field(..., description="Total response bytes")
    avg_request_size: float = Field(..., description="Average request size")
    avg_response_size: float = Field(..., description="Average response size")


class TrafficStatsResponse(BaseModel):
    current: CurrentTrafficStats
    last_minute: WindowTrafficStats


class BandwidthUsageResponse(BaseModel):
    period: str = Field(..., description="Time period")
    total_bytes: int = Field(..., description="Total bytes transferred")
    total_requests: int = Field(..., description="Total requests")
    request_bytes: int = Field(..., description="Total request bytes")
    response_bytes: int = Field(..., description="Total response bytes")
    timestamp: str = Field(..., description="Timestamp")


class EndpointTrafficStats(BaseModel):
    endpoint: str = Field(..., description="Endpoint path")
    requests: int = Field(..., description="Total requests")
    total_bytes: int = Field(..., description="Total bytes")
    request_bytes: int = Field(..., description="Request bytes")
    response_bytes: int = Field(..., description="Response bytes")
    avg_request_size: float = Field(..., description="Average request size")
    avg_response_size: float = Field(..., description="Average response size")


class TenantTrafficStats(BaseModel):
    tenant_id: int = Field(..., description="Tenant ID")
    requests: int = Field(..., description="Total requests")
    total_bytes: int = Field(..., description="Total bytes")


@router.get(
    "/stats",
    response_model=TrafficStatsResponse,
    summary="获取实时流量统计",
    description="获取当前实时流量统计信息，包括每秒请求数、带宽使用等",
)
async def get_network_stats(
    traffic_service: TrafficStatsService = Depends(get_traffic_stats_service),
):
    stats = traffic_service.get_current_stats()
    return stats


@router.get(
    "/bandwidth",
    response_model=BandwidthUsageResponse,
    summary="获取带宽使用情况",
    description="获取指定时间段内的带宽使用统计",
)
async def get_bandwidth_usage(
    period: str = Query(
        "1hour", description="时间段: 1min, 5min, 15min, 1hour", regex="^(1min|5min|15min|1hour)$"
    ),
    traffic_service: TrafficStatsService = Depends(get_traffic_stats_service),
):
    usage = traffic_service.get_bandwidth_usage(period)
    return usage


@router.get(
    "/endpoints",
    response_model=list[EndpointTrafficStats],
    summary="获取各端点流量统计",
    description="获取按端点分组的流量统计信息",
)
async def get_traffic_by_endpoint(
    limit: int = Query(10, ge=1, le=100, description="返回的端点数量限制"),
    traffic_service: TrafficStatsService = Depends(get_traffic_stats_service),
):
    stats = traffic_service.get_traffic_by_endpoint(limit)
    return stats


@router.get(
    "/tenants",
    response_model=list[TenantTrafficStats],
    summary="获取各租户流量统计",
    description="获取按租户分组的流量统计信息",
)
async def get_traffic_by_tenant(
    limit: int = Query(10, ge=1, le=100, description="返回的租户数量限制"),
    traffic_service: TrafficStatsService = Depends(get_traffic_stats_service),
):
    stats = traffic_service.get_traffic_by_tenant(limit)
    return stats


@router.get(
    "/top-consumers",
    response_model=list[EndpointTrafficStats],
    summary="获取流量消耗排行",
    description="获取流量消耗最高的端点或用户",
)
async def get_top_consumers(
    limit: int = Query(10, ge=1, le=100, description="返回数量限制"),
    by: str = Query(
        "bandwidth", description="排序依据: bandwidth 或 requests", regex="^(bandwidth|requests)$"
    ),
    traffic_service: TrafficStatsService = Depends(get_traffic_stats_service),
):
    consumers = traffic_service.get_top_consumers(limit, by)
    return consumers
