diff --git a/atst/models/mixins/state_machines.py b/atst/models/mixins/state_machines.py index bc35209d..b2eda399 100644 --- a/atst/models/mixins/state_machines.py +++ b/atst/models/mixins/state_machines.py @@ -10,7 +10,6 @@ class StageStates(Enum): class AzureStages(Enum): TENANT = "tenant" BILLING_PROFILE = "billing profile" - ADMIN_SUBSCRIPTION = "admin subscription" def _build_csp_states(csp_stages): @@ -31,14 +30,14 @@ def _build_csp_states(csp_stages): FSMStates = Enum("FSMStates", _build_csp_states(AzureStages)) +compose_state = lambda csp_stage, state: getattr( + FSMStates, "_".join([csp_stage.name, state.name]) +) + def _build_transitions(csp_stages): transitions = [] states = [] - compose_state = lambda csp_stage, state: getattr( - FSMStates, "_".join([csp_stage.name, state.name]) - ) - for stage_i, csp_stage in enumerate(csp_stages): for state in StageStates: states.append( @@ -99,6 +98,24 @@ class FSMMixin: {"trigger": "fail", "source": "*", "dest": FSMStates.FAILED,}, ] + def fail_stage(self, stage): + fail_trigger = "fail" + stage + if fail_trigger in self.machine.get_triggers(self.current_state.name): + self.trigger(fail_trigger) + + def finish_stage(self, stage): + finish_trigger = "finish_" + stage + if finish_trigger in self.machine.get_triggers(self.current_state.name): + self.trigger(finish_trigger) + + def _get_first_stage_create_trigger(self): + return list( + filter( + lambda trigger: trigger.startswith("create_"), + self.machine.get_triggers(FSMStates.STARTED.name), + ) + )[0] + def prepare_init(self, event): pass @@ -125,13 +142,3 @@ class FSMMixin: def after_reset(self, event): pass - - def fail_stage(self, stage): - fail_trigger = "fail" + stage - if fail_trigger in self.machine.get_triggers(self.current_state.name): - self.trigger(fail_trigger) - - def finish_stage(self, stage): - finish_trigger = "finish_" + stage - if finish_trigger in self.machine.get_triggers(self.current_state.name): - self.trigger(finish_trigger) diff --git a/atst/models/portfolio_state_machine.py b/atst/models/portfolio_state_machine.py index 3c934197..d7b6a36e 100644 --- a/atst/models/portfolio_state_machine.py +++ b/atst/models/portfolio_state_machine.py @@ -84,13 +84,11 @@ class PortfolioStateMachine( elif self.current_state == FSMStates.STARTED: # get the first trigger that starts with 'create_' - create_trigger = list( - filter( - lambda trigger: trigger.startswith("create_"), - self.machine.get_triggers(FSMStates.STARTED.name), - ) - )[0] - self.trigger(create_trigger) + create_trigger = self._get_first_stage_create_trigger() + if create_trigger: + self.trigger(create_trigger) + else: + self.fail_stage(stage) elif state_obj.is_IN_PROGRESS: pass diff --git a/tests/domain/test_portfolio_state_machine.py b/tests/domain/test_portfolio_state_machine.py index 0aa90867..d0a78fa0 100644 --- a/tests/domain/test_portfolio_state_machine.py +++ b/tests/domain/test_portfolio_state_machine.py @@ -6,6 +6,8 @@ from tests.factories import ( ) from atst.models import FSMStates +from atst.models.mixins.state_machines import AzureStages, StageStates, compose_state +from atst.domain.csp import get_stage_csp_class @pytest.fixture(scope="function") @@ -19,14 +21,67 @@ def test_fsm_creation(portfolio): assert sm.portfolio +def test_state_machine_trigger_next_transition(portfolio): + sm = PortfolioStateMachineFactory.create(portfolio=portfolio) + + sm.trigger_next_transition() + assert sm.current_state == FSMStates.STARTING + + sm.trigger_next_transition() + assert sm.current_state == FSMStates.STARTED + + +def test_state_machine_compose_state(portfolio): + sm = PortfolioStateMachineFactory.create(portfolio=portfolio) + assert ( + compose_state(AzureStages.TENANT, StageStates.CREATED) + == FSMStates.TENANT_CREATED + ) + + +def test_state_machine_first_stage_create_trigger(portfolio): + sm = PortfolioStateMachineFactory.create(portfolio=portfolio) + first_stage_create_trigger = sm._get_first_stage_create_trigger() + first_stage_name = list(AzureStages)[0].name.lower() + assert "create_" + first_stage_name == first_stage_create_trigger + + +def test_state_machine_valid_data_classes_for_stages(portfolio): + sm = PortfolioStateMachineFactory.create(portfolio=portfolio) + for stage in AzureStages: + assert get_stage_csp_class(stage.name.lower(), "payload") is not None + assert get_stage_csp_class(stage.name.lower(), "result") is not None + + +def test_state_machine_initialization(portfolio): + + sm = PortfolioStateMachineFactory.create(portfolio=portfolio) + for stage in AzureStages: + + # check that all stages have a 'create' and 'fail' triggers + stage_name = stage.name.lower() + for trigger_prefix in ["create", "fail"]: + assert hasattr(sm, trigger_prefix + "_" + stage_name) + + # check that machine + in_progress_triggers = sm.machine.get_triggers(stage.name + "_IN_PROGRESS") + assert [ + "reset", + "fail", + "finish_" + stage_name, + "fail_" + stage_name, + ] == in_progress_triggers + + started_triggers = sm.machine.get_triggers("STARTED") + first_stage_create_trigger = sm._get_first_stage_create_trigger() + assert ["reset", "fail", first_stage_create_trigger] == started_triggers + + def test_fsm_transition_start(portfolio): sm = PortfolioStateMachineFactory.create(portfolio=portfolio) assert sm.portfolio assert sm.state == FSMStates.UNSTARTED - # next_state does not create the trigger callbacks !!! - # sm.next_state() - sm.init() assert sm.state == FSMStates.STARTING