from typing import Any, Generic, TypeVar

from fastapi.encoders import jsonable_encoder
from pydantic import BaseModel
from sqlalchemy import or_
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session

from app.core.exceptions import ConstraintViolationError
from app.models.base import Base
from app.models.user import User
from common_logging import get_logger

logger = get_logger(__name__)

ModelType = TypeVar("ModelType", bound=Base)
CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)


class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):

    def __init__(self, model: type[ModelType]):
        self.model = model

    def get(self, db: Session, id: Any, current_user: User | None = None) -> ModelType | None:
        query = db.query(self.model).filter(self.model.id == id)
        if current_user:
            query = self._apply_permission_filter(query, current_user, db)
        return query.first()

    def get_multi(
        self,
        db: Session,
        *,
        skip: int = 0,
        limit: int = 100,
        current_user: User | None = None,
        filters: dict[str, Any] | None = None,
        order_by: str | None = None,
        order_desc: bool = True,
    ) -> list[ModelType]:
        query = db.query(self.model)
        if current_user:
            query = self._apply_permission_filter(query, current_user, db)
        if filters:
            for field, value in filters.items():
                if hasattr(self.model, field):
                    if isinstance(value, list):
                        query = query.filter(getattr(self.model, field).in_(value))
                    elif isinstance(value, str) and value.startswith("%") and value.endswith("%"):
                        query = query.filter(getattr(self.model, field).like(value))
                    else:
                        query = query.filter(getattr(self.model, field) == value)
        if order_by and hasattr(self.model, order_by):
            order_field = getattr(self.model, order_by)
            query = query.order_by(order_field.desc() if order_desc else order_field.asc())
        elif hasattr(self.model, "created_at"):
            query = query.order_by(self.model.created_at.desc())
        return query.offset(skip).limit(limit).all()

    def get_count(
        self,
        db: Session,
        *,
        current_user: User | None = None,
        filters: dict[str, Any] | None = None,
    ) -> int:
        query = db.query(self.model)
        if current_user:
            query = self._apply_permission_filter(query, current_user, db)
        if filters:
            for field, value in filters.items():
                if hasattr(self.model, field):
                    if isinstance(value, list):
                        query = query.filter(getattr(self.model, field).in_(value))
                    elif isinstance(value, str) and value.startswith("%") and value.endswith("%"):
                        query = query.filter(getattr(self.model, field).like(value))
                    else:
                        query = query.filter(getattr(self.model, field) == value)
        return query.count()

    def create(
        self,
        db: Session,
        *,
        obj_in: CreateSchemaType,
        created_by: int | None = None,
        tenant_id: int | None = None,
        commit: bool = True,
    ) -> ModelType:
        obj_in_data = jsonable_encoder(obj_in)
        if created_by and hasattr(self.model, "created_by"):
            obj_in_data["created_by"] = created_by
        if tenant_id is not None and hasattr(self.model, "tenant_id"):
            obj_in_data["tenant_id"] = tenant_id
        db_obj = self.model(**obj_in_data)
        db.add(db_obj)
        if commit:
            try:
                db.flush()
                db.commit()
                db.expunge(db_obj)
                db.add(db_obj)
                db.refresh(db_obj)
                logger.bind(model=self.model.__name__, id=db_obj.id).info("Record created")
            except IntegrityError as e:
                db.rollback()
                constraint_name = str(e.orig) if hasattr(e, "orig") else str(e)
                raise ConstraintViolationError(constraint_name) from None
        return db_obj

    def update(
        self,
        db: Session,
        *,
        db_obj: ModelType,
        obj_in: UpdateSchemaType | dict[str, Any],
        commit: bool = True,
    ) -> ModelType:
        obj_data = jsonable_encoder(db_obj)
        if isinstance(obj_in, dict):
            update_data = obj_in
        else:
            update_data = obj_in.model_dump(exclude_unset=True)
        for field in obj_data:
            if field in update_data:
                setattr(db_obj, field, update_data[field])
        db.add(db_obj)
        if commit:
            db.commit()
            db.refresh(db_obj)
            logger.bind(model=self.model.__name__, id=db_obj.id).info("Record updated")
        return db_obj

    def delete(
        self, db: Session, *, id: int, soft: bool = False, commit: bool = True
    ) -> ModelType | None:
        obj = db.query(self.model).filter(self.model.id == id).first()
        if not obj:
            return None
        if soft and hasattr(self.model, "is_deleted"):
            obj.is_deleted = True
            db.add(obj)
        else:
            db.delete(obj)
        if commit:
            db.commit()
            logger.bind(model=self.model.__name__, id=id, soft=soft).info("Record deleted")
        return obj

    def delete_multi(
        self,
        db: Session,
        *,
        ids: list[int],
        tenant_id: int,
        current_user,
        soft: bool = False,
        commit: bool = True,
    ) -> int:
        query = db.query(self.model).filter(self.model.id.in_(ids))
        if hasattr(self.model, "tenant_id"):
            query = query.filter(self.model.tenant_id == tenant_id)
        if hasattr(self.model, "created_by") and current_user.role != "platform_admin":
            query = query.filter(self.model.created_by == current_user.id)
        if soft and hasattr(self.model, "is_deleted"):
            count = query.update({"is_deleted": True}, synchronize_session=False)
        else:
            count = query.delete(synchronize_session=False)
        if commit:
            db.commit()
            logger.bind(model=self.model.__name__, count=count, soft=soft).info("Records batch deleted")
        return count

    def _apply_permission_filter(self, query, current_user: User, db: Session):
        if hasattr(self.model, "is_deleted"):
            query = query.filter(self.model.is_deleted == False)
        if current_user.role == "platform_admin":
            return query
        if hasattr(self.model, "tenant_id"):
            if current_user.tenant_id is None:
                query = query.filter(self.model.tenant_id.is_(None))
            else:
                query = query.filter(self.model.tenant_id == current_user.tenant_id)
        return query

    def search(
        self,
        db: Session,
        *,
        search_term: str,
        search_fields: list[str],
        skip: int = 0,
        limit: int = 100,
        current_user: User | None = None,
    ) -> list[ModelType]:
        query = db.query(self.model)
        if current_user:
            query = self._apply_permission_filter(query, current_user, db)
        search_conditions = []
        for field in search_fields:
            if hasattr(self.model, field):
                search_conditions.append(getattr(self.model, field).like(f"%{search_term}%"))
        if search_conditions:
            query = query.filter(or_(*search_conditions))
        return query.offset(skip).limit(limit).all()
