from contextvars import ContextVar

import jwt
from fastapi import HTTPException, Request, status
from starlette.middleware.base import BaseHTTPMiddleware

from common_logging import get_logger

logger = get_logger(__name__)
_tenant_context: ContextVar[dict | None] = ContextVar("tenant_context", default=None)


class TenantContextMiddleware(BaseHTTPMiddleware):

    async def dispatch(self, request: Request, call_next):
        public_paths = [
            "/",
            "/health",
            "/docs",
            "/openapi.json",
            "/redoc",
            "/api/v1/auth/login",
            "/api/v1/auth/register",
        ]
        if request.url.path in public_paths or request.url.path.startswith("/static"):
            return await call_next(request)
        auth_header = request.headers.get("Authorization")
        if auth_header and auth_header.startswith("Bearer "):
            token = auth_header.replace("Bearer ", "")
            try:
                from app.config import settings

                payload = jwt.decode(
                    token, settings.SECRET_KEY, algorithms=["HS256"], options={"verify_exp": True}
                )
                tenant_id = payload.get("tenant_id") or payload.get("tenant_id")
                user_id = payload.get("sub")
                if tenant_id:
                    from app.db.session import SessionLocal
                    from app.models.tenant import Tenant

                    db = SessionLocal()
                    try:
                        tenant = (
                            db.query(Tenant)
                            .filter(Tenant.id == int(tenant_id), not Tenant.is_deleted)
                            .first()
                        )
                        if not tenant:
                            logger.warning(f"Invalid or deleted tenant_id in JWT: {tenant_id}")
                        else:
                            set_tenant_context(
                                tenant_id=int(tenant_id), user_id=int(user_id) if user_id else None
                            )
                            logger.debug(
                                f"Tenant context set from JWT: tenant_id={tenant_id}, user_id={user_id}"
                            )
                    finally:
                        db.close()
            except jwt.InvalidTokenError as e:
                logger.warning(f"Invalid JWT token: {e}")
            except Exception as e:
                logger.error(f"Error extracting tenant context from JWT: {e}")
        try:
            response = await call_next(request)
            return response
        finally:
            clear_tenant_context()


def set_tenant_context(tenant_id: int, user_id: int | None = None, **kwargs):
    context = {"tenant_id": tenant_id, "user_id": user_id, **kwargs}
    _tenant_context.set(context)
    logger.info(f"Tenant context set", tenant_id=tenant_id, user_id=user_id)


def get_tenant_context() -> dict | None:
    return _tenant_context.get()


def get_current_tenant_id() -> int | None:
    context = _tenant_context.get()
    return context.get("tenant_id") if context else None


def get_current_user_id() -> int | None:
    context = _tenant_context.get()
    return context.get("user_id") if context else None


def clear_tenant_context():
    _tenant_context.set(None)
    logger.debug("Tenant context cleared")


def require_tenant_context():
    context = _tenant_context.get()
    if not context or not context.get("tenant_id"):
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Tenant context not set. Please authenticate.",
        )
    return context
