state machine unit tests

This commit is contained in:
Philip Kalinsky 2020-02-18 15:36:15 -05:00
parent e3397390d3
commit daf4a0ba68
2 changed files with 92 additions and 57 deletions

View File

@ -139,7 +139,6 @@ class FSMMixin:
def fail_stage(self, stage): def fail_stage(self, stage):
fail_trigger = f"fail_{stage}" fail_trigger = f"fail_{stage}"
if fail_trigger in self.machine.get_triggers(self.current_state.name): if fail_trigger in self.machine.get_triggers(self.current_state.name):
self.trigger(fail_trigger) self.trigger(fail_trigger)
app.logger.info( app.logger.info(
@ -158,29 +157,3 @@ class FSMMixin:
) )
self.trigger(finish_trigger) self.trigger(finish_trigger)
def prepare_init(self, event):
pass
def before_init(self, event):
pass
def after_init(self, event):
pass
def prepare_start(self, event):
pass
def before_start(self, event):
pass
def after_start(self, event):
pass
def prepare_reset(self, event):
pass
def before_reset(self, event):
pass
def after_reset(self, event):
pass

View File

@ -2,7 +2,7 @@ import pytest
import pydantic import pydantic
import re import re
from unittest import mock from unittest.mock import Mock, patch
from enum import Enum from enum import Enum
from tests.factories import ( from tests.factories import (
@ -41,12 +41,19 @@ def portfolio():
portfolio = CLINFactory.create().task_order.portfolio portfolio = CLINFactory.create().task_order.portfolio
return portfolio return portfolio
@pytest.fixture(scope="function")
def state_machine():
# TODO: setup clin/to as active/funded/ready
portfolio = CLINFactory.create().task_order.portfolio
return PortfolioStateMachineFactory.create(portfolio=portfolio)
@pytest.mark.state_machine
def test_fsm_creation(portfolio): def test_fsm_creation(portfolio):
sm = PortfolioStateMachineFactory.create(portfolio=portfolio) sm = PortfolioStateMachineFactory.create(portfolio=portfolio)
assert sm.portfolio assert sm.portfolio
@pytest.mark.state_machine
def test_state_machine_trigger_next_transition(portfolio): def test_state_machine_trigger_next_transition(portfolio):
sm = PortfolioStateMachineFactory.create(portfolio=portfolio) sm = PortfolioStateMachineFactory.create(portfolio=portfolio)
@ -57,6 +64,7 @@ def test_state_machine_trigger_next_transition(portfolio):
assert sm.current_state == FSMStates.STARTED assert sm.current_state == FSMStates.STARTED
@pytest.mark.state_machine
def test_state_machine_compose_state(): def test_state_machine_compose_state():
assert ( assert (
compose_state(AzureStages.TENANT, StageStates.CREATED) compose_state(AzureStages.TENANT, StageStates.CREATED)
@ -64,6 +72,7 @@ def test_state_machine_compose_state():
) )
@pytest.mark.state_machine
def test_stage_to_classname(): def test_stage_to_classname():
assert ( assert (
_stage_to_classname(AzureStages.BILLING_PROFILE_CREATION.name) _stage_to_classname(AzureStages.BILLING_PROFILE_CREATION.name)
@ -71,16 +80,19 @@ def test_stage_to_classname():
) )
@pytest.mark.state_machine
def test_get_stage_csp_class(): def test_get_stage_csp_class():
csp_class = get_stage_csp_class(list(AzureStages)[0].name.lower(), "payload") csp_class = get_stage_csp_class(list(AzureStages)[0].name.lower(), "payload")
assert isinstance(csp_class, pydantic.main.ModelMetaclass) assert isinstance(csp_class, pydantic.main.ModelMetaclass)
@pytest.mark.state_machine
def test_get_stage_csp_class_import_fail(): def test_get_stage_csp_class_import_fail():
with pytest.raises(StateMachineMisconfiguredError): with pytest.raises(StateMachineMisconfiguredError):
csp_class = get_stage_csp_class("doesnotexist", "payload") csp_class = get_stage_csp_class("doesnotexist", "payload")
@pytest.mark.state_machine
def test_build_transitions(): def test_build_transitions():
states, transitions = _build_transitions(AzureStagesTest) states, transitions = _build_transitions(AzureStagesTest)
assert [s.get("name").name for s in states] == [ assert [s.get("name").name for s in states] == [
@ -102,6 +114,7 @@ def test_build_transitions():
] ]
@pytest.mark.state_machine
def test_build_csp_states(): def test_build_csp_states():
states = _build_csp_states(AzureStagesTest) states = _build_csp_states(AzureStagesTest)
assert list(states) == [ assert list(states) == [
@ -116,18 +129,18 @@ def test_build_csp_states():
] ]
def test_state_machine_valid_data_classes_for_stages(portfolio): @pytest.mark.state_machine
PortfolioStateMachineFactory.create(portfolio=portfolio) def test_state_machine_valid_data_classes_for_stages():
for stage in AzureStages: for stage in AzureStages:
assert get_stage_csp_class(stage.name.lower(), "payload") is not None assert get_stage_csp_class(stage.name.lower(), "payload") is not None
assert get_stage_csp_class(stage.name.lower(), "result") is not None assert get_stage_csp_class(stage.name.lower(), "result") is not None
def test_attach_machine(portfolio): @pytest.mark.state_machine
sm = PortfolioStateMachineFactory.create(portfolio=portfolio) def test_attach_machine(state_machine):
sm.machine = None state_machine.machine = None
sm.attach_machine(stages=AzureStagesTest) state_machine.attach_machine(stages=AzureStagesTest)
assert list(sm.machine.events) == [ assert list(state_machine.machine.events) == [
"init", "init",
"start", "start",
"reset", "reset",
@ -139,14 +152,60 @@ def test_attach_machine(portfolio):
"resume_progress_tenant", "resume_progress_tenant",
] ]
@pytest.mark.state_machine
def test_current_state_property(state_machine):
assert state_machine.current_state == FSMStates.UNSTARTED
state_machine.state = FSMStates.TENANT_IN_PROGRESS
assert state_machine.current_state == FSMStates.TENANT_IN_PROGRESS
def test_fail_stage(portfolio): @pytest.mark.state_machine
sm = PortfolioStateMachineFactory.create(portfolio=portfolio) def test_fail_stage(state_machine):
sm.state = FSMStates.TENANT_IN_PROGRESS state_machine.state = FSMStates.TENANT_IN_PROGRESS
sm.fail_stage("tenant") state_machine.portfolio.csp_data = {}
assert sm.state == FSMStates.TENANT_FAILED state_machine.fail_stage("tenant")
assert state_machine.state == FSMStates.TENANT_FAILED
@pytest.mark.state_machine
def test_fail_stage_invalid_triggers(state_machine):
state_machine.state = FSMStates.TENANT_IN_PROGRESS
state_machine.portfolio.csp_data = {}
state_machine.machine.get_triggers = Mock(return_value=['some', 'triggers', 'here'])
state_machine.fail_stage("tenant")
assert state_machine.state == FSMStates.TENANT_IN_PROGRESS
@pytest.mark.state_machine
def test_fail_stage_invalid_stage(state_machine):
state_machine.state = FSMStates.TENANT_IN_PROGRESS
state_machine.portfolio.csp_data = {}
portfolio.csp_data = {}
state_machine.fail_stage("invalid stage")
assert state_machine.state == FSMStates.TENANT_IN_PROGRESS
@pytest.mark.state_machine
def test_finish_stage(state_machine):
state_machine.state = FSMStates.TENANT_IN_PROGRESS
state_machine.portfolio.csp_data = {}
state_machine.finish_stage("tenant")
assert state_machine.state == FSMStates.TENANT_CREATED
@pytest.mark.state_machine
def test_finish_stage_invalid_triggers(state_machine):
state_machine.state = FSMStates.TENANT_IN_PROGRESS
state_machine.portfolio.csp_data = {}
state_machine.machine.get_triggers = Mock(return_value=['some', 'triggers', 'here'])
state_machine.finish_stage("tenant")
assert state_machine.state == FSMStates.TENANT_IN_PROGRESS
@pytest.mark.state_machine
def test_finish_stage_invalid_stage(state_machine):
state_machine.state = FSMStates.TENANT_IN_PROGRESS
state_machine.portfolio.csp_data = {}
portfolio.csp_data = {}
state_machine.finish_stage("invalid stage")
assert state_machine.state == FSMStates.TENANT_IN_PROGRESS
@pytest.mark.state_machine
def test_stage_state_to_stage_name(): def test_stage_state_to_stage_name():
stage = _stage_state_to_stage_name( stage = _stage_state_to_stage_name(
FSMStates.TENANT_IN_PROGRESS, StageStates.IN_PROGRESS FSMStates.TENANT_IN_PROGRESS, StageStates.IN_PROGRESS
@ -154,18 +213,17 @@ def test_stage_state_to_stage_name():
assert stage == "tenant" assert stage == "tenant"
def test_state_machine_initialization(portfolio): @pytest.mark.state_machine
def test_state_machine_initialization(state_machine):
sm = PortfolioStateMachineFactory.create(portfolio=portfolio)
for stage in AzureStages: for stage in AzureStages:
# check that all stages have a 'create' and 'fail' triggers # check that all stages have a 'create' and 'fail' triggers
stage_name = stage.name.lower() stage_name = stage.name.lower()
for trigger_prefix in ["create", "fail"]: for trigger_prefix in ["create", "fail"]:
assert hasattr(sm, trigger_prefix + "_" + stage_name) assert hasattr(state_machine, trigger_prefix + "_" + stage_name)
# check that machine # check that machine
in_progress_triggers = sm.machine.get_triggers(stage.name + "_IN_PROGRESS") in_progress_triggers = state_machine.machine.get_triggers(stage.name + "_IN_PROGRESS")
assert [ assert [
"reset", "reset",
"fail", "fail",
@ -173,32 +231,26 @@ def test_state_machine_initialization(portfolio):
"fail_" + stage_name, "fail_" + stage_name,
] == in_progress_triggers ] == in_progress_triggers
started_triggers = sm.machine.get_triggers("STARTED") started_triggers = state_machine.machine.get_triggers("STARTED")
create_trigger = next( create_trigger = next(
filter( filter(
lambda trigger: trigger.startswith("create_"), lambda trigger: trigger.startswith("create_"),
sm.machine.get_triggers(FSMStates.STARTED.name), state_machine.machine.get_triggers(FSMStates.STARTED.name),
), ),
None, None,
) )
assert ["reset", "fail", create_trigger] == started_triggers assert ["reset", "fail", create_trigger] == started_triggers
@mock.patch("atst.domain.csp.cloud.MockCloudProvider") @pytest.mark.state_machine
@patch("atst.domain.csp.cloud.MockCloudProvider")
def test_fsm_transition_start(mock_cloud_provider, portfolio: Portfolio): def test_fsm_transition_start(mock_cloud_provider, portfolio: Portfolio):
mock_cloud_provider._authorize.return_value = None mock_cloud_provider._authorize.return_value = None
mock_cloud_provider._maybe_raise.return_value = None mock_cloud_provider._maybe_raise.return_value = None
sm: PortfolioStateMachine = PortfolioStateMachineFactory.create(portfolio=portfolio)
assert sm.portfolio
assert sm.state == FSMStates.UNSTARTED
sm.init()
assert sm.state == FSMStates.STARTING
sm.start()
assert sm.state == FSMStates.STARTED
expected_states = [ expected_states = [
FSMStates.STARTING,
FSMStates.STARTED,
FSMStates.TENANT_CREATED, FSMStates.TENANT_CREATED,
FSMStates.BILLING_PROFILE_CREATION_CREATED, FSMStates.BILLING_PROFILE_CREATION_CREATED,
FSMStates.BILLING_PROFILE_VERIFICATION_CREATED, FSMStates.BILLING_PROFILE_VERIFICATION_CREATED,
@ -227,6 +279,12 @@ def test_fsm_transition_start(mock_cloud_provider, portfolio: Portfolio):
csp_data = {} csp_data = {}
ppoc = portfolio.owner ppoc = portfolio.owner
if not ppoc:
class ppoc:
first_name = "John"
last_name = "Doe"
email = "email@example.com"
user_id = f"{ppoc.first_name[0]}{ppoc.last_name}".lower() user_id = f"{ppoc.first_name[0]}{ppoc.last_name}".lower()
domain_name = re.sub("[^0-9a-zA-Z]+", "", portfolio.name).lower() domain_name = re.sub("[^0-9a-zA-Z]+", "", portfolio.name).lower()
@ -261,6 +319,10 @@ def test_fsm_transition_start(mock_cloud_provider, portfolio: Portfolio):
config = {"billing_account_name": "billing_account_name"} config = {"billing_account_name": "billing_account_name"}
sm: PortfolioStateMachine = PortfolioStateMachineFactory.create(portfolio=portfolio)
assert sm.portfolio
assert sm.state == FSMStates.UNSTARTED
for expected_state in expected_states: for expected_state in expected_states:
collected_data = dict( collected_data = dict(
list(csp_data.items()) + list(portfolio_data.items()) + list(config.items()) list(csp_data.items()) + list(portfolio_data.items()) + list(config.items())