import json
import time
from collections.abc import Callable

from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware

from app.core.tenant_context import get_tenant_context
from app.db.session import SessionLocal
from app.models.audit_log import AuditLog
from common_logging import get_logger, log_execution

logger = get_logger(__name__)


class AuditMiddleware(BaseHTTPMiddleware):
    EXCLUDED_PATHS = ["/docs", "/redoc", "/openapi.json", "/health", "/metrics", "/favicon.ico"]
    SENSITIVE_FIELDS = ["password", "api_key", "secret", "token", "authorization"]

    async def dispatch(self, request: Request, call_next: Callable) -> Response:
        if any(request.url.path.startswith(path) for path in self.EXCLUDED_PATHS):
            return await call_next(request)
        start_time = time.time()
        tenant_context = get_tenant_context()
        tenant_id = tenant_context.get("tenant_id") if tenant_context else None
        user_id = tenant_context.get("user_id") if tenant_context else None
        ip_address = request.client.host if request.client else None
        user_agent = request.headers.get("user-agent", "")
        request_data = None
        if request.method in ["POST", "PUT", "PATCH"]:
            try:
                body = await request.body()
                if body:
                    request_data = json.loads(body.decode())
                    request_data = self._sanitize_data(request_data)
            except Exception as e:
                logger.warning(f"Failed to parse request body: {e}")
                request_data = {"error": "Failed to parse body"}
        response = None
        error_message = None
        response_size = 0
        try:
            response = await call_next(request)
            response_status = response.status_code
            response_size = getattr(request.state, "response_size", 0)
        except Exception as e:
            logger.error(f"Request failed: {e}")
            response_status = 500
            error_message = str(e)
            raise
        finally:
            duration_ms = int((time.time() - start_time) * 1000)
            request_size = getattr(request.state, "request_size", 0)
            action, resource_type = self._parse_action_and_resource(
                request.method, request.url.path
            )
            try:
                self._log_to_database(
                    tenant_id=tenant_id,
                    user_id=user_id,
                    action=action,
                    resource_type=resource_type,
                    method=request.method,
                    path=request.url.path,
                    ip_address=ip_address,
                    user_agent=user_agent,
                    request_data=request_data,
                    response_status=response_status,
                    error_message=error_message,
                    duration_ms=duration_ms,
                    request_size_bytes=request_size,
                    response_size_bytes=response_size,
                )
            except Exception as e:
                logger.error(f"Failed to write audit log: {e}")
        return response

    def _sanitize_data(self, data: dict) -> dict:
        if not isinstance(data, dict):
            return data
        sanitized = {}
        for key, value in data.items():
            if any(sensitive in key.lower() for sensitive in self.SENSITIVE_FIELDS):
                sanitized[key] = "***REDACTED***"
            elif isinstance(value, dict):
                sanitized[key] = self._sanitize_data(value)
            elif isinstance(value, list):
                sanitized[key] = [
                    self._sanitize_data(item) if isinstance(item, dict) else item for item in value
                ]
            else:
                sanitized[key] = value
        return sanitized

    def _parse_action_and_resource(self, method: str, path: str) -> tuple:
        method_action_map = {
            "GET": "read",
            "POST": "create",
            "PUT": "update",
            "PATCH": "update",
            "DELETE": "delete",
        }
        action = method_action_map.get(method, "unknown")
        if "/login" in path:
            action = "login"
        elif "/logout" in path:
            action = "logout"
        elif "/chat" in path:
            action = "execute"
        elif "/toggle" in path or "/status" in path:
            action = "update"
        resource_type = "unknown"
        path_parts = path.strip("/").split("/")
        if len(path_parts) >= 3:
            resource_type = path_parts[2]
        return (action, resource_type)

    @log_execution
    def _log_to_database(
        self,
        tenant_id: int,
        user_id: int,
        action: str,
        resource_type: str,
        method: str,
        path: str,
        ip_address: str,
        user_agent: str,
        request_data: dict,
        response_status: int,
        error_message: str,
        duration_ms: int,
        request_size_bytes: int = 0,
        response_size_bytes: int = 0,
    ):
        db = SessionLocal()
        try:
            bandwidth_bytes = request_size_bytes + response_size_bytes
            audit_log = AuditLog(
                tenant_id=tenant_id,
                user_id=user_id,
                action=action,
                resource_type=resource_type,
                method=method,
                path=path,
                ip_address=ip_address,
                user_agent=user_agent,
                request_data=request_data,
                response_status=response_status,
                error_message=error_message,
                duration_ms=duration_ms,
                request_size_bytes=request_size_bytes,
                response_size_bytes=response_size_bytes,
                bandwidth_bytes=bandwidth_bytes,
            )
            db.add(audit_log)
            db.commit()
            logger.info(f"Audit log created: action={action}, resource={resource_type}, status={response_status}", user_id=user_id, tenant_id=tenant_id)
        except Exception as e:
            logger.error(f"Failed to write audit log to database: {e}")
            db.rollback()
        finally:
            db.close()
