Merge pull request #1433 from dod-ccpo/state-machine-unit-tests

state machine unit tests
This commit is contained in:
tomdds 2020-02-20 10:39:23 -05:00 committed by GitHub
commit 08ca8eac79
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 187 additions and 111 deletions

View File

@ -139,7 +139,6 @@ class FSMMixin:
def fail_stage(self, stage): def fail_stage(self, stage):
fail_trigger = f"fail_{stage}" fail_trigger = f"fail_{stage}"
if fail_trigger in self.machine.get_triggers(self.current_state.name): if fail_trigger in self.machine.get_triggers(self.current_state.name):
self.trigger(fail_trigger) self.trigger(fail_trigger)
app.logger.info( app.logger.info(
@ -157,30 +156,3 @@ class FSMMixin:
f"calling finish trigger '{finish_trigger}' for '{self.__repr__()}'" f"calling finish trigger '{finish_trigger}' for '{self.__repr__()}'"
) )
self.trigger(finish_trigger) self.trigger(finish_trigger)
def prepare_init(self, event):
pass
def before_init(self, event):
pass
def after_init(self, event):
pass
def prepare_start(self, event):
pass
def before_start(self, event):
pass
def after_start(self, event):
pass
def prepare_reset(self, event):
pass
def before_reset(self, event):
pass
def after_reset(self, event):
pass

View File

@ -1,33 +1,39 @@
import pytest import pytest
import pydantic import pydantic
import re from unittest.mock import Mock, patch
from unittest import mock
from enum import Enum from enum import Enum
from unittest.mock import Mock, patch
import pendulum
import pydantic
import pytest
from tests.factories import ( from tests.factories import (
PortfolioStateMachineFactory, ApplicationFactory,
CLINFactory, CLINFactory,
PortfolioFactory,
PortfolioStateMachineFactory,
TaskOrderFactory,
UserFactory,
get_portfolio_csp_data,
) )
from atst.models import FSMStates, PortfolioStateMachine, TaskOrder from atst.models import FSMStates, PortfolioStateMachine, TaskOrder
from atst.domain.csp.cloud.models import BillingProfileCreationCSPPayload
from atst.models.mixins.state_machines import ( from atst.models.mixins.state_machines import (
AzureStages, AzureStages,
StageStates, StageStates,
compose_state,
_build_transitions,
_build_csp_states, _build_csp_states,
_build_transitions,
compose_state,
) )
from atst.models.portfolio import Portfolio from atst.models.portfolio import Portfolio
from atst.models.portfolio_state_machine import ( from atst.models.portfolio_state_machine import (
get_stage_csp_class,
_stage_to_classname,
_stage_state_to_stage_name,
StateMachineMisconfiguredError, StateMachineMisconfiguredError,
_stage_state_to_stage_name,
_stage_to_classname,
get_stage_csp_class,
) )
# TODO: Write failure case tests # TODO: Write failure case tests
@ -37,26 +43,51 @@ class AzureStagesTest(Enum):
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def portfolio(): def portfolio():
# TODO: setup clin/to as active/funded/ready today = pendulum.today()
portfolio = CLINFactory.create().task_order.portfolio yesterday = today.subtract(days=1)
future = today.add(days=100)
owner = UserFactory.create()
portfolio = PortfolioFactory.create(owner=owner)
ApplicationFactory.create(portfolio=portfolio, environments=[{"name": "dev"}])
TaskOrderFactory.create(
portfolio=portfolio,
signed_at=yesterday,
clins=[CLINFactory.create(start_date=yesterday, end_date=future)],
)
return portfolio return portfolio
@pytest.fixture(scope="function")
def state_machine(portfolio):
return PortfolioStateMachineFactory.create(portfolio=portfolio)
@pytest.mark.state_machine
def test_fsm_creation(portfolio): def test_fsm_creation(portfolio):
sm = PortfolioStateMachineFactory.create(portfolio=portfolio) sm = PortfolioStateMachineFactory.create(portfolio=portfolio)
assert sm.portfolio assert sm.portfolio
def test_state_machine_trigger_next_transition(portfolio): @pytest.mark.state_machine
sm = PortfolioStateMachineFactory.create(portfolio=portfolio) def test_state_machine_trigger_next_transition(state_machine):
sm.trigger_next_transition() state_machine.trigger_next_transition()
assert sm.current_state == FSMStates.STARTING assert state_machine.current_state == FSMStates.STARTING
sm.trigger_next_transition() state_machine.trigger_next_transition()
assert sm.current_state == FSMStates.STARTED assert state_machine.current_state == FSMStates.STARTED
state_machine.state = FSMStates.STARTED
state_machine.trigger_next_transition(
csp_data=get_portfolio_csp_data(state_machine.portfolio)
)
assert state_machine.current_state == FSMStates.TENANT_CREATED
@pytest.mark.state_machine
def test_state_machine_compose_state(): def test_state_machine_compose_state():
assert ( assert (
compose_state(AzureStages.TENANT, StageStates.CREATED) compose_state(AzureStages.TENANT, StageStates.CREATED)
@ -64,6 +95,7 @@ def test_state_machine_compose_state():
) )
@pytest.mark.state_machine
def test_stage_to_classname(): def test_stage_to_classname():
assert ( assert (
_stage_to_classname(AzureStages.BILLING_PROFILE_CREATION.name) _stage_to_classname(AzureStages.BILLING_PROFILE_CREATION.name)
@ -71,16 +103,19 @@ def test_stage_to_classname():
) )
@pytest.mark.state_machine
def test_get_stage_csp_class(): def test_get_stage_csp_class():
csp_class = get_stage_csp_class(list(AzureStages)[0].name.lower(), "payload") csp_class = get_stage_csp_class(list(AzureStages)[0].name.lower(), "payload")
assert isinstance(csp_class, pydantic.main.ModelMetaclass) assert isinstance(csp_class, pydantic.main.ModelMetaclass)
@pytest.mark.state_machine
def test_get_stage_csp_class_import_fail(): def test_get_stage_csp_class_import_fail():
with pytest.raises(StateMachineMisconfiguredError): with pytest.raises(StateMachineMisconfiguredError):
csp_class = get_stage_csp_class("doesnotexist", "payload") csp_class = get_stage_csp_class("doesnotexist", "payload")
@pytest.mark.state_machine
def test_build_transitions(): def test_build_transitions():
states, transitions = _build_transitions(AzureStagesTest) states, transitions = _build_transitions(AzureStagesTest)
assert [s.get("name").name for s in states] == [ assert [s.get("name").name for s in states] == [
@ -102,6 +137,7 @@ def test_build_transitions():
] ]
@pytest.mark.state_machine
def test_build_csp_states(): def test_build_csp_states():
states = _build_csp_states(AzureStagesTest) states = _build_csp_states(AzureStagesTest)
assert list(states) == [ assert list(states) == [
@ -116,18 +152,18 @@ def test_build_csp_states():
] ]
def test_state_machine_valid_data_classes_for_stages(portfolio): @pytest.mark.state_machine
PortfolioStateMachineFactory.create(portfolio=portfolio) def test_state_machine_valid_data_classes_for_stages():
for stage in AzureStages: for stage in AzureStages:
assert get_stage_csp_class(stage.name.lower(), "payload") is not None assert get_stage_csp_class(stage.name.lower(), "payload") is not None
assert get_stage_csp_class(stage.name.lower(), "result") is not None assert get_stage_csp_class(stage.name.lower(), "result") is not None
def test_attach_machine(portfolio): @pytest.mark.state_machine
sm = PortfolioStateMachineFactory.create(portfolio=portfolio) def test_attach_machine(state_machine):
sm.machine = None state_machine.machine = None
sm.attach_machine(stages=AzureStagesTest) state_machine.attach_machine(stages=AzureStagesTest)
assert list(sm.machine.events) == [ assert list(state_machine.machine.events) == [
"init", "init",
"start", "start",
"reset", "reset",
@ -140,13 +176,69 @@ def test_attach_machine(portfolio):
] ]
def test_fail_stage(portfolio): @pytest.mark.state_machine
sm = PortfolioStateMachineFactory.create(portfolio=portfolio) def test_current_state_property(state_machine):
sm.state = FSMStates.TENANT_IN_PROGRESS assert state_machine.current_state == FSMStates.UNSTARTED
sm.fail_stage("tenant") state_machine.state = FSMStates.TENANT_IN_PROGRESS
assert sm.state == FSMStates.TENANT_FAILED assert state_machine.current_state == FSMStates.TENANT_IN_PROGRESS
state_machine.state = "UNSTARTED"
assert state_machine.current_state == FSMStates.UNSTARTED
@pytest.mark.state_machine
def test_fail_stage(state_machine):
state_machine.state = FSMStates.TENANT_IN_PROGRESS
state_machine.portfolio.csp_data = {}
state_machine.fail_stage("tenant")
assert state_machine.state == FSMStates.TENANT_FAILED
@pytest.mark.state_machine
def test_fail_stage_invalid_triggers(state_machine):
state_machine.state = FSMStates.TENANT_IN_PROGRESS
state_machine.portfolio.csp_data = {}
state_machine.machine.get_triggers = Mock(return_value=["some", "triggers", "here"])
state_machine.fail_stage("tenant")
assert state_machine.state == FSMStates.TENANT_IN_PROGRESS
@pytest.mark.state_machine
def test_fail_stage_invalid_stage(state_machine):
state_machine.state = FSMStates.TENANT_IN_PROGRESS
state_machine.portfolio.csp_data = {}
portfolio.csp_data = {}
state_machine.fail_stage("invalid stage")
assert state_machine.state == FSMStates.TENANT_IN_PROGRESS
@pytest.mark.state_machine
def test_finish_stage(state_machine):
state_machine.state = FSMStates.TENANT_IN_PROGRESS
state_machine.portfolio.csp_data = {}
state_machine.finish_stage("tenant")
assert state_machine.state == FSMStates.TENANT_CREATED
@pytest.mark.state_machine
def test_finish_stage_invalid_triggers(state_machine):
state_machine.state = FSMStates.TENANT_IN_PROGRESS
state_machine.portfolio.csp_data = {}
state_machine.machine.get_triggers = Mock(return_value=["some", "triggers", "here"])
state_machine.finish_stage("tenant")
assert state_machine.state == FSMStates.TENANT_IN_PROGRESS
@pytest.mark.state_machine
def test_finish_stage_invalid_stage(state_machine):
state_machine.state = FSMStates.TENANT_IN_PROGRESS
state_machine.portfolio.csp_data = {}
portfolio.csp_data = {}
state_machine.finish_stage("invalid stage")
assert state_machine.state == FSMStates.TENANT_IN_PROGRESS
@pytest.mark.state_machine
def test_stage_state_to_stage_name(): def test_stage_state_to_stage_name():
stage = _stage_state_to_stage_name( stage = _stage_state_to_stage_name(
FSMStates.TENANT_IN_PROGRESS, StageStates.IN_PROGRESS FSMStates.TENANT_IN_PROGRESS, StageStates.IN_PROGRESS
@ -154,18 +246,19 @@ def test_stage_state_to_stage_name():
assert stage == "tenant" assert stage == "tenant"
def test_state_machine_initialization(portfolio): @pytest.mark.state_machine
def test_state_machine_initialization(state_machine):
sm = PortfolioStateMachineFactory.create(portfolio=portfolio)
for stage in AzureStages: for stage in AzureStages:
# check that all stages have a 'create' and 'fail' triggers # check that all stages have a 'create' and 'fail' triggers
stage_name = stage.name.lower() stage_name = stage.name.lower()
for trigger_prefix in ["create", "fail"]: for trigger_prefix in ["create", "fail"]:
assert hasattr(sm, trigger_prefix + "_" + stage_name) assert hasattr(state_machine, trigger_prefix + "_" + stage_name)
# check that machine # check that machine
in_progress_triggers = sm.machine.get_triggers(stage.name + "_IN_PROGRESS") in_progress_triggers = state_machine.machine.get_triggers(
stage.name + "_IN_PROGRESS"
)
assert [ assert [
"reset", "reset",
"fail", "fail",
@ -173,32 +266,29 @@ def test_state_machine_initialization(portfolio):
"fail_" + stage_name, "fail_" + stage_name,
] == in_progress_triggers ] == in_progress_triggers
started_triggers = sm.machine.get_triggers("STARTED") started_triggers = state_machine.machine.get_triggers("STARTED")
create_trigger = next( create_trigger = next(
filter( filter(
lambda trigger: trigger.startswith("create_"), lambda trigger: trigger.startswith("create_"),
sm.machine.get_triggers(FSMStates.STARTED.name), state_machine.machine.get_triggers(FSMStates.STARTED.name),
), ),
None, None,
) )
assert ["reset", "fail", create_trigger] == started_triggers assert ["reset", "fail", create_trigger] == started_triggers
@mock.patch("atst.domain.csp.cloud.MockCloudProvider") @pytest.mark.state_machine
def test_fsm_transition_start(mock_cloud_provider, portfolio: Portfolio): @patch("atst.domain.csp.cloud.MockCloudProvider")
def test_fsm_transition_start(
mock_cloud_provider, state_machine: PortfolioStateMachine
):
mock_cloud_provider._authorize.return_value = None mock_cloud_provider._authorize.return_value = None
mock_cloud_provider._maybe_raise.return_value = None mock_cloud_provider._maybe_raise.return_value = None
sm: PortfolioStateMachine = PortfolioStateMachineFactory.create(portfolio=portfolio) portfolio = state_machine.portfolio
assert sm.portfolio
assert sm.state == FSMStates.UNSTARTED
sm.init()
assert sm.state == FSMStates.STARTING
sm.start()
assert sm.state == FSMStates.STARTED
expected_states = [ expected_states = [
FSMStates.STARTING,
FSMStates.STARTED,
FSMStates.TENANT_CREATED, FSMStates.TENANT_CREATED,
FSMStates.BILLING_PROFILE_CREATION_CREATED, FSMStates.BILLING_PROFILE_CREATION_CREATED,
FSMStates.BILLING_PROFILE_VERIFICATION_CREATED, FSMStates.BILLING_PROFILE_VERIFICATION_CREATED,
@ -226,47 +316,17 @@ def test_fsm_transition_start(mock_cloud_provider, portfolio: Portfolio):
else: else:
csp_data = {} csp_data = {}
ppoc = portfolio.owner
user_id = f"{ppoc.first_name[0]}{ppoc.last_name}".lower()
domain_name = re.sub("[^0-9a-zA-Z]+", "", portfolio.name).lower()
initial_task_order: TaskOrder = portfolio.task_orders[0]
initial_clin = initial_task_order.sorted_clins[0]
portfolio_data = {
"user_id": user_id,
"password": "jklfsdNCVD83nklds2#202", # pragma: allowlist secret
"domain_name": domain_name,
"display_name": "mgmt group display name",
"management_group_name": "mgmt-group-uuid",
"first_name": ppoc.first_name,
"last_name": ppoc.last_name,
"country_code": "US",
"password_recovery_email_address": ppoc.email,
"address": { # TODO: TBD if we're sourcing this from data or config
"company_name": "",
"address_line_1": "",
"city": "",
"region": "",
"country": "",
"postal_code": "",
},
"billing_profile_display_name": "My Billing Profile",
"initial_clin_amount": initial_clin.obligated_amount,
"initial_clin_start_date": initial_clin.start_date.strftime("%Y/%m/%d"),
"initial_clin_end_date": initial_clin.end_date.strftime("%Y/%m/%d"),
"initial_clin_type": initial_clin.number,
"initial_task_order_id": initial_task_order.number,
}
config = {"billing_account_name": "billing_account_name"} config = {"billing_account_name": "billing_account_name"}
assert state_machine.state == FSMStates.UNSTARTED
portfolio_data = get_portfolio_csp_data(portfolio)
for expected_state in expected_states: for expected_state in expected_states:
collected_data = dict( collected_data = dict(
list(csp_data.items()) + list(portfolio_data.items()) + list(config.items()) list(csp_data.items()) + list(portfolio_data.items()) + list(config.items())
) )
sm.trigger_next_transition(csp_data=collected_data) state_machine.trigger_next_transition(csp_data=collected_data)
assert sm.state == expected_state assert state_machine.state == expected_state
if portfolio.csp_data is not None: if portfolio.csp_data is not None:
csp_data = portfolio.csp_data csp_data = portfolio.csp_data
else: else:

View File

@ -1,3 +1,4 @@
import re
import operator import operator
import random import random
import string import string
@ -78,6 +79,49 @@ def get_all_portfolio_permission_sets():
return PermissionSets.get_many(PortfolioRoles.PORTFOLIO_PERMISSION_SETS) return PermissionSets.get_many(PortfolioRoles.PORTFOLIO_PERMISSION_SETS)
def get_portfolio_csp_data(portfolio):
ppoc = portfolio.owner
if not ppoc:
class ppoc:
first_name = "John"
last_name = "Doe"
email = "email@example.com"
user_id = f"{ppoc.first_name[0]}{ppoc.last_name}".lower()
domain_name = re.sub("[^0-9a-zA-Z]+", "", portfolio.name).lower()
initial_task_order: TaskOrder = portfolio.task_orders[0]
initial_clin = initial_task_order.sorted_clins[0]
return {
"user_id": user_id,
"password": "jklfsdNCVD83nklds2#202", # pragma: allowlist secret
"domain_name": domain_name,
"display_name": "mgmt group display name",
"management_group_name": "mgmt-group-uuid",
"first_name": ppoc.first_name,
"last_name": ppoc.last_name,
"country_code": "US",
"password_recovery_email_address": ppoc.email,
"address": { # TODO: TBD if we're sourcing this from data or config
"company_name": "",
"address_line_1": "",
"city": "",
"region": "",
"country": "",
"postal_code": "",
},
"billing_profile_display_name": "My Billing Profile",
"initial_clin_amount": initial_clin.obligated_amount,
"initial_clin_start_date": initial_clin.start_date.strftime("%Y/%m/%d"),
"initial_clin_end_date": initial_clin.end_date.strftime("%Y/%m/%d"),
"initial_clin_type": initial_clin.number,
"initial_task_order_id": initial_task_order.number,
}
class Base(factory.alchemy.SQLAlchemyModelFactory): class Base(factory.alchemy.SQLAlchemyModelFactory):
@classmethod @classmethod
def dictionary(cls, **attrs): def dictionary(cls, **attrs):