azure csp unit tests WIP
This commit is contained in:
@@ -123,15 +123,21 @@ class FSMMixin:
|
||||
]
|
||||
|
||||
def fail_stage(self, stage):
|
||||
fail_trigger = "fail" + stage
|
||||
fail_trigger = f"fail_{stage}"
|
||||
|
||||
if fail_trigger in self.machine.get_triggers(self.current_state.name):
|
||||
self.trigger(fail_trigger)
|
||||
app.logger.info(
|
||||
f"calling fail trigger '{fail_trigger}' for '{self.__repr__()}'"
|
||||
)
|
||||
else:
|
||||
app.logger.info(
|
||||
f"could not locate fail trigger '{fail_trigger}' for '{self.__repr__()}'"
|
||||
)
|
||||
|
||||
|
||||
def finish_stage(self, stage):
|
||||
finish_trigger = "finish_" + stage
|
||||
finish_trigger = f"finish_{stage}"
|
||||
if finish_trigger in self.machine.get_triggers(self.current_state.name):
|
||||
app.logger.info(
|
||||
f"calling finish trigger '{finish_trigger}' for '{self.__repr__()}'"
|
||||
|
||||
@@ -15,13 +15,31 @@ from atst.database import db
|
||||
from atst.models.types import Id
|
||||
from atst.models.base import Base
|
||||
import atst.models.mixins as mixins
|
||||
from atst.models.mixins.state_machines import FSMStates, AzureStages, _build_transitions
|
||||
from atst.models.mixins.state_machines import (
|
||||
FSMStates,
|
||||
AzureStages,
|
||||
StageStates,
|
||||
_build_transitions
|
||||
)
|
||||
|
||||
|
||||
class StateMachineMisconfiguredError(Exception):
|
||||
def __init__(self, class_details):
|
||||
self.class_details = class_details
|
||||
|
||||
@property
|
||||
def message(self):
|
||||
return self.class_details
|
||||
|
||||
|
||||
def _stage_to_classname(stage):
|
||||
return "".join(map(lambda word: word.capitalize(), stage.split("_")))
|
||||
|
||||
|
||||
def _stage_state_to_stage_name(state, stage_state):
|
||||
return state.name.split(f"_{stage_state.name}")[0].lower()
|
||||
|
||||
|
||||
def get_stage_csp_class(stage, class_type):
|
||||
"""
|
||||
given a stage name and class_type return the class
|
||||
@@ -34,7 +52,7 @@ def get_stage_csp_class(stage, class_type):
|
||||
importlib.import_module("atst.domain.csp.cloud.models"), cls_name
|
||||
)
|
||||
except AttributeError:
|
||||
print("could not import CSP Result class <%s>" % cls_name)
|
||||
raise StateMachineMisconfiguredError(f"could not import CSP Payload/Result class {cls_name}")
|
||||
|
||||
|
||||
@add_state_features(Tags)
|
||||
@@ -74,7 +92,7 @@ class PortfolioStateMachine(
|
||||
return f"<PortfolioStateMachine(state='{self.current_state.name}', portfolio='{self.portfolio.name}'"
|
||||
|
||||
@reconstructor
|
||||
def attach_machine(self):
|
||||
def attach_machine(self, stages=AzureStages):
|
||||
"""
|
||||
This is called as a result of a sqlalchemy query.
|
||||
Attach a machine depending on the current state.
|
||||
@@ -86,7 +104,7 @@ class PortfolioStateMachine(
|
||||
auto_transitions=False,
|
||||
after_state_change="after_state_change",
|
||||
)
|
||||
states, transitions = _build_transitions(AzureStages)
|
||||
states, transitions = _build_transitions(stages)
|
||||
self.machine.add_states(self.system_states + states)
|
||||
self.machine.add_transitions(self.system_transitions + transitions)
|
||||
|
||||
@@ -120,7 +138,7 @@ class PortfolioStateMachine(
|
||||
app.logger.info(
|
||||
f"could not locate 'create trigger' for {self.__repr__()}"
|
||||
)
|
||||
self.fail_stage(stage)
|
||||
self.trigger('fail')
|
||||
|
||||
elif self.current_state == FSMStates.FAILED:
|
||||
# get the first trigger that starts with 'create_'
|
||||
@@ -151,17 +169,16 @@ class PortfolioStateMachine(
|
||||
if create_trigger is not None:
|
||||
self.trigger(create_trigger, **kwargs)
|
||||
|
||||
def after_in_progress_callback(self, event):
|
||||
stage = self.current_state.name.split("_IN_PROGRESS")[0].lower()
|
||||
|
||||
def after_in_progress_callback(self, event):
|
||||
# Accumulate payload w/ creds
|
||||
payload = event.kwargs.get("csp_data")
|
||||
|
||||
payload_data_cls = get_stage_csp_class(stage, "payload")
|
||||
current_stage = _stage_state_to_stage_name(self.current_state, StageStates.IN_PROGRESS)
|
||||
payload_data_cls = get_stage_csp_class(current_stage, "payload")
|
||||
|
||||
if not payload_data_cls:
|
||||
app.logger.info(f"could not resolve payload data class for stage {stage}")
|
||||
self.fail_stage(stage)
|
||||
app.logger.info(f"could not resolve payload data class for stage {current_stage}")
|
||||
self.fail_stage(current_stage)
|
||||
try:
|
||||
payload_data = payload_data_cls(**payload)
|
||||
except PydanticValidationError as exc:
|
||||
@@ -171,13 +188,13 @@ class PortfolioStateMachine(
|
||||
app.logger.info(exc.json())
|
||||
print(exc.json())
|
||||
app.logger.info(payload)
|
||||
self.fail_stage(stage)
|
||||
self.fail_stage(current_stage)
|
||||
|
||||
# TODO: Determine best place to do this, maybe @reconstructor
|
||||
self.csp = app.csp.cloud
|
||||
|
||||
try:
|
||||
func_name = f"create_{stage}"
|
||||
func_name = f"create_{current_stage}"
|
||||
response = getattr(self.csp, func_name)(payload_data)
|
||||
if self.portfolio.csp_data is None:
|
||||
self.portfolio.csp_data = {}
|
||||
@@ -193,16 +210,16 @@ class PortfolioStateMachine(
|
||||
print(exc.json())
|
||||
app.logger.info(payload_data)
|
||||
# TODO: Ensure that failing the stage does not preclude a Celery retry
|
||||
self.fail_stage(stage)
|
||||
self.fail_stage(current_stage)
|
||||
# TODO: catch and handle general CSP exception here
|
||||
except (ConnectionException, UnknownServerException) as exc:
|
||||
app.logger.error(
|
||||
f"CSP api call. Caught exception for {self.__repr__()}.", exc_info=1,
|
||||
)
|
||||
# TODO: Ensure that failing the stage does not preclude a Celery retry
|
||||
self.fail_stage(stage)
|
||||
self.fail_stage(current_stage)
|
||||
|
||||
self.finish_stage(stage)
|
||||
self.finish_stage(current_stage)
|
||||
|
||||
def is_csp_data_valid(self, event):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user