from datetime import datetime
from enum import Enum
from typing import Any

from common_logging import get_logger

logger = get_logger(__name__)


class WorkflowStatus(str, Enum):
    PENDING = 'pending'
    IN_PROGRESS = 'in_progress'
    COMPLETED = 'completed'
    FAILED = 'failed'
    CANCELLED = 'cancelled'

class WorkflowStep:

    def __init__(self, name: str, description: str, required: bool=True):
        self.name = name
        self.description = description
        self.required = required
        self.status = WorkflowStatus.PENDING
        self.result: dict[str, Any] | None = None
        self.error: str | None = None
        self.started_at: datetime | None = None
        self.completed_at: datetime | None = None

class BaseWorkflow:

    def __init__(self, workflow_id: str, tenant_id: int, user_id: int):
        self.workflow_id = workflow_id
        self.tenant_id = tenant_id
        self.user_id = user_id
        self.status = WorkflowStatus.PENDING
        self.steps: list[WorkflowStep] = []
        self.context: dict[str, Any] = {}
        self.created_at = datetime.now()
        self.updated_at = datetime.now()
        self._status_cache: dict[str, Any] | None = None
        self._cache_dirty = True

    def add_step(self, step: WorkflowStep):
        self.steps.append(step)
        self._invalidate_cache()

    def get_current_step(self) -> WorkflowStep | None:
        for step in self.steps:
            if step.status == WorkflowStatus.PENDING:
                return step
        return None

    def execute_step(self, step_name: str, data: dict[str, Any]) -> dict[str, Any]:
        self._invalidate_cache()
        raise NotImplementedError

    def get_status(self) -> dict[str, Any]:
        if not self._cache_dirty and self._status_cache:
            return self._status_cache
        status = {'workflow_id': self.workflow_id, 'status': self.status, 'steps': [{'name': step.name, 'status': step.status, 'result': step.result, 'error': step.error} for step in self.steps], 'created_at': self.created_at.isoformat(), 'updated_at': self.updated_at.isoformat()}
        self._status_cache = status
        self._cache_dirty = False
        return status

    def _invalidate_cache(self):
        self._cache_dirty = True
        self.updated_at = datetime.now()

    def update_step_status(self, step_name: str, status: WorkflowStatus, result: dict[str, Any] | None=None, error: str | None=None):
        for step in self.steps:
            if step.name == step_name:
                step.status = status
                if result:
                    step.result = result
                if error:
                    step.error = error
                if status == WorkflowStatus.IN_PROGRESS:
                    step.started_at = datetime.now()
                elif status in [WorkflowStatus.COMPLETED, WorkflowStatus.FAILED]:
                    step.completed_at = datetime.now()
                self._invalidate_cache()
                if status == WorkflowStatus.FAILED:
                    logger.bind(workflow_id=self.workflow_id).warning("workflow step failed", step=step_name, error=error)
                else:
                    logger.bind(workflow_id=self.workflow_id).info("workflow step status updated", step=step_name, status=status.value)
                break
