
from sqlalchemy.orm import Session

from app.crud.base import CRUDBase
from app.models.agent import Agent, ChatMessage
from app.schemas.agent import AgentCreate, AgentUpdate
from common_logging import get_logger

logger = get_logger(__name__)


class CRUDAgent(CRUDBase[Agent, AgentCreate, AgentUpdate]):

    def remove(self, db: Session, *, id: int) -> Agent | None:
        return self.delete(db, id=id)

    def get_by_name(
        self,
        db: Session,
        *,
        name: str,
        created_by: int | None = None,
        tenant_id: int | None = None,
    ) -> Agent | None:
        query = db.query(Agent).filter(Agent.name == name)
        if created_by:
            query = query.filter(Agent.created_by == created_by)
        if tenant_id is not None and hasattr(Agent, "tenant_id"):
            query = query.filter(Agent.tenant_id == tenant_id)
        return query.first()

    def get_by_model(
        self, db: Session, *, model_id: int, tenant_id: int, skip: int = 0, limit: int = 100
    ) -> list[Agent]:
        return (
            db.query(Agent)
            .filter(Agent.model_id == model_id, Agent.tenant_id == tenant_id)
            .offset(skip)
            .limit(limit)
            .all()
        )

    def get_active_agents(
        self,
        db: Session,
        *,
        tenant_id: int,
        created_by: int | None = None,
        skip: int = 0,
        limit: int = 100,
    ) -> list[Agent]:
        query = db.query(Agent).filter(Agent.status == "active", Agent.tenant_id == tenant_id)
        if created_by:
            query = query.filter(Agent.created_by == created_by)
        return query.offset(skip).limit(limit).all()

    def update_status(
        self, db: Session, *, agent_id: int, tenant_id: int, status: str, commit: bool = True
    ) -> Agent | None:
        agent = db.query(Agent).filter(Agent.id == agent_id, Agent.tenant_id == tenant_id).first()
        if not agent:
            return None
        agent.status = status
        db.add(agent)
        if commit:
            db.commit()
            db.refresh(agent)
            logger.bind(agent_id=agent.id).info("Agent status updated")
        return agent

    def delete(
        self, db: Session, *, id: int, soft: bool = False, commit: bool = True
    ) -> Agent | None:
        agent = db.query(Agent).filter(Agent.id == id).first()
        if not agent:
            return None
        db.query(ChatMessage).filter(ChatMessage.agent_id == id).delete(synchronize_session=False)
        if soft and hasattr(Agent, "is_deleted"):
            agent.is_deleted = True
            db.add(agent)
        else:
            db.delete(agent)
        if commit:
            db.commit()
            logger.bind(agent_id=id).info("Agent deleted")
        return agent


agent = CRUDAgent(Agent)
