import hashlib
import html
import re
import secrets
import time

from fastapi import HTTPException, Request, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.trustedhost import TrustedHostMiddleware
from starlette.middleware.base import BaseHTTPMiddleware

from common_logging import get_logger

logger = get_logger(__name__)


class RateLimitMiddleware(BaseHTTPMiddleware):

    def __init__(self, app, requests_per_minute: int = 60):
        super().__init__(app)
        self.requests_per_minute = requests_per_minute
        self.requests: dict[str, list] = {}

    async def dispatch(self, request: Request, call_next):
        client_ip = request.client.host if request.client else "unknown"
        if request.url.path in ["/health", "/metrics"]:
            return await call_next(request)
        current_time = time.time()
        if client_ip not in self.requests:
            self.requests[client_ip] = []
        self.requests[client_ip] = [
            req_time for req_time in self.requests[client_ip] if current_time - req_time < 60
        ]
        if len(self.requests[client_ip]) >= self.requests_per_minute:
            logger.warning(f"Rate limit exceeded for IP: {client_ip}")
            raise HTTPException(
                status_code=status.HTTP_429_TOO_MANY_REQUESTS,
                detail="Too many requests. Please try again later.",
            )
        self.requests[client_ip].append(current_time)
        response = await call_next(request)
        return response


class SecurityHeadersMiddleware(BaseHTTPMiddleware):

    async def dispatch(self, request: Request, call_next):
        response = await call_next(request)
        response.headers["X-Content-Type-Options"] = "nosniff"
        response.headers["X-Frame-Options"] = "DENY"
        response.headers["X-XSS-Protection"] = "1; mode=block"
        response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
        response.headers["Content-Security-Policy"] = "default-src 'self'"
        response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
        response.headers["Permissions-Policy"] = "geolocation=(), microphone=(), camera=()"
        return response


def configure_cors(app):
    app.add_middleware(
        CORSMiddleware,
        allow_origins=[
            "http://localhost:8888",
            "http://localhost:8889",
            "http://localhost:3000",
            "http://127.0.0.1:8888",
            "http://127.0.0.1:8889",
            "http://127.0.0.1:3000",
        ],
        allow_credentials=True,
        allow_methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
        allow_headers=[
            "Authorization",
            "Content-Type",
            "Accept",
            "Accept-Language",
            "X-CSRF-Token",
            "X-Request-ID",
        ],
    )


def configure_trusted_hosts(app, allowed_hosts: list):
    app.add_middleware(TrustedHostMiddleware, allowed_hosts=allowed_hosts)


def hash_sensitive_data(data: str) -> str:
    hashed = hashlib.sha256(data.encode()).hexdigest()
    logger.info("Sensitive data hashed")
    return hashed


def validate_input_length(value: str, max_length: int, field_name: str):
    if len(value) > max_length:
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=f"{field_name} exceeds maximum length of {max_length}",
        )


def sanitize_filename(filename: str) -> str:
    filename = filename.replace("/", "_").replace("\\", "_")
    filename = re.sub('[<>:"|?*]', "", filename)
    filename = filename.replace("..", "_")
    filename = filename.replace("\x00", "")
    filename = filename.lstrip(".")
    if len(filename) > 255:
        filename = filename[:255]
    return filename


def check_file_size(size: int, max_size_mb: int = 10) -> bool:
    if size <= 0:
        return False
    max_size_bytes = max_size_mb * 1024 * 1024
    return size <= max_size_bytes


def validate_mime_type(mime_type: str | None, allowed_types: list) -> bool:
    if not mime_type:
        return False
    return mime_type in allowed_types


def validate_file_upload(file, allowed_types: list, max_size_mb: int = 10) -> dict:
    if not validate_mime_type(file.content_type, allowed_types):
        return {"valid": False, "error": f"Invalid MIME type: {file.content_type}"}
    if not check_file_size(file.size, max_size_mb):
        return {
            "valid": False,
            "error": f"File size exceeds maximum allowed size of {max_size_mb}MB",
        }
    sanitized_name = sanitize_filename(file.filename)
    return {"valid": True, "sanitized_filename": sanitized_name}


def validate_email(email: str) -> bool:
    pattern = "^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$"
    return re.match(pattern, email) is not None


def validate_email_format(email: str) -> bool:
    if not email:
        return False
    return validate_email(email)


def validate_url_format(url: str) -> bool:
    if not url:
        return False
    dangerous_protocols = ["javascript:", "file:", "data:"]
    url_lower = url.lower()
    if any(url_lower.startswith(proto) for proto in dangerous_protocols):
        return False
    return url_lower.startswith("http://") or url_lower.startswith("https://")


def check_sql_injection(value: str) -> bool:
    suspicious_patterns = [
        "' OR '1'='1",
        "'; DROP TABLE",
        "' OR 1=1",
        "UNION SELECT",
        "'; --",
        "' OR 'a'='a",
    ]
    value_upper = value.upper()
    return any(pattern.upper() in value_upper for pattern in suspicious_patterns)


def sanitize_sql_input(value: str) -> str:
    if not value:
        return value
    return value.replace("'", "").replace(";", "").replace("--", "")


def check_xss(value: str) -> bool:
    suspicious_patterns = ["<script", "javascript:", "onerror=", "onload=", "<iframe", "eval("]
    value_lower = value.lower()
    return any(pattern.lower() in value_lower for pattern in suspicious_patterns)


def sanitize_html_input(value: str) -> str:
    sanitized = html.escape(value)
    logger.info("HTML input sanitized")
    return sanitized


CSRF_TOKEN_LENGTH = 32
CSRF_COOKIE_NAME = "csrf_token"
CSRF_HEADER_NAME = "X-CSRF-Token"


def generate_csrf_token() -> str:
    return secrets.token_hex(CSRF_TOKEN_LENGTH)


def verify_csrf_token(token: str | None, cookie_token: str | None) -> bool:
    if not token or not cookie_token:
        return False
    return secrets.compare_digest(token, cookie_token)


class CSRFMiddleware(BaseHTTPMiddleware):

    def __init__(self, app, exempt_paths: list = None):
        super().__init__(app)
        self.exempt_paths = exempt_paths or [
            "/health",
            "/metrics",
            "/docs",
            "/openapi.json",
            "/api/v1/auth/login",
            "/api/v1/auth/register",
            "/api/v1/agents",
        ]

    async def dispatch(self, request: Request, call_next):
        if request.method in ["GET", "HEAD", "OPTIONS"]:
            response = await call_next(request)
            if CSRF_COOKIE_NAME not in request.cookies:
                csrf_token = generate_csrf_token()
                is_https = request.url.scheme == "https"
                response.set_cookie(
                    key=CSRF_COOKIE_NAME,
                    value=csrf_token,
                    httponly=False,
                    secure=is_https,
                    samesite="lax",
                    max_age=3600,
                )
            return response
        if any(request.url.path.startswith(path) for path in self.exempt_paths):
            return await call_next(request)
        csrf_token = request.headers.get(CSRF_HEADER_NAME)
        csrf_cookie = request.cookies.get(CSRF_COOKIE_NAME)
        if not csrf_token or not csrf_cookie:
            logger.warning(f"CSRF token missing for {request.method} {request.url.path}")
            raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="CSRF token missing")
        if not verify_csrf_token(csrf_token, csrf_cookie):
            logger.warning(f"CSRF token mismatch for {request.method} {request.url.path}")
            raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="CSRF token invalid")
        return await call_next(request)
