from typing import Any

from .base_workflow import BaseWorkflow, WorkflowStatus, WorkflowStep

from common_logging import get_logger

logger = get_logger(__name__)


class PlanningWorkflow(BaseWorkflow):

    def __init__(self, workflow_id: str, tenant_id: int, user_id: int):
        super().__init__(workflow_id, tenant_id, user_id)
        self._init_steps()

    def _init_steps(self):
        self.add_step(WorkflowStep('scheme_design', '方案设计'))
        self.add_step(WorkflowStep('simulation', '模拟计算'))
        self.add_step(WorkflowStep('risk_assessment', '风险评估'))
        self.add_step(WorkflowStep('implementation', '实施'))

    def execute_step(self, step_name: str, data: dict[str, Any]) -> dict[str, Any]:
        step = next((s for s in self.steps if s.name == step_name), None)
        if not step:
            raise ValueError(f'Step {step_name} not found')
        step.status = WorkflowStatus.IN_PROGRESS
        try:
            if step_name == 'scheme_design':
                result = self._design_scheme(data)
            elif step_name == 'simulation':
                result = self._simulate(data)
            elif step_name == 'risk_assessment':
                result = self._assess_risk(data)
            elif step_name == 'implementation':
                result = self._implement(data)
            else:
                raise ValueError(f'Unknown step: {step_name}')
            step.status = WorkflowStatus.COMPLETED
            step.result = result
            return result
        except Exception as e:
            step.status = WorkflowStatus.FAILED
            step.error = str(e)
            logger.error(f'Step {step_name} failed: {e}')
            raise

    def _design_scheme(self, data: dict[str, Any]) -> dict[str, Any]:
        return {'status': 'designed', 'scheme_id': 'scheme_001'}

    def _simulate(self, data: dict[str, Any]) -> dict[str, Any]:
        return {'status': 'simulated', 'savings': data.get('expected_savings', 0)}

    def _assess_risk(self, data: dict[str, Any]) -> dict[str, Any]:
        return {'status': 'assessed', 'risk_level': 'low'}

    def _implement(self, data: dict[str, Any]) -> dict[str, Any]:
        return {'status': 'implemented', 'implementation_id': 'impl_001'}
