import time
from datetime import datetime
from typing import Any

from common_logging import get_logger

logger = get_logger(__name__)




class TrafficStatsService:

    def __init__(self, redis_client=None, key_prefix: str = "traffic:"):
        self.redis_client = redis_client
        self.key_prefix = key_prefix
        self.enabled = redis_client is not None
        if not self.enabled:
            logger.warning("Redis client not provided. Traffic stats will be disabled.")

    def record_request(
        self,
        endpoint: str,
        method: str,
        request_size: int,
        response_size: int,
        status_code: int,
        duration_ms: int,
        tenant_id: int | None = None,
        user_id: int | None = None,
    ):
        if not self.enabled:
            return
        try:
            current_time = int(time.time())
            total_bandwidth = request_size + response_size
            self._increment_counter("requests:total", 1)
            self._increment_counter("bandwidth:total", total_bandwidth)
            for window in ["1min", "5min", "15min", "1hour"]:
                window_key = f"{self.key_prefix}window:{window}:{current_time // self._window_seconds(window)}"
                self.redis_client.hincrby(window_key, "requests", 1)
                self.redis_client.hincrby(window_key, "bandwidth", total_bandwidth)
                self.redis_client.hincrby(window_key, "request_bytes", request_size)
                self.redis_client.hincrby(window_key, "response_bytes", response_size)
                self.redis_client.expire(window_key, self._window_seconds(window) * 2)
            endpoint_key = f"{self.key_prefix}endpoint:{endpoint}"
            self.redis_client.hincrby(endpoint_key, "requests", 1)
            self.redis_client.hincrby(endpoint_key, "bandwidth", total_bandwidth)
            self.redis_client.hincrby(endpoint_key, "request_bytes", request_size)
            self.redis_client.hincrby(endpoint_key, "response_bytes", response_size)
            self.redis_client.expire(endpoint_key, 86400)
            if tenant_id:
                tenant_key = f"{self.key_prefix}tenant:{tenant_id}"
                self.redis_client.hincrby(tenant_key, "requests", 1)
                self.redis_client.hincrby(tenant_key, "bandwidth", total_bandwidth)
                self.redis_client.expire(tenant_key, 86400)
            if user_id:
                user_key = f"{self.key_prefix}user:{user_id}"
                self.redis_client.hincrby(user_key, "requests", 1)
                self.redis_client.hincrby(user_key, "bandwidth", total_bandwidth)
                self.redis_client.expire(user_key, 86400)
            self.redis_client.zadd(
                f"{self.key_prefix}endpoints:by_bandwidth", {endpoint: total_bandwidth}, incr=True
            )
        except Exception as e:
            logger.error(f"Failed to record traffic stats: {e}")

    def get_current_stats(self) -> dict[str, Any]:
        if not self.enabled:
            return self._empty_stats()
        try:
            current_time = int(time.time())
            window_key = f"{self.key_prefix}window:1min:{current_time // 60}"
            window_data = self.redis_client.hgetall(window_key)
            requests = int(window_data.get(b"requests", 0))
            bandwidth = int(window_data.get(b"bandwidth", 0))
            request_bytes = int(window_data.get(b"request_bytes", 0))
            response_bytes = int(window_data.get(b"response_bytes", 0))
            requests_per_sec = requests / 60.0
            bytes_per_sec = bandwidth / 60.0
            return {
                "current": {
                    "requests_per_sec": round(requests_per_sec, 2),
                    "bytes_per_sec": round(bytes_per_sec, 2),
                    "timestamp": datetime.utcnow().isoformat(),
                },
                "last_minute": {
                    "total_requests": requests,
                    "total_bytes": bandwidth,
                    "request_bytes": request_bytes,
                    "response_bytes": response_bytes,
                    "avg_request_size": round(request_bytes / requests, 2) if requests > 0 else 0,
                    "avg_response_size": round(response_bytes / requests, 2) if requests > 0 else 0,
                },
            }
        except Exception as e:
            logger.error(f"Failed to get current stats: {e}")
            return self._empty_stats()

    def get_bandwidth_usage(self, period: str = "1hour") -> dict[str, Any]:
        if not self.enabled:
            return {"period": period, "total_bytes": 0, "total_requests": 0}
        try:
            current_time = int(time.time())
            window_seconds = self._window_seconds(period)
            window_id = current_time // window_seconds
            window_key = f"{self.key_prefix}window:{period}:{window_id}"
            window_data = self.redis_client.hgetall(window_key)
            return {
                "period": period,
                "total_bytes": int(window_data.get(b"bandwidth", 0)),
                "total_requests": int(window_data.get(b"requests", 0)),
                "request_bytes": int(window_data.get(b"request_bytes", 0)),
                "response_bytes": int(window_data.get(b"response_bytes", 0)),
                "timestamp": datetime.utcnow().isoformat(),
            }
        except Exception as e:
            logger.error(f"Failed to get bandwidth usage: {e}")
            return {"period": period, "total_bytes": 0, "total_requests": 0}

    def get_traffic_by_endpoint(self, limit: int = 10) -> list[dict[str, Any]]:
        if not self.enabled:
            return []
        try:
            top_endpoints = self.redis_client.zrevrange(
                f"{self.key_prefix}endpoints:by_bandwidth", 0, limit - 1, withscores=True
            )
            results = []
            for endpoint_bytes, _total_bandwidth in top_endpoints:
                endpoint = endpoint_bytes.decode("utf-8")
                endpoint_key = f"{self.key_prefix}endpoint:{endpoint}"
                endpoint_data = self.redis_client.hgetall(endpoint_key)
                requests = int(endpoint_data.get(b"requests", 0))
                bandwidth = int(endpoint_data.get(b"bandwidth", 0))
                request_bytes = int(endpoint_data.get(b"request_bytes", 0))
                response_bytes = int(endpoint_data.get(b"response_bytes", 0))
                results.append(
                    {
                        "endpoint": endpoint,
                        "requests": requests,
                        "total_bytes": bandwidth,
                        "request_bytes": request_bytes,
                        "response_bytes": response_bytes,
                        "avg_request_size": (
                            round(request_bytes / requests, 2) if requests > 0 else 0
                        ),
                        "avg_response_size": (
                            round(response_bytes / requests, 2) if requests > 0 else 0
                        ),
                    }
                )
            return results
        except Exception as e:
            logger.error(f"Failed to get traffic by endpoint: {e}")
            return []

    def get_traffic_by_tenant(self, limit: int = 10) -> list[dict[str, Any]]:
        if not self.enabled:
            return []
        try:
            tenant_keys = []
            cursor = 0
            while True:
                cursor, keys = self.redis_client.scan(
                    cursor, match=f"{self.key_prefix}tenant:*", count=100
                )
                tenant_keys.extend(keys)
                if cursor == 0:
                    break
            results = []
            for tenant_key in tenant_keys[:limit]:
                tenant_id = tenant_key.decode("utf-8").split(":")[-1]
                tenant_data = self.redis_client.hgetall(tenant_key)
                requests = int(tenant_data.get(b"requests", 0))
                bandwidth = int(tenant_data.get(b"bandwidth", 0))
                results.append(
                    {"tenant_id": int(tenant_id), "requests": requests, "total_bytes": bandwidth}
                )
            results.sort(key=lambda x: x["total_bytes"], reverse=True)
            return results[:limit]
        except Exception as e:
            logger.error(f"Failed to get traffic by tenant: {e}")
            return []

    def get_top_consumers(self, limit: int = 10, by: str = "bandwidth") -> list[dict[str, Any]]:
        if not self.enabled:
            return []
        try:
            return self.get_traffic_by_endpoint(limit)
        except Exception as e:
            logger.error(f"Failed to get top consumers: {e}")
            return []

    def _increment_counter(self, key: str, value: int):
        full_key = f"{self.key_prefix}{key}"
        self.redis_client.incrby(full_key, value)

    def _window_seconds(self, window: str) -> int:
        windows = {"1min": 60, "5min": 300, "15min": 900, "1hour": 3600}
        return windows.get(window, 60)

    def _empty_stats(self) -> dict[str, Any]:
        return {
            "current": {
                "requests_per_sec": 0,
                "bytes_per_sec": 0,
                "timestamp": datetime.utcnow().isoformat(),
            },
            "last_minute": {
                "total_requests": 0,
                "total_bytes": 0,
                "request_bytes": 0,
                "response_bytes": 0,
                "avg_request_size": 0,
                "avg_response_size": 0,
            },
        }


_traffic_stats_service: TrafficStatsService | None = None


def get_traffic_stats_service() -> TrafficStatsService:
    global _traffic_stats_service
    if _traffic_stats_service is None:
        try:
            import redis

            from app.config import settings


            if settings.REDIS_URL:
                redis_client = redis.from_url(settings.REDIS_URL, decode_responses=False)
            else:
                redis_client = redis.Redis(
                    host=settings.REDIS_HOST,
                    port=settings.REDIS_PORT,
                    db=settings.REDIS_DB,
                    password=settings.REDIS_PASSWORD,
                    decode_responses=False,
                )
            _traffic_stats_service = TrafficStatsService(redis_client)
            logger.info("Traffic stats service initialized with Redis")
        except Exception as e:
            logger.warning(f"Failed to initialize Redis for traffic stats: {e}")
            _traffic_stats_service = TrafficStatsService(redis_client=None)
    return _traffic_stats_service
