diff --git a/atst/models/mixins/state_machines.py b/atst/models/mixins/state_machines.py index 8c941ffe..bd834c3d 100644 --- a/atst/models/mixins/state_machines.py +++ b/atst/models/mixins/state_machines.py @@ -139,7 +139,6 @@ class FSMMixin: def fail_stage(self, stage): fail_trigger = f"fail_{stage}" - if fail_trigger in self.machine.get_triggers(self.current_state.name): self.trigger(fail_trigger) app.logger.info( @@ -157,30 +156,3 @@ class FSMMixin: f"calling finish trigger '{finish_trigger}' for '{self.__repr__()}'" ) self.trigger(finish_trigger) - - def prepare_init(self, event): - pass - - def before_init(self, event): - pass - - def after_init(self, event): - pass - - def prepare_start(self, event): - pass - - def before_start(self, event): - pass - - def after_start(self, event): - pass - - def prepare_reset(self, event): - pass - - def before_reset(self, event): - pass - - def after_reset(self, event): - pass diff --git a/tests/domain/test_portfolio_state_machine.py b/tests/domain/test_portfolio_state_machine.py index a91a63c3..23664079 100644 --- a/tests/domain/test_portfolio_state_machine.py +++ b/tests/domain/test_portfolio_state_machine.py @@ -1,33 +1,39 @@ import pytest import pydantic -import re -from unittest import mock +from unittest.mock import Mock, patch from enum import Enum +from unittest.mock import Mock, patch +import pendulum +import pydantic +import pytest from tests.factories import ( - PortfolioStateMachineFactory, + ApplicationFactory, CLINFactory, + PortfolioFactory, + PortfolioStateMachineFactory, + TaskOrderFactory, + UserFactory, + get_portfolio_csp_data, ) from atst.models import FSMStates, PortfolioStateMachine, TaskOrder -from atst.domain.csp.cloud.models import BillingProfileCreationCSPPayload from atst.models.mixins.state_machines import ( AzureStages, StageStates, - compose_state, - _build_transitions, _build_csp_states, + _build_transitions, + compose_state, ) from atst.models.portfolio import Portfolio from atst.models.portfolio_state_machine import ( - get_stage_csp_class, - _stage_to_classname, - _stage_state_to_stage_name, StateMachineMisconfiguredError, + _stage_state_to_stage_name, + _stage_to_classname, + get_stage_csp_class, ) - # TODO: Write failure case tests @@ -37,26 +43,51 @@ class AzureStagesTest(Enum): @pytest.fixture(scope="function") def portfolio(): - # TODO: setup clin/to as active/funded/ready - portfolio = CLINFactory.create().task_order.portfolio + today = pendulum.today() + yesterday = today.subtract(days=1) + future = today.add(days=100) + + owner = UserFactory.create() + portfolio = PortfolioFactory.create(owner=owner) + ApplicationFactory.create(portfolio=portfolio, environments=[{"name": "dev"}]) + + TaskOrderFactory.create( + portfolio=portfolio, + signed_at=yesterday, + clins=[CLINFactory.create(start_date=yesterday, end_date=future)], + ) + return portfolio +@pytest.fixture(scope="function") +def state_machine(portfolio): + return PortfolioStateMachineFactory.create(portfolio=portfolio) + + +@pytest.mark.state_machine def test_fsm_creation(portfolio): sm = PortfolioStateMachineFactory.create(portfolio=portfolio) assert sm.portfolio -def test_state_machine_trigger_next_transition(portfolio): - sm = PortfolioStateMachineFactory.create(portfolio=portfolio) +@pytest.mark.state_machine +def test_state_machine_trigger_next_transition(state_machine): - sm.trigger_next_transition() - assert sm.current_state == FSMStates.STARTING + state_machine.trigger_next_transition() + assert state_machine.current_state == FSMStates.STARTING - sm.trigger_next_transition() - assert sm.current_state == FSMStates.STARTED + state_machine.trigger_next_transition() + assert state_machine.current_state == FSMStates.STARTED + + state_machine.state = FSMStates.STARTED + state_machine.trigger_next_transition( + csp_data=get_portfolio_csp_data(state_machine.portfolio) + ) + assert state_machine.current_state == FSMStates.TENANT_CREATED +@pytest.mark.state_machine def test_state_machine_compose_state(): assert ( compose_state(AzureStages.TENANT, StageStates.CREATED) @@ -64,6 +95,7 @@ def test_state_machine_compose_state(): ) +@pytest.mark.state_machine def test_stage_to_classname(): assert ( _stage_to_classname(AzureStages.BILLING_PROFILE_CREATION.name) @@ -71,16 +103,19 @@ def test_stage_to_classname(): ) +@pytest.mark.state_machine def test_get_stage_csp_class(): csp_class = get_stage_csp_class(list(AzureStages)[0].name.lower(), "payload") assert isinstance(csp_class, pydantic.main.ModelMetaclass) +@pytest.mark.state_machine def test_get_stage_csp_class_import_fail(): with pytest.raises(StateMachineMisconfiguredError): csp_class = get_stage_csp_class("doesnotexist", "payload") +@pytest.mark.state_machine def test_build_transitions(): states, transitions = _build_transitions(AzureStagesTest) assert [s.get("name").name for s in states] == [ @@ -102,6 +137,7 @@ def test_build_transitions(): ] +@pytest.mark.state_machine def test_build_csp_states(): states = _build_csp_states(AzureStagesTest) assert list(states) == [ @@ -116,18 +152,18 @@ def test_build_csp_states(): ] -def test_state_machine_valid_data_classes_for_stages(portfolio): - PortfolioStateMachineFactory.create(portfolio=portfolio) +@pytest.mark.state_machine +def test_state_machine_valid_data_classes_for_stages(): for stage in AzureStages: assert get_stage_csp_class(stage.name.lower(), "payload") is not None assert get_stage_csp_class(stage.name.lower(), "result") is not None -def test_attach_machine(portfolio): - sm = PortfolioStateMachineFactory.create(portfolio=portfolio) - sm.machine = None - sm.attach_machine(stages=AzureStagesTest) - assert list(sm.machine.events) == [ +@pytest.mark.state_machine +def test_attach_machine(state_machine): + state_machine.machine = None + state_machine.attach_machine(stages=AzureStagesTest) + assert list(state_machine.machine.events) == [ "init", "start", "reset", @@ -140,13 +176,69 @@ def test_attach_machine(portfolio): ] -def test_fail_stage(portfolio): - sm = PortfolioStateMachineFactory.create(portfolio=portfolio) - sm.state = FSMStates.TENANT_IN_PROGRESS - sm.fail_stage("tenant") - assert sm.state == FSMStates.TENANT_FAILED +@pytest.mark.state_machine +def test_current_state_property(state_machine): + assert state_machine.current_state == FSMStates.UNSTARTED + state_machine.state = FSMStates.TENANT_IN_PROGRESS + assert state_machine.current_state == FSMStates.TENANT_IN_PROGRESS + state_machine.state = "UNSTARTED" + assert state_machine.current_state == FSMStates.UNSTARTED +@pytest.mark.state_machine +def test_fail_stage(state_machine): + state_machine.state = FSMStates.TENANT_IN_PROGRESS + state_machine.portfolio.csp_data = {} + state_machine.fail_stage("tenant") + assert state_machine.state == FSMStates.TENANT_FAILED + + +@pytest.mark.state_machine +def test_fail_stage_invalid_triggers(state_machine): + state_machine.state = FSMStates.TENANT_IN_PROGRESS + state_machine.portfolio.csp_data = {} + state_machine.machine.get_triggers = Mock(return_value=["some", "triggers", "here"]) + state_machine.fail_stage("tenant") + assert state_machine.state == FSMStates.TENANT_IN_PROGRESS + + +@pytest.mark.state_machine +def test_fail_stage_invalid_stage(state_machine): + state_machine.state = FSMStates.TENANT_IN_PROGRESS + state_machine.portfolio.csp_data = {} + portfolio.csp_data = {} + state_machine.fail_stage("invalid stage") + assert state_machine.state == FSMStates.TENANT_IN_PROGRESS + + +@pytest.mark.state_machine +def test_finish_stage(state_machine): + state_machine.state = FSMStates.TENANT_IN_PROGRESS + state_machine.portfolio.csp_data = {} + state_machine.finish_stage("tenant") + assert state_machine.state == FSMStates.TENANT_CREATED + + +@pytest.mark.state_machine +def test_finish_stage_invalid_triggers(state_machine): + state_machine.state = FSMStates.TENANT_IN_PROGRESS + state_machine.portfolio.csp_data = {} + + state_machine.machine.get_triggers = Mock(return_value=["some", "triggers", "here"]) + state_machine.finish_stage("tenant") + assert state_machine.state == FSMStates.TENANT_IN_PROGRESS + + +@pytest.mark.state_machine +def test_finish_stage_invalid_stage(state_machine): + state_machine.state = FSMStates.TENANT_IN_PROGRESS + state_machine.portfolio.csp_data = {} + portfolio.csp_data = {} + state_machine.finish_stage("invalid stage") + assert state_machine.state == FSMStates.TENANT_IN_PROGRESS + + +@pytest.mark.state_machine def test_stage_state_to_stage_name(): stage = _stage_state_to_stage_name( FSMStates.TENANT_IN_PROGRESS, StageStates.IN_PROGRESS @@ -154,18 +246,19 @@ def test_stage_state_to_stage_name(): assert stage == "tenant" -def test_state_machine_initialization(portfolio): - - sm = PortfolioStateMachineFactory.create(portfolio=portfolio) +@pytest.mark.state_machine +def test_state_machine_initialization(state_machine): for stage in AzureStages: # check that all stages have a 'create' and 'fail' triggers stage_name = stage.name.lower() for trigger_prefix in ["create", "fail"]: - assert hasattr(sm, trigger_prefix + "_" + stage_name) + assert hasattr(state_machine, trigger_prefix + "_" + stage_name) # check that machine - in_progress_triggers = sm.machine.get_triggers(stage.name + "_IN_PROGRESS") + in_progress_triggers = state_machine.machine.get_triggers( + stage.name + "_IN_PROGRESS" + ) assert [ "reset", "fail", @@ -173,32 +266,29 @@ def test_state_machine_initialization(portfolio): "fail_" + stage_name, ] == in_progress_triggers - started_triggers = sm.machine.get_triggers("STARTED") + started_triggers = state_machine.machine.get_triggers("STARTED") create_trigger = next( filter( lambda trigger: trigger.startswith("create_"), - sm.machine.get_triggers(FSMStates.STARTED.name), + state_machine.machine.get_triggers(FSMStates.STARTED.name), ), None, ) assert ["reset", "fail", create_trigger] == started_triggers -@mock.patch("atst.domain.csp.cloud.MockCloudProvider") -def test_fsm_transition_start(mock_cloud_provider, portfolio: Portfolio): +@pytest.mark.state_machine +@patch("atst.domain.csp.cloud.MockCloudProvider") +def test_fsm_transition_start( + mock_cloud_provider, state_machine: PortfolioStateMachine +): mock_cloud_provider._authorize.return_value = None mock_cloud_provider._maybe_raise.return_value = None - sm: PortfolioStateMachine = PortfolioStateMachineFactory.create(portfolio=portfolio) - assert sm.portfolio - assert sm.state == FSMStates.UNSTARTED - - sm.init() - assert sm.state == FSMStates.STARTING - - sm.start() - assert sm.state == FSMStates.STARTED + portfolio = state_machine.portfolio expected_states = [ + FSMStates.STARTING, + FSMStates.STARTED, FSMStates.TENANT_CREATED, FSMStates.BILLING_PROFILE_CREATION_CREATED, FSMStates.BILLING_PROFILE_VERIFICATION_CREATED, @@ -226,47 +316,17 @@ def test_fsm_transition_start(mock_cloud_provider, portfolio: Portfolio): else: csp_data = {} - ppoc = portfolio.owner - user_id = f"{ppoc.first_name[0]}{ppoc.last_name}".lower() - domain_name = re.sub("[^0-9a-zA-Z]+", "", portfolio.name).lower() - - initial_task_order: TaskOrder = portfolio.task_orders[0] - initial_clin = initial_task_order.sorted_clins[0] - - portfolio_data = { - "user_id": user_id, - "password": "jklfsdNCVD83nklds2#202", # pragma: allowlist secret - "domain_name": domain_name, - "display_name": "mgmt group display name", - "management_group_name": "mgmt-group-uuid", - "first_name": ppoc.first_name, - "last_name": ppoc.last_name, - "country_code": "US", - "password_recovery_email_address": ppoc.email, - "address": { # TODO: TBD if we're sourcing this from data or config - "company_name": "", - "address_line_1": "", - "city": "", - "region": "", - "country": "", - "postal_code": "", - }, - "billing_profile_display_name": "My Billing Profile", - "initial_clin_amount": initial_clin.obligated_amount, - "initial_clin_start_date": initial_clin.start_date.strftime("%Y/%m/%d"), - "initial_clin_end_date": initial_clin.end_date.strftime("%Y/%m/%d"), - "initial_clin_type": initial_clin.number, - "initial_task_order_id": initial_task_order.number, - } - config = {"billing_account_name": "billing_account_name"} + assert state_machine.state == FSMStates.UNSTARTED + portfolio_data = get_portfolio_csp_data(portfolio) + for expected_state in expected_states: collected_data = dict( list(csp_data.items()) + list(portfolio_data.items()) + list(config.items()) ) - sm.trigger_next_transition(csp_data=collected_data) - assert sm.state == expected_state + state_machine.trigger_next_transition(csp_data=collected_data) + assert state_machine.state == expected_state if portfolio.csp_data is not None: csp_data = portfolio.csp_data else: diff --git a/tests/factories.py b/tests/factories.py index ee548d40..fa857a3c 100644 --- a/tests/factories.py +++ b/tests/factories.py @@ -1,3 +1,4 @@ +import re import operator import random import string @@ -78,6 +79,49 @@ def get_all_portfolio_permission_sets(): return PermissionSets.get_many(PortfolioRoles.PORTFOLIO_PERMISSION_SETS) +def get_portfolio_csp_data(portfolio): + + ppoc = portfolio.owner + if not ppoc: + + class ppoc: + first_name = "John" + last_name = "Doe" + email = "email@example.com" + + user_id = f"{ppoc.first_name[0]}{ppoc.last_name}".lower() + domain_name = re.sub("[^0-9a-zA-Z]+", "", portfolio.name).lower() + + initial_task_order: TaskOrder = portfolio.task_orders[0] + initial_clin = initial_task_order.sorted_clins[0] + + return { + "user_id": user_id, + "password": "jklfsdNCVD83nklds2#202", # pragma: allowlist secret + "domain_name": domain_name, + "display_name": "mgmt group display name", + "management_group_name": "mgmt-group-uuid", + "first_name": ppoc.first_name, + "last_name": ppoc.last_name, + "country_code": "US", + "password_recovery_email_address": ppoc.email, + "address": { # TODO: TBD if we're sourcing this from data or config + "company_name": "", + "address_line_1": "", + "city": "", + "region": "", + "country": "", + "postal_code": "", + }, + "billing_profile_display_name": "My Billing Profile", + "initial_clin_amount": initial_clin.obligated_amount, + "initial_clin_start_date": initial_clin.start_date.strftime("%Y/%m/%d"), + "initial_clin_end_date": initial_clin.end_date.strftime("%Y/%m/%d"), + "initial_clin_type": initial_clin.number, + "initial_task_order_id": initial_task_order.number, + } + + class Base(factory.alchemy.SQLAlchemyModelFactory): @classmethod def dictionary(cls, **attrs):