from sqlalchemy.orm import Session
from common_logging import get_logger, log_execution

from app.models.training_template import TrainingTemplate

logger = get_logger(__name__)


class TemplateManager:

    def __init__(self, db: Session):
        self.db = db

    @log_execution(logger)
    def create_template(self, name: str, description: str, task_type: str, platform: str, config: dict, is_public: bool, tenant_id: int, created_by: int):
        template = TrainingTemplate(name=name, description=description, task_type=task_type, platform=platform, config=config, is_public=is_public, tenant_id=tenant_id, created_by=created_by)
        self.db.add(template)
        self.db.commit()
        self.db.refresh(template)
        logger.bind(template_id=template.id).info("Template created")
        return template

    def get_templates(self, tenant_id: int, task_type: str=None):
        query = self.db.query(TrainingTemplate).filter((TrainingTemplate.tenant_id == tenant_id) | (TrainingTemplate.is_public))
        if task_type:
            query = query.filter(TrainingTemplate.task_type == task_type)
        return query.all()

    def get_template(self, template_id: int):
        return self.db.query(TrainingTemplate).filter(TrainingTemplate.id == template_id).first()

    def update_template(self, template_id: int, **kwargs):
        template = self.get_template(template_id)
        if not template:
            return None
        for key, value in kwargs.items():
            setattr(template, key, value)
        self.db.commit()
        self.db.refresh(template)
        logger.bind(template_id=template_id).info("Template updated")
        return template

    def delete_template(self, template_id: int):
        template = self.get_template(template_id)
        if template:
            self.db.delete(template)
            self.db.commit()
            logger.bind(template_id=template_id).info("Template deleted")
            return True
        return False
