
from sqlalchemy.orm import Session

from app.crud.base import CRUDBase
from app.models.provider import Model, ModelProvider
from app.schemas.provider import ModelCreate, ModelUpdate, ProviderCreate, ProviderUpdate
from common_logging import get_logger

logger = get_logger(__name__)


class CRUDModelProvider(CRUDBase[ModelProvider, ProviderCreate, ProviderUpdate]):

    def get_by_name(
        self, db: Session, *, name: str, tenant_id: int | None = None
    ) -> ModelProvider | None:
        query = db.query(ModelProvider).filter(ModelProvider.name == name)
        if tenant_id is not None and hasattr(ModelProvider, "tenant_id"):
            query = query.filter(ModelProvider.tenant_id == tenant_id)
        return query.first()

    def remove(self, db: Session, *, id: int) -> ModelProvider | None:
        result = self.delete(db, id=id)
        logger.bind(provider_id=id).info("Provider deleted")
        return result

    def get_by_code(
        self, db: Session, *, code: str, tenant_id: int | None = None
    ) -> ModelProvider | None:
        query = db.query(ModelProvider).filter(ModelProvider.code == code)
        if tenant_id is not None and hasattr(ModelProvider, "tenant_id"):
            query = query.filter(ModelProvider.tenant_id == tenant_id)
        return query.first()

    def get_active_providers(
        self, db: Session, *, skip: int = 0, limit: int = 100
    ) -> list[ModelProvider]:
        return (
            db.query(ModelProvider)
            .filter(ModelProvider.is_active)
            .offset(skip)
            .limit(limit)
            .all()
        )


class CRUDModel(CRUDBase[Model, ModelCreate, ModelUpdate]):

    def get_by_name(
        self, db: Session, *, name: str, provider_id: int | None = None
    ) -> Model | None:
        pass

    def remove(self, db: Session, *, id: int) -> ModelProvider | None:
        result = self.delete(db, id=id)
        logger.bind(model_id=id).info("Model deleted")
        return result

    def get_by_code(self, db: Session, *, code: str) -> Model | None:
        return db.query(Model).filter(Model.code == code).first()

    def get_by_provider(
        self, db: Session, *, provider_id: int, skip: int = 0, limit: int = 100
    ) -> list[Model]:
        return (
            db.query(Model).filter(Model.provider_id == provider_id).offset(skip).limit(limit).all()
        )

    def get_by_type(
        self, db: Session, *, model_type: str, skip: int = 0, limit: int = 100
    ) -> list[Model]:
        return db.query(Model).filter(Model.type == model_type).offset(skip).limit(limit).all()

    def get_active_models(
        self, db: Session, *, provider_id: int | None = None, skip: int = 0, limit: int = 100
    ) -> list[Model]:
        query = db.query(Model).filter(Model.is_active)
        if provider_id:
            query = query.filter(Model.provider_id == provider_id)
        return query.offset(skip).limit(limit).all()


provider = CRUDModelProvider(ModelProvider)
model = CRUDModel(Model)
