Merge pull request #1433 from dod-ccpo/state-machine-unit-tests
state machine unit tests
This commit is contained in:
commit
08ca8eac79
@ -139,7 +139,6 @@ class FSMMixin:
|
||||
|
||||
def fail_stage(self, stage):
|
||||
fail_trigger = f"fail_{stage}"
|
||||
|
||||
if fail_trigger in self.machine.get_triggers(self.current_state.name):
|
||||
self.trigger(fail_trigger)
|
||||
app.logger.info(
|
||||
@ -157,30 +156,3 @@ class FSMMixin:
|
||||
f"calling finish trigger '{finish_trigger}' for '{self.__repr__()}'"
|
||||
)
|
||||
self.trigger(finish_trigger)
|
||||
|
||||
def prepare_init(self, event):
|
||||
pass
|
||||
|
||||
def before_init(self, event):
|
||||
pass
|
||||
|
||||
def after_init(self, event):
|
||||
pass
|
||||
|
||||
def prepare_start(self, event):
|
||||
pass
|
||||
|
||||
def before_start(self, event):
|
||||
pass
|
||||
|
||||
def after_start(self, event):
|
||||
pass
|
||||
|
||||
def prepare_reset(self, event):
|
||||
pass
|
||||
|
||||
def before_reset(self, event):
|
||||
pass
|
||||
|
||||
def after_reset(self, event):
|
||||
pass
|
||||
|
@ -1,33 +1,39 @@
|
||||
import pytest
|
||||
import pydantic
|
||||
|
||||
import re
|
||||
from unittest import mock
|
||||
from unittest.mock import Mock, patch
|
||||
from enum import Enum
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pendulum
|
||||
import pydantic
|
||||
import pytest
|
||||
from tests.factories import (
|
||||
PortfolioStateMachineFactory,
|
||||
ApplicationFactory,
|
||||
CLINFactory,
|
||||
PortfolioFactory,
|
||||
PortfolioStateMachineFactory,
|
||||
TaskOrderFactory,
|
||||
UserFactory,
|
||||
get_portfolio_csp_data,
|
||||
)
|
||||
|
||||
from atst.models import FSMStates, PortfolioStateMachine, TaskOrder
|
||||
from atst.domain.csp.cloud.models import BillingProfileCreationCSPPayload
|
||||
from atst.models.mixins.state_machines import (
|
||||
AzureStages,
|
||||
StageStates,
|
||||
compose_state,
|
||||
_build_transitions,
|
||||
_build_csp_states,
|
||||
_build_transitions,
|
||||
compose_state,
|
||||
)
|
||||
from atst.models.portfolio import Portfolio
|
||||
from atst.models.portfolio_state_machine import (
|
||||
get_stage_csp_class,
|
||||
_stage_to_classname,
|
||||
_stage_state_to_stage_name,
|
||||
StateMachineMisconfiguredError,
|
||||
_stage_state_to_stage_name,
|
||||
_stage_to_classname,
|
||||
get_stage_csp_class,
|
||||
)
|
||||
|
||||
|
||||
# TODO: Write failure case tests
|
||||
|
||||
|
||||
@ -37,26 +43,51 @@ class AzureStagesTest(Enum):
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def portfolio():
|
||||
# TODO: setup clin/to as active/funded/ready
|
||||
portfolio = CLINFactory.create().task_order.portfolio
|
||||
today = pendulum.today()
|
||||
yesterday = today.subtract(days=1)
|
||||
future = today.add(days=100)
|
||||
|
||||
owner = UserFactory.create()
|
||||
portfolio = PortfolioFactory.create(owner=owner)
|
||||
ApplicationFactory.create(portfolio=portfolio, environments=[{"name": "dev"}])
|
||||
|
||||
TaskOrderFactory.create(
|
||||
portfolio=portfolio,
|
||||
signed_at=yesterday,
|
||||
clins=[CLINFactory.create(start_date=yesterday, end_date=future)],
|
||||
)
|
||||
|
||||
return portfolio
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def state_machine(portfolio):
|
||||
return PortfolioStateMachineFactory.create(portfolio=portfolio)
|
||||
|
||||
|
||||
@pytest.mark.state_machine
|
||||
def test_fsm_creation(portfolio):
|
||||
sm = PortfolioStateMachineFactory.create(portfolio=portfolio)
|
||||
assert sm.portfolio
|
||||
|
||||
|
||||
def test_state_machine_trigger_next_transition(portfolio):
|
||||
sm = PortfolioStateMachineFactory.create(portfolio=portfolio)
|
||||
@pytest.mark.state_machine
|
||||
def test_state_machine_trigger_next_transition(state_machine):
|
||||
|
||||
sm.trigger_next_transition()
|
||||
assert sm.current_state == FSMStates.STARTING
|
||||
state_machine.trigger_next_transition()
|
||||
assert state_machine.current_state == FSMStates.STARTING
|
||||
|
||||
sm.trigger_next_transition()
|
||||
assert sm.current_state == FSMStates.STARTED
|
||||
state_machine.trigger_next_transition()
|
||||
assert state_machine.current_state == FSMStates.STARTED
|
||||
|
||||
state_machine.state = FSMStates.STARTED
|
||||
state_machine.trigger_next_transition(
|
||||
csp_data=get_portfolio_csp_data(state_machine.portfolio)
|
||||
)
|
||||
assert state_machine.current_state == FSMStates.TENANT_CREATED
|
||||
|
||||
|
||||
@pytest.mark.state_machine
|
||||
def test_state_machine_compose_state():
|
||||
assert (
|
||||
compose_state(AzureStages.TENANT, StageStates.CREATED)
|
||||
@ -64,6 +95,7 @@ def test_state_machine_compose_state():
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.state_machine
|
||||
def test_stage_to_classname():
|
||||
assert (
|
||||
_stage_to_classname(AzureStages.BILLING_PROFILE_CREATION.name)
|
||||
@ -71,16 +103,19 @@ def test_stage_to_classname():
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.state_machine
|
||||
def test_get_stage_csp_class():
|
||||
csp_class = get_stage_csp_class(list(AzureStages)[0].name.lower(), "payload")
|
||||
assert isinstance(csp_class, pydantic.main.ModelMetaclass)
|
||||
|
||||
|
||||
@pytest.mark.state_machine
|
||||
def test_get_stage_csp_class_import_fail():
|
||||
with pytest.raises(StateMachineMisconfiguredError):
|
||||
csp_class = get_stage_csp_class("doesnotexist", "payload")
|
||||
|
||||
|
||||
@pytest.mark.state_machine
|
||||
def test_build_transitions():
|
||||
states, transitions = _build_transitions(AzureStagesTest)
|
||||
assert [s.get("name").name for s in states] == [
|
||||
@ -102,6 +137,7 @@ def test_build_transitions():
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.state_machine
|
||||
def test_build_csp_states():
|
||||
states = _build_csp_states(AzureStagesTest)
|
||||
assert list(states) == [
|
||||
@ -116,18 +152,18 @@ def test_build_csp_states():
|
||||
]
|
||||
|
||||
|
||||
def test_state_machine_valid_data_classes_for_stages(portfolio):
|
||||
PortfolioStateMachineFactory.create(portfolio=portfolio)
|
||||
@pytest.mark.state_machine
|
||||
def test_state_machine_valid_data_classes_for_stages():
|
||||
for stage in AzureStages:
|
||||
assert get_stage_csp_class(stage.name.lower(), "payload") is not None
|
||||
assert get_stage_csp_class(stage.name.lower(), "result") is not None
|
||||
|
||||
|
||||
def test_attach_machine(portfolio):
|
||||
sm = PortfolioStateMachineFactory.create(portfolio=portfolio)
|
||||
sm.machine = None
|
||||
sm.attach_machine(stages=AzureStagesTest)
|
||||
assert list(sm.machine.events) == [
|
||||
@pytest.mark.state_machine
|
||||
def test_attach_machine(state_machine):
|
||||
state_machine.machine = None
|
||||
state_machine.attach_machine(stages=AzureStagesTest)
|
||||
assert list(state_machine.machine.events) == [
|
||||
"init",
|
||||
"start",
|
||||
"reset",
|
||||
@ -140,13 +176,69 @@ def test_attach_machine(portfolio):
|
||||
]
|
||||
|
||||
|
||||
def test_fail_stage(portfolio):
|
||||
sm = PortfolioStateMachineFactory.create(portfolio=portfolio)
|
||||
sm.state = FSMStates.TENANT_IN_PROGRESS
|
||||
sm.fail_stage("tenant")
|
||||
assert sm.state == FSMStates.TENANT_FAILED
|
||||
@pytest.mark.state_machine
|
||||
def test_current_state_property(state_machine):
|
||||
assert state_machine.current_state == FSMStates.UNSTARTED
|
||||
state_machine.state = FSMStates.TENANT_IN_PROGRESS
|
||||
assert state_machine.current_state == FSMStates.TENANT_IN_PROGRESS
|
||||
state_machine.state = "UNSTARTED"
|
||||
assert state_machine.current_state == FSMStates.UNSTARTED
|
||||
|
||||
|
||||
@pytest.mark.state_machine
|
||||
def test_fail_stage(state_machine):
|
||||
state_machine.state = FSMStates.TENANT_IN_PROGRESS
|
||||
state_machine.portfolio.csp_data = {}
|
||||
state_machine.fail_stage("tenant")
|
||||
assert state_machine.state == FSMStates.TENANT_FAILED
|
||||
|
||||
|
||||
@pytest.mark.state_machine
|
||||
def test_fail_stage_invalid_triggers(state_machine):
|
||||
state_machine.state = FSMStates.TENANT_IN_PROGRESS
|
||||
state_machine.portfolio.csp_data = {}
|
||||
state_machine.machine.get_triggers = Mock(return_value=["some", "triggers", "here"])
|
||||
state_machine.fail_stage("tenant")
|
||||
assert state_machine.state == FSMStates.TENANT_IN_PROGRESS
|
||||
|
||||
|
||||
@pytest.mark.state_machine
|
||||
def test_fail_stage_invalid_stage(state_machine):
|
||||
state_machine.state = FSMStates.TENANT_IN_PROGRESS
|
||||
state_machine.portfolio.csp_data = {}
|
||||
portfolio.csp_data = {}
|
||||
state_machine.fail_stage("invalid stage")
|
||||
assert state_machine.state == FSMStates.TENANT_IN_PROGRESS
|
||||
|
||||
|
||||
@pytest.mark.state_machine
|
||||
def test_finish_stage(state_machine):
|
||||
state_machine.state = FSMStates.TENANT_IN_PROGRESS
|
||||
state_machine.portfolio.csp_data = {}
|
||||
state_machine.finish_stage("tenant")
|
||||
assert state_machine.state == FSMStates.TENANT_CREATED
|
||||
|
||||
|
||||
@pytest.mark.state_machine
|
||||
def test_finish_stage_invalid_triggers(state_machine):
|
||||
state_machine.state = FSMStates.TENANT_IN_PROGRESS
|
||||
state_machine.portfolio.csp_data = {}
|
||||
|
||||
state_machine.machine.get_triggers = Mock(return_value=["some", "triggers", "here"])
|
||||
state_machine.finish_stage("tenant")
|
||||
assert state_machine.state == FSMStates.TENANT_IN_PROGRESS
|
||||
|
||||
|
||||
@pytest.mark.state_machine
|
||||
def test_finish_stage_invalid_stage(state_machine):
|
||||
state_machine.state = FSMStates.TENANT_IN_PROGRESS
|
||||
state_machine.portfolio.csp_data = {}
|
||||
portfolio.csp_data = {}
|
||||
state_machine.finish_stage("invalid stage")
|
||||
assert state_machine.state == FSMStates.TENANT_IN_PROGRESS
|
||||
|
||||
|
||||
@pytest.mark.state_machine
|
||||
def test_stage_state_to_stage_name():
|
||||
stage = _stage_state_to_stage_name(
|
||||
FSMStates.TENANT_IN_PROGRESS, StageStates.IN_PROGRESS
|
||||
@ -154,18 +246,19 @@ def test_stage_state_to_stage_name():
|
||||
assert stage == "tenant"
|
||||
|
||||
|
||||
def test_state_machine_initialization(portfolio):
|
||||
|
||||
sm = PortfolioStateMachineFactory.create(portfolio=portfolio)
|
||||
@pytest.mark.state_machine
|
||||
def test_state_machine_initialization(state_machine):
|
||||
for stage in AzureStages:
|
||||
|
||||
# check that all stages have a 'create' and 'fail' triggers
|
||||
stage_name = stage.name.lower()
|
||||
for trigger_prefix in ["create", "fail"]:
|
||||
assert hasattr(sm, trigger_prefix + "_" + stage_name)
|
||||
assert hasattr(state_machine, trigger_prefix + "_" + stage_name)
|
||||
|
||||
# check that machine
|
||||
in_progress_triggers = sm.machine.get_triggers(stage.name + "_IN_PROGRESS")
|
||||
in_progress_triggers = state_machine.machine.get_triggers(
|
||||
stage.name + "_IN_PROGRESS"
|
||||
)
|
||||
assert [
|
||||
"reset",
|
||||
"fail",
|
||||
@ -173,32 +266,29 @@ def test_state_machine_initialization(portfolio):
|
||||
"fail_" + stage_name,
|
||||
] == in_progress_triggers
|
||||
|
||||
started_triggers = sm.machine.get_triggers("STARTED")
|
||||
started_triggers = state_machine.machine.get_triggers("STARTED")
|
||||
create_trigger = next(
|
||||
filter(
|
||||
lambda trigger: trigger.startswith("create_"),
|
||||
sm.machine.get_triggers(FSMStates.STARTED.name),
|
||||
state_machine.machine.get_triggers(FSMStates.STARTED.name),
|
||||
),
|
||||
None,
|
||||
)
|
||||
assert ["reset", "fail", create_trigger] == started_triggers
|
||||
|
||||
|
||||
@mock.patch("atst.domain.csp.cloud.MockCloudProvider")
|
||||
def test_fsm_transition_start(mock_cloud_provider, portfolio: Portfolio):
|
||||
@pytest.mark.state_machine
|
||||
@patch("atst.domain.csp.cloud.MockCloudProvider")
|
||||
def test_fsm_transition_start(
|
||||
mock_cloud_provider, state_machine: PortfolioStateMachine
|
||||
):
|
||||
mock_cloud_provider._authorize.return_value = None
|
||||
mock_cloud_provider._maybe_raise.return_value = None
|
||||
sm: PortfolioStateMachine = PortfolioStateMachineFactory.create(portfolio=portfolio)
|
||||
assert sm.portfolio
|
||||
assert sm.state == FSMStates.UNSTARTED
|
||||
|
||||
sm.init()
|
||||
assert sm.state == FSMStates.STARTING
|
||||
|
||||
sm.start()
|
||||
assert sm.state == FSMStates.STARTED
|
||||
portfolio = state_machine.portfolio
|
||||
|
||||
expected_states = [
|
||||
FSMStates.STARTING,
|
||||
FSMStates.STARTED,
|
||||
FSMStates.TENANT_CREATED,
|
||||
FSMStates.BILLING_PROFILE_CREATION_CREATED,
|
||||
FSMStates.BILLING_PROFILE_VERIFICATION_CREATED,
|
||||
@ -226,47 +316,17 @@ def test_fsm_transition_start(mock_cloud_provider, portfolio: Portfolio):
|
||||
else:
|
||||
csp_data = {}
|
||||
|
||||
ppoc = portfolio.owner
|
||||
user_id = f"{ppoc.first_name[0]}{ppoc.last_name}".lower()
|
||||
domain_name = re.sub("[^0-9a-zA-Z]+", "", portfolio.name).lower()
|
||||
|
||||
initial_task_order: TaskOrder = portfolio.task_orders[0]
|
||||
initial_clin = initial_task_order.sorted_clins[0]
|
||||
|
||||
portfolio_data = {
|
||||
"user_id": user_id,
|
||||
"password": "jklfsdNCVD83nklds2#202", # pragma: allowlist secret
|
||||
"domain_name": domain_name,
|
||||
"display_name": "mgmt group display name",
|
||||
"management_group_name": "mgmt-group-uuid",
|
||||
"first_name": ppoc.first_name,
|
||||
"last_name": ppoc.last_name,
|
||||
"country_code": "US",
|
||||
"password_recovery_email_address": ppoc.email,
|
||||
"address": { # TODO: TBD if we're sourcing this from data or config
|
||||
"company_name": "",
|
||||
"address_line_1": "",
|
||||
"city": "",
|
||||
"region": "",
|
||||
"country": "",
|
||||
"postal_code": "",
|
||||
},
|
||||
"billing_profile_display_name": "My Billing Profile",
|
||||
"initial_clin_amount": initial_clin.obligated_amount,
|
||||
"initial_clin_start_date": initial_clin.start_date.strftime("%Y/%m/%d"),
|
||||
"initial_clin_end_date": initial_clin.end_date.strftime("%Y/%m/%d"),
|
||||
"initial_clin_type": initial_clin.number,
|
||||
"initial_task_order_id": initial_task_order.number,
|
||||
}
|
||||
|
||||
config = {"billing_account_name": "billing_account_name"}
|
||||
|
||||
assert state_machine.state == FSMStates.UNSTARTED
|
||||
portfolio_data = get_portfolio_csp_data(portfolio)
|
||||
|
||||
for expected_state in expected_states:
|
||||
collected_data = dict(
|
||||
list(csp_data.items()) + list(portfolio_data.items()) + list(config.items())
|
||||
)
|
||||
sm.trigger_next_transition(csp_data=collected_data)
|
||||
assert sm.state == expected_state
|
||||
state_machine.trigger_next_transition(csp_data=collected_data)
|
||||
assert state_machine.state == expected_state
|
||||
if portfolio.csp_data is not None:
|
||||
csp_data = portfolio.csp_data
|
||||
else:
|
||||
|
@ -1,3 +1,4 @@
|
||||
import re
|
||||
import operator
|
||||
import random
|
||||
import string
|
||||
@ -78,6 +79,49 @@ def get_all_portfolio_permission_sets():
|
||||
return PermissionSets.get_many(PortfolioRoles.PORTFOLIO_PERMISSION_SETS)
|
||||
|
||||
|
||||
def get_portfolio_csp_data(portfolio):
|
||||
|
||||
ppoc = portfolio.owner
|
||||
if not ppoc:
|
||||
|
||||
class ppoc:
|
||||
first_name = "John"
|
||||
last_name = "Doe"
|
||||
email = "email@example.com"
|
||||
|
||||
user_id = f"{ppoc.first_name[0]}{ppoc.last_name}".lower()
|
||||
domain_name = re.sub("[^0-9a-zA-Z]+", "", portfolio.name).lower()
|
||||
|
||||
initial_task_order: TaskOrder = portfolio.task_orders[0]
|
||||
initial_clin = initial_task_order.sorted_clins[0]
|
||||
|
||||
return {
|
||||
"user_id": user_id,
|
||||
"password": "jklfsdNCVD83nklds2#202", # pragma: allowlist secret
|
||||
"domain_name": domain_name,
|
||||
"display_name": "mgmt group display name",
|
||||
"management_group_name": "mgmt-group-uuid",
|
||||
"first_name": ppoc.first_name,
|
||||
"last_name": ppoc.last_name,
|
||||
"country_code": "US",
|
||||
"password_recovery_email_address": ppoc.email,
|
||||
"address": { # TODO: TBD if we're sourcing this from data or config
|
||||
"company_name": "",
|
||||
"address_line_1": "",
|
||||
"city": "",
|
||||
"region": "",
|
||||
"country": "",
|
||||
"postal_code": "",
|
||||
},
|
||||
"billing_profile_display_name": "My Billing Profile",
|
||||
"initial_clin_amount": initial_clin.obligated_amount,
|
||||
"initial_clin_start_date": initial_clin.start_date.strftime("%Y/%m/%d"),
|
||||
"initial_clin_end_date": initial_clin.end_date.strftime("%Y/%m/%d"),
|
||||
"initial_clin_type": initial_clin.number,
|
||||
"initial_task_order_id": initial_task_order.number,
|
||||
}
|
||||
|
||||
|
||||
class Base(factory.alchemy.SQLAlchemyModelFactory):
|
||||
@classmethod
|
||||
def dictionary(cls, **attrs):
|
||||
|
Loading…
x
Reference in New Issue
Block a user