
from sqlalchemy.orm import Session

from app.models.role import Role
from app.models.user_role import UserRole
from common_logging import get_logger

logger = get_logger(__name__)


class CRUDUserRole:

    def get_user_roles(self, db: Session, *, user_id: int) -> list[Role]:
        user_roles = (
            db.query(UserRole)
            .filter(UserRole.user_id == user_id, not UserRole.is_deleted)
            .all()
        )
        role_ids = [ur.role_id for ur in user_roles]
        if not role_ids:
            return []
        return db.query(Role).filter(Role.id.in_(role_ids), not Role.is_deleted).all()

    def update_user_roles(
        self, db: Session, *, user_id: int, role_ids: list[int], tenant_id: int
    ) -> None:
        from app.core.permissions import (
            add_role_for_user,
            invalidate_user_permissions,
            remove_role_for_user,
        )

        old_roles = self.get_user_roles(db, user_id=user_id)
        existing = (
            db.query(UserRole)
            .filter(UserRole.user_id == user_id, not UserRole.is_deleted)
            .all()
        )
        for ur in existing:
            ur.is_deleted = True
        for role_id in role_ids:
            user_role = UserRole(user_id=user_id, role_id=role_id)
            db.add(user_role)
        db.commit()
        for old_role in old_roles:
            remove_role_for_user(user_id, f"role:{old_role.code}", tenant_id)
        new_roles = db.query(Role).filter(Role.id.in_(role_ids)).all()
        for role in new_roles:
            add_role_for_user(user_id, f"role:{role.code}", tenant_id)
        invalidate_user_permissions(user_id, tenant_id)
        logger.bind(user_id=user_id).info("User roles updated")

    def add_user_role(self, db: Session, *, user_id: int, role_id: int) -> UserRole:
        user_role = UserRole(user_id=user_id, role_id=role_id)
        db.add(user_role)
        db.commit()
        db.refresh(user_role)
        logger.bind(user_id=user_id, role_id=role_id).info("User role added")
        return user_role

    def remove_user_role(self, db: Session, *, user_id: int, role_id: int) -> None:
        user_role = (
            db.query(UserRole)
            .filter(
                UserRole.user_id == user_id,
                UserRole.role_id == role_id,
                not UserRole.is_deleted,
            )
            .first()
        )
        if user_role:
            user_role.is_deleted = True
            db.commit()
            logger.bind(user_id=user_id, role_id=role_id).info("User role removed")


user_role = CRUDUserRole()
