atst/atat/models/mixins/state_machines.py
2020-03-04 11:51:15 -05:00

159 lines
5.9 KiB
Python

from enum import Enum
from flask import current_app as app
class StageStates(Enum):
CREATED = "created"
IN_PROGRESS = "in progress"
FAILED = "failed"
class AzureStages(Enum):
TENANT = "tenant"
BILLING_PROFILE_CREATION = "billing profile creation"
BILLING_PROFILE_VERIFICATION = "billing profile verification"
BILLING_PROFILE_TENANT_ACCESS = "billing profile tenant access"
TASK_ORDER_BILLING_CREATION = "task order billing creation"
TASK_ORDER_BILLING_VERIFICATION = "task order billing verification"
BILLING_INSTRUCTION = "billing instruction"
PRODUCT_PURCHASE = "purchase aad premium product"
PRODUCT_PURCHASE_VERIFICATION = "purchase aad premium product verification"
TENANT_PRINCIPAL_APP = "tenant principal application"
TENANT_PRINCIPAL = "tenant principal"
TENANT_PRINCIPAL_CREDENTIAL = "tenant principal credential"
ADMIN_ROLE_DEFINITION = "admin role definition"
PRINCIPAL_ADMIN_ROLE = "tenant principal admin"
INITIAL_MGMT_GROUP = "initial management group"
INITIAL_MGMT_GROUP_VERIFICATION = "initial management group verification"
TENANT_ADMIN_OWNERSHIP = "tenant admin ownership"
TENANT_PRINCIPAL_OWNERSHIP = "tenant principial ownership"
BILLING_OWNER = "billing owner"
def _build_csp_states(csp_stages):
states = {
"UNSTARTED": "unstarted",
"STARTING": "starting",
"STARTED": "started",
"COMPLETED": "completed",
"FAILED": "failed",
}
for csp_stage in csp_stages:
for state in StageStates:
states[csp_stage.name + "_" + state.name] = (
csp_stage.value + " " + state.value
)
return states
FSMStates = Enum("FSMStates", _build_csp_states(AzureStages))
compose_state = lambda csp_stage, state: getattr(
FSMStates, "_".join([csp_stage.name, state.name])
)
def _build_transitions(csp_stages):
transitions = []
states = []
for stage_i, csp_stage in enumerate(csp_stages):
# the last CREATED stage has a transition to COMPLETED
if stage_i == len(csp_stages) - 1:
transitions.append(
dict(
trigger="complete",
source=compose_state(
list(csp_stages)[stage_i], StageStates.CREATED
),
dest=FSMStates.COMPLETED,
)
)
for state in StageStates:
states.append(
dict(
name=compose_state(csp_stage, state),
tags=[csp_stage.name, state.name],
)
)
if state == StageStates.CREATED:
if stage_i > 0:
src = compose_state(
list(csp_stages)[stage_i - 1], StageStates.CREATED
)
else:
src = FSMStates.STARTED
transitions.append(
dict(
trigger="create_" + csp_stage.name.lower(),
source=src,
dest=compose_state(csp_stage, StageStates.IN_PROGRESS),
after="after_in_progress_callback",
)
)
if state == StageStates.IN_PROGRESS:
transitions.append(
dict(
trigger="finish_" + csp_stage.name.lower(),
source=compose_state(csp_stage, state),
dest=compose_state(csp_stage, StageStates.CREATED),
conditions=["is_csp_data_valid"],
)
)
if state == StageStates.FAILED:
transitions.append(
dict(
trigger="fail_" + csp_stage.name.lower(),
source=compose_state(csp_stage, StageStates.IN_PROGRESS),
dest=compose_state(csp_stage, StageStates.FAILED),
)
)
transitions.append(
dict(
trigger="resume_progress_" + csp_stage.name.lower(),
source=compose_state(csp_stage, StageStates.FAILED),
dest=compose_state(csp_stage, StageStates.IN_PROGRESS),
conditions=["is_ready_resume_progress"],
)
)
return states, transitions
class FSMMixin:
system_states = [
{"name": FSMStates.UNSTARTED.name, "tags": ["system"]},
{"name": FSMStates.STARTING.name, "tags": ["system"]},
{"name": FSMStates.STARTED.name, "tags": ["system"]},
{"name": FSMStates.FAILED.name, "tags": ["system"]},
{"name": FSMStates.COMPLETED.name, "tags": ["system"]},
]
system_transitions = [
{"trigger": "init", "source": FSMStates.UNSTARTED, "dest": FSMStates.STARTING},
{"trigger": "start", "source": FSMStates.STARTING, "dest": FSMStates.STARTED},
{"trigger": "reset", "source": "*", "dest": FSMStates.UNSTARTED},
{"trigger": "fail", "source": "*", "dest": FSMStates.FAILED,},
]
def fail_stage(self, stage):
fail_trigger = f"fail_{stage}"
if fail_trigger in self.machine.get_triggers(self.current_state.name):
self.trigger(fail_trigger)
app.logger.info(
f"calling fail trigger '{fail_trigger}' for '{self.__repr__()}'"
)
else:
app.logger.info(
f"could not locate fail trigger '{fail_trigger}' for '{self.__repr__()}'"
)
def finish_stage(self, stage):
finish_trigger = f"finish_{stage}"
if finish_trigger in self.machine.get_triggers(self.current_state.name):
app.logger.info(
f"calling finish trigger '{finish_trigger}' for '{self.__repr__()}'"
)
self.trigger(finish_trigger)