Fix formatting and some typos

This commit is contained in:
tomdds 2020-01-14 15:21:02 -05:00
parent 34546ecd94
commit d81d953c31
14 changed files with 189 additions and 141 deletions

View File

@ -32,8 +32,12 @@ def make_csp_provider(app, csp=None):
else: else:
app.csp = MockCSP(app) app.csp = MockCSP(app)
def _stage_to_classname(stage): def _stage_to_classname(stage):
return "".join(map(lambda word: word.capitalize(), stage.replace('_', ' ').split(" "))) return "".join(
map(lambda word: word.capitalize(), stage.replace("_", " ").split(" "))
)
def get_stage_csp_class(stage, class_type): def get_stage_csp_class(stage, class_type):
""" """
@ -46,4 +50,3 @@ def get_stage_csp_class(stage, class_type):
return getattr(importlib.import_module("atst.domain.csp.cloud"), cls_name) return getattr(importlib.import_module("atst.domain.csp.cloud"), cls_name)
except AttributeError: except AttributeError:
print("could not import CSP Result class <%s>" % cls_name) print("could not import CSP Result class <%s>" % cls_name)

View File

@ -143,7 +143,7 @@ class BaselineProvisionException(GeneralCSPException):
class BaseCSPPayload(BaseModel): class BaseCSPPayload(BaseModel):
#{"username": "mock-cloud", "pass": "shh"} # {"username": "mock-cloud", "pass": "shh"}
creds: Dict creds: Dict
@ -179,6 +179,8 @@ class BillingProfileAddress(BaseModel):
"postalCode": "string" "postalCode": "string"
}, },
""" """
class BillingProfileCLINBudget(BaseModel): class BillingProfileCLINBudget(BaseModel):
clinBudget: Dict clinBudget: Dict
""" """
@ -190,7 +192,10 @@ class BillingProfileCLINBudget(BaseModel):
} }
""" """
class BillingProfileCSPPayload(BaseCSPPayload, BillingProfileAddress, BillingProfileCLINBudget):
class BillingProfileCSPPayload(
BaseCSPPayload, BillingProfileAddress, BillingProfileCLINBudget
):
displayName: str displayName: str
poNumber: str poNumber: str
invoiceEmailOptIn: str invoiceEmailOptIn: str
@ -411,7 +416,6 @@ class MockCloudProvider(CloudProviderInterface):
return {"id": self._id(), "credentials": self._auth_credentials} return {"id": self._id(), "credentials": self._auth_credentials}
def create_tenant(self, payload): def create_tenant(self, payload):
""" """
payload is an instance of TenantCSPPayload data class payload is an instance of TenantCSPPayload data class
@ -432,7 +436,6 @@ class MockCloudProvider(CloudProviderInterface):
"user_object_id": response["objectId"], "user_object_id": response["objectId"],
} }
def create_billing_profile(self, creds, tenant_admin_details, billing_owner_id): def create_billing_profile(self, creds, tenant_admin_details, billing_owner_id):
# call billing profile creation endpoint, specifying owner # call billing profile creation endpoint, specifying owner
# Payload: # Payload:
@ -475,7 +478,6 @@ class MockCloudProvider(CloudProviderInterface):
response = {"id": "string"} response = {"id": "string"}
return {"billing_profile_id": response["id"]} return {"billing_profile_id": response["id"]}
def create_or_update_user(self, auth_credentials, user_info, csp_role_id): def create_or_update_user(self, auth_credentials, user_info, csp_role_id):
self._authorize(auth_credentials) self._authorize(auth_credentials)
@ -655,7 +657,6 @@ class AzureCloudProvider(CloudProviderInterface):
"role_name": role_assignment_id, "role_name": role_assignment_id,
} }
def create_tenant(self, payload): def create_tenant(self, payload):
# auth as SP that is allowed to create tenant? (tenant creation sp creds) # auth as SP that is allowed to create tenant? (tenant creation sp creds)
# create tenant with owner details (populated from portfolio point of contact, pw is generated) # create tenant with owner details (populated from portfolio point of contact, pw is generated)

View File

@ -1,3 +1,4 @@
from sqlalchemy import or_
from typing import List from typing import List
from uuid import UUID from uuid import UUID
@ -7,7 +8,14 @@ from atst.domain.authz import Authorization
from atst.domain.portfolio_roles import PortfolioRoles from atst.domain.portfolio_roles import PortfolioRoles
from atst.domain.invitations import PortfolioInvitations from atst.domain.invitations import PortfolioInvitations
from atst.models import Portfolio, PortfolioStateMachine, FSMStates, Permissions, PortfolioRole, PortfolioRoleStatus from atst.models import (
Portfolio,
PortfolioStateMachine,
FSMStates,
Permissions,
PortfolioRole,
PortfolioRoleStatus,
)
from .query import PortfoliosQuery, PortfolioStateMachinesQuery from .query import PortfoliosQuery, PortfolioStateMachinesQuery
from .scopes import ScopedPortfolio from .scopes import ScopedPortfolio
@ -21,17 +29,15 @@ class PortfolioDeletionApplicationsExistError(Exception):
pass pass
class PortfolioStateMachines(object): class PortfolioStateMachines(object):
@classmethod @classmethod
def create(cls, portfolio, **sm_attrs): def create(cls, portfolio, **sm_attrs):
sm_attrs.update({'portfolio': portfolio}) sm_attrs.update({"portfolio": portfolio})
sm = PortfolioStateMachinesQuery.create(**sm_attrs) sm = PortfolioStateMachinesQuery.create(**sm_attrs)
return sm return sm
class Portfolios(object):
class Portfolios(object):
@classmethod @classmethod
def get_or_create_state_machine(cls, portfolio): def get_or_create_state_machine(cls, portfolio):
""" """
@ -133,12 +139,9 @@ class Portfolios(object):
PortfoliosQuery.add_and_commit(portfolio) PortfoliosQuery.add_and_commit(portfolio)
@classmethod @classmethod
def base_provision_query(cls): def base_provision_query(cls):
return ( return db.session.query(Portfolio.id)
db.session.query(Portfolio.id)
)
@classmethod @classmethod
def get_portfolios_pending_provisioning(cls) -> List[UUID]: def get_portfolios_pending_provisioning(cls) -> List[UUID]:
@ -150,19 +153,19 @@ class Portfolios(object):
""" """
results = ( results = (
cls.base_provision_query().\ cls.base_provision_query()
join(PortfolioStateMachine).\ .join(PortfolioStateMachine)
filter( .filter(
or_( or_(
PortfolioStateMachine.state == FSMStates.UNSTARTED, PortfolioStateMachine.state == FSMStates.UNSTARTED,
PortfolioStateMachine.state == FSMStates.FAILED, PortfolioStateMachine.state == FSMStates.FAILED,
PortfolioStateMachine.state == FSMStates.TENANT_CREATION_FAILED, PortfolioStateMachine.state == FSMStates.TENANT_FAILED,
) )
) )
) )
return [id_ for id_, in results] return [id_ for id_, in results]
#db.session.query(PortfolioStateMachine).\ # db.session.query(PortfolioStateMachine).\
# filter( # filter(
# or_( # or_(
# PortfolioStateMachine.state==FSMStates.UNSTARTED, # PortfolioStateMachine.state==FSMStates.UNSTARTED,

View File

@ -9,7 +9,8 @@ from atst.models.application_role import (
) )
from atst.models.application import Application from atst.models.application import Application
from atst.models.portfolio_state_machine import PortfolioStateMachine from atst.models.portfolio_state_machine import PortfolioStateMachine
#from atst.models.application import Application
# from atst.models.application import Application
class PortfolioStateMachinesQuery(Query): class PortfolioStateMachinesQuery(Query):

View File

@ -8,12 +8,10 @@ from atst.models import (
EnvironmentRoleJobFailure, EnvironmentRoleJobFailure,
EnvironmentRole, EnvironmentRole,
PortfolioJobFailure, PortfolioJobFailure,
FSMStates,
) )
from atst.domain.csp.cloud import CloudProviderInterface, GeneralCSPException from atst.domain.csp.cloud import CloudProviderInterface, GeneralCSPException
from atst.domain.environments import Environments from atst.domain.environments import Environments
from atst.domain.portfolios import Portfolios from atst.domain.portfolios import Portfolios
from atst.domain.portfolios.query import PortfolioStateMachinesQuery
from atst.domain.environment_roles import EnvironmentRoles from atst.domain.environment_roles import EnvironmentRoles
from atst.models.utils import claim_for_update from atst.models.utils import claim_for_update
@ -29,6 +27,7 @@ class RecordPortfolioFailure(celery.Task):
db.session.add(failure) db.session.add(failure)
db.session.commit() db.session.commit()
class RecordEnvironmentFailure(celery.Task): class RecordEnvironmentFailure(celery.Task):
def on_failure(self, exc, task_id, args, kwargs, einfo): def on_failure(self, exc, task_id, args, kwargs, einfo):
if "environment_id" in kwargs: if "environment_id" in kwargs:
@ -64,7 +63,6 @@ def send_notification_mail(recipients, subject, body):
app.mailer.send(recipients, subject, body) app.mailer.send(recipients, subject, body)
def do_create_environment(csp: CloudProviderInterface, environment_id=None): def do_create_environment(csp: CloudProviderInterface, environment_id=None):
environment = Environments.get(environment_id) environment = Environments.get(environment_id)
@ -150,6 +148,7 @@ def do_provision_portfolio(csp: CloudProviderInterface, portfolio_id=None):
def provision_portfolio(self, portfolio_id=None): def provision_portfolio(self, portfolio_id=None):
do_work(do_provision_portfolio, self, app.csp.cloud, portfolio_id=portfolio_id) do_work(do_provision_portfolio, self, app.csp.cloud, portfolio_id=portfolio_id)
@celery.task(bind=True, base=RecordEnvironmentFailure) @celery.task(bind=True, base=RecordEnvironmentFailure)
def create_environment(self, environment_id=None): def create_environment(self, environment_id=None):
do_work(do_create_environment, self, app.csp.cloud, environment_id=environment_id) do_work(do_create_environment, self, app.csp.cloud, environment_id=environment_id)

View File

@ -7,7 +7,11 @@ from .audit_event import AuditEvent
from .clin import CLIN, JEDICLINType from .clin import CLIN, JEDICLINType
from .environment import Environment from .environment import Environment
from .environment_role import EnvironmentRole, CSPRole from .environment_role import EnvironmentRole, CSPRole
from .job_failure import EnvironmentJobFailure, EnvironmentRoleJobFailure, PortfolioJobFailure from .job_failure import (
EnvironmentJobFailure,
EnvironmentRoleJobFailure,
PortfolioJobFailure,
)
from .notification_recipient import NotificationRecipient from .notification_recipient import NotificationRecipient
from .permissions import Permissions from .permissions import Permissions
from .permission_set import PermissionSet from .permission_set import PermissionSet

View File

@ -15,8 +15,8 @@ class EnvironmentRoleJobFailure(Base, mixins.JobFailureMixin):
environment_role_id = Column(ForeignKey("environment_roles.id"), nullable=False) environment_role_id = Column(ForeignKey("environment_roles.id"), nullable=False)
class PortfolioJobFailure(Base, mixins.JobFailureMixin): class PortfolioJobFailure(Base, mixins.JobFailureMixin):
__tablename__ = "portfolio_job_failures" __tablename__ = "portfolio_job_failures"
portfolio_id = Column(ForeignKey("portfolios.id"), nullable=False) portfolio_id = Column(ForeignKey("portfolios.id"), nullable=False)

View File

@ -1,109 +1,137 @@
from enum import Enum from enum import Enum
from atst.database import db
class StageStates(Enum): class StageStates(Enum):
CREATED = "created" CREATED = "created"
IN_PROGRESS = "in progress" IN_PROGRESS = "in progress"
FAILED = "failed" FAILED = "failed"
class AzureStages(Enum): class AzureStages(Enum):
TENANT = "tenant" TENANT = "tenant"
BILLING_PROFILE = "billing profile" BILLING_PROFILE = "billing profile"
ADMIN_SUBSCRIPTION = "admin subscription" ADMIN_SUBSCRIPTION = "admin subscription"
def _build_csp_states(csp_stages): def _build_csp_states(csp_stages):
states = { states = {
'UNSTARTED' : "unstarted", "UNSTARTED": "unstarted",
'STARTING' : "starting", "STARTING": "starting",
'STARTED' : "started", "STARTED": "started",
'COMPLETED' : "completed", "COMPLETED": "completed",
'FAILED' : "failed", "FAILED": "failed",
} }
for csp_stage in csp_stages: for csp_stage in csp_stages:
for state in StageStates: for state in StageStates:
states[csp_stage.name+"_"+state.name] = csp_stage.value+" "+state.value states[csp_stage.name + "_" + state.name] = (
csp_stage.value + " " + state.value
)
return states return states
FSMStates = Enum('FSMStates', _build_csp_states(AzureStages))
FSMStates = Enum("FSMStates", _build_csp_states(AzureStages))
def _build_transitions(csp_stages): def _build_transitions(csp_stages):
transitions = [] transitions = []
states = [] states = []
compose_state = lambda csp_stage, state: getattr(FSMStates, "_".join([csp_stage.name, state.name])) compose_state = lambda csp_stage, state: getattr(
FSMStates, "_".join([csp_stage.name, state.name])
)
for stage_i, csp_stage in enumerate(csp_stages): for stage_i, csp_stage in enumerate(csp_stages):
for state in StageStates: for state in StageStates:
states.append(dict(name=compose_state(csp_stage, state), tags=[csp_stage.name, state.name])) states.append(
dict(
name=compose_state(csp_stage, state),
tags=[csp_stage.name, state.name],
)
)
if state == StageStates.CREATED: if state == StageStates.CREATED:
if stage_i > 0: if stage_i > 0:
src = compose_state(list(csp_stages)[stage_i-1] , StageStates.CREATED) src = compose_state(
list(csp_stages)[stage_i - 1], StageStates.CREATED
)
else: else:
src = FSMStates.STARTED src = FSMStates.STARTED
transitions.append( transitions.append(
dict( dict(
trigger='create_'+csp_stage.name.lower(), trigger="create_" + csp_stage.name.lower(),
source=src, source=src,
dest=compose_state(csp_stage, StageStates.IN_PROGRESS), dest=compose_state(csp_stage, StageStates.IN_PROGRESS),
after='after_in_progress_callback', after="after_in_progress_callback",
) )
) )
if state == StageStates.IN_PROGRESS: if state == StageStates.IN_PROGRESS:
transitions.append( transitions.append(
dict( dict(
trigger='finish_'+csp_stage.name.lower(), trigger="finish_" + csp_stage.name.lower(),
source=compose_state(csp_stage, state), source=compose_state(csp_stage, state),
dest=compose_state(csp_stage, StageStates.CREATED), dest=compose_state(csp_stage, StageStates.CREATED),
conditions=['is_csp_data_valid'], conditions=["is_csp_data_valid"],
) )
) )
if state == StageStates.FAILED: if state == StageStates.FAILED:
transitions.append( transitions.append(
dict( dict(
trigger='fail_'+csp_stage.name.lower(), trigger="fail_" + csp_stage.name.lower(),
source=compose_state(csp_stage, StageStates.IN_PROGRESS), source=compose_state(csp_stage, StageStates.IN_PROGRESS),
dest=compose_state(csp_stage, StageStates.FAILED), dest=compose_state(csp_stage, StageStates.FAILED),
) )
) )
return states, transitions return states, transitions
class FSMMixin():
class FSMMixin:
system_states = [ system_states = [
{'name': FSMStates.UNSTARTED.name, 'tags': ['system']}, {"name": FSMStates.UNSTARTED.name, "tags": ["system"]},
{'name': FSMStates.STARTING.name, 'tags': ['system']}, {"name": FSMStates.STARTING.name, "tags": ["system"]},
{'name': FSMStates.STARTED.name, 'tags': ['system']}, {"name": FSMStates.STARTED.name, "tags": ["system"]},
{'name': FSMStates.FAILED.name, 'tags': ['system']}, {"name": FSMStates.FAILED.name, "tags": ["system"]},
{'name': FSMStates.COMPLETED.name, 'tags': ['system']}, {"name": FSMStates.COMPLETED.name, "tags": ["system"]},
] ]
system_transitions = [ system_transitions = [
{'trigger': 'init', 'source': FSMStates.UNSTARTED, 'dest': FSMStates.STARTING}, {"trigger": "init", "source": FSMStates.UNSTARTED, "dest": FSMStates.STARTING},
{'trigger': 'start', 'source': FSMStates.STARTING, 'dest': FSMStates.STARTED}, {"trigger": "start", "source": FSMStates.STARTING, "dest": FSMStates.STARTED},
{'trigger': 'reset', 'source': '*', 'dest': FSMStates.UNSTARTED}, {"trigger": "reset", "source": "*", "dest": FSMStates.UNSTARTED},
{'trigger': 'fail', 'source': '*', 'dest': FSMStates.FAILED,} {"trigger": "fail", "source": "*", "dest": FSMStates.FAILED,},
] ]
def prepare_init(self, event): pass def prepare_init(self, event):
def before_init(self, event): pass pass
def after_init(self, event): pass
def prepare_start(self, event): pass def before_init(self, event):
def before_start(self, event): pass pass
def after_start(self, event): pass
def prepare_reset(self, event): pass def after_init(self, event):
def before_reset(self, event): pass pass
def after_reset(self, event): pass
def prepare_start(self, event):
pass
def before_start(self, event):
pass
def after_start(self, event):
pass
def prepare_reset(self, event):
pass
def before_reset(self, event):
pass
def after_reset(self, event):
pass
def fail_stage(self, stage): def fail_stage(self, stage):
fail_trigger = 'fail'+stage fail_trigger = "fail" + stage
if fail_trigger in self.machine.get_triggers(self.current_state.name): if fail_trigger in self.machine.get_triggers(self.current_state.name):
self.trigger(fail_trigger) self.trigger(fail_trigger)
def finish_stage(self, stage): def finish_stage(self, stage):
finish_trigger = 'finish_'+stage finish_trigger = "finish_" + stage
if finish_trigger in self.machine.get_triggers(self.current_state.name): if finish_trigger in self.machine.get_triggers(self.current_state.name):
self.trigger(finish_trigger) self.trigger(finish_trigger)

View File

@ -14,7 +14,6 @@ from atst.database import db
from sqlalchemy_json import NestedMutableJson from sqlalchemy_json import NestedMutableJson
class Portfolio( class Portfolio(
Base, mixins.TimestampsMixin, mixins.AuditableMixin, mixins.DeletableMixin Base, mixins.TimestampsMixin, mixins.AuditableMixin, mixins.DeletableMixin
): ):
@ -43,8 +42,9 @@ class Portfolio(
primaryjoin="and_(Application.portfolio_id == Portfolio.id, Application.deleted == False)", primaryjoin="and_(Application.portfolio_id == Portfolio.id, Application.deleted == False)",
) )
state_machine = relationship("PortfolioStateMachine", state_machine = relationship(
uselist=False, back_populates="portfolio") "PortfolioStateMachine", uselist=False, back_populates="portfolio"
)
roles = relationship("PortfolioRole") roles = relationship("PortfolioRole")

View File

@ -1,5 +1,3 @@
import importlib
from sqlalchemy import Column, ForeignKey, Enum as SQLAEnum from sqlalchemy import Column, ForeignKey, Enum as SQLAEnum
from sqlalchemy.orm import relationship, reconstructor from sqlalchemy.orm import relationship, reconstructor
from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.dialects.postgresql import UUID
@ -13,36 +11,35 @@ from flask import current_app as app
from atst.domain.csp.cloud import ConnectionException, UnknownServerException from atst.domain.csp.cloud import ConnectionException, UnknownServerException
from atst.domain.csp import MockCSP, AzureCSP, get_stage_csp_class from atst.domain.csp import MockCSP, AzureCSP, get_stage_csp_class
from atst.database import db from atst.database import db
from atst.queue import celery
from atst.models.types import Id from atst.models.types import Id
from atst.models.base import Base from atst.models.base import Base
import atst.models.mixins as mixins import atst.models.mixins as mixins
from atst.models.mixins.state_machines import ( from atst.models.mixins.state_machines import FSMStates, AzureStages, _build_transitions
FSMStates, AzureStages, _build_transitions
)
@add_state_features(Tags) @add_state_features(Tags)
class StateMachineWithTags(Machine): class StateMachineWithTags(Machine):
pass pass
class PortfolioStateMachine( class PortfolioStateMachine(
Base, mixins.TimestampsMixin, mixins.AuditableMixin, mixins.DeletableMixin, mixins.FSMMixin, Base,
mixins.TimestampsMixin,
mixins.AuditableMixin,
mixins.DeletableMixin,
mixins.FSMMixin,
): ):
__tablename__ = "portfolio_state_machines" __tablename__ = "portfolio_state_machines"
id = Id() id = Id()
portfolio_id = Column( portfolio_id = Column(UUID(as_uuid=True), ForeignKey("portfolios.id"),)
UUID(as_uuid=True),
ForeignKey("portfolios.id"),
)
portfolio = relationship("Portfolio", back_populates="state_machine") portfolio = relationship("Portfolio", back_populates="state_machine")
state = Column( state = Column(
SQLAEnum(FSMStates, native_enum=False, create_constraint=False), SQLAEnum(FSMStates, native_enum=False, create_constraint=False),
default=FSMStates.UNSTARTED, nullable=False default=FSMStates.UNSTARTED,
nullable=False,
) )
def __init__(self, portfolio, csp=None, **kwargs): def __init__(self, portfolio, csp=None, **kwargs):
@ -60,15 +57,15 @@ class PortfolioStateMachine(
Attach a machine depending on the current state. Attach a machine depending on the current state.
""" """
self.machine = StateMachineWithTags( self.machine = StateMachineWithTags(
model = self, model=self,
send_event=True, send_event=True,
initial=self.current_state if self.state else FSMStates.UNSTARTED, initial=self.current_state if self.state else FSMStates.UNSTARTED,
auto_transitions=False, auto_transitions=False,
after_state_change='after_state_change', after_state_change="after_state_change",
) )
states, transitions = _build_transitions(AzureStages) states, transitions = _build_transitions(AzureStages)
self.machine.add_states(self.system_states+states) self.machine.add_states(self.system_states + states)
self.machine.add_transitions(self.system_transitions+transitions) self.machine.add_transitions(self.system_transitions + transitions)
@property @property
def current_state(self): def current_state(self):
@ -87,37 +84,38 @@ class PortfolioStateMachine(
elif self.current_state == FSMStates.STARTED: elif self.current_state == FSMStates.STARTED:
# get the first trigger that starts with 'create_' # get the first trigger that starts with 'create_'
create_trigger = list(filter(lambda trigger: trigger.startswith('create_'), create_trigger = list(
self.machine.get_triggers(FSMStates.STARTED.name)))[0] filter(
lambda trigger: trigger.startswith("create_"),
self.machine.get_triggers(FSMStates.STARTED.name),
)
)[0]
self.trigger(create_trigger) self.trigger(create_trigger)
elif state_obj.is_IN_PROGRESS: elif state_obj.is_IN_PROGRESS:
pass pass
#elif state_obj.is_TENANT: # elif state_obj.is_TENANT:
# pass # pass
#elif state_obj.is_BILLING_PROFILE: # elif state_obj.is_BILLING_PROFILE:
# pass # pass
# @with_payload
#@with_payload
def after_in_progress_callback(self, event): def after_in_progress_callback(self, event):
stage = self.current_state.name.split('_IN_PROGRESS')[0].lower() stage = self.current_state.name.split("_IN_PROGRESS")[0].lower()
if stage == 'tenant': if stage == "tenant":
payload = dict( payload = dict( # nosec
creds={"username": "mock-cloud", "pass": "shh"},
user_id='123',
password='123',
domain_name='123',
first_name='john',
last_name='doe',
country_code='US',
password_recovery_email_address='password@email.com'
)
elif stage == 'billing_profile':
payload = dict(
creds={"username": "mock-cloud", "pass": "shh"}, creds={"username": "mock-cloud", "pass": "shh"},
user_id="123",
password="123",
domain_name="123",
first_name="john",
last_name="doe",
country_code="US",
password_recovery_email_address="password@email.com",
) )
elif stage == "billing_profile":
payload = dict(creds={"username": "mock-cloud", "pass": "shh"},)
payload_data_cls = get_stage_csp_class(stage, "payload") payload_data_cls = get_stage_csp_class(stage, "payload")
if not payload_data_cls: if not payload_data_cls:
@ -128,7 +126,7 @@ class PortfolioStateMachine(
print(exc.json()) print(exc.json())
self.fail_stage(stage) self.fail_stage(stage)
csp = event.kwargs.get('csp') csp = event.kwargs.get("csp")
if csp is not None: if csp is not None:
self.csp = AzureCSP(app).cloud self.csp = AzureCSP(app).cloud
else: else:
@ -136,18 +134,19 @@ class PortfolioStateMachine(
for attempt in range(5): for attempt in range(5):
try: try:
response = getattr(self.csp, 'create_'+stage)(payload_data) response = getattr(self.csp, "create_" + stage)(payload_data)
except (ConnectionException, UnknownServerException) as exc: except (ConnectionException, UnknownServerException) as exc:
print('caught exception. retry', attempt) print("caught exception. retry", attempt)
continue continue
else: break else:
break
else: else:
# failed all attempts # failed all attempts
self.fail_stage(stage) self.fail_stage(stage)
if self.portfolio.csp_data is None: if self.portfolio.csp_data is None:
self.portfolio.csp_data = {} self.portfolio.csp_data = {}
self.portfolio.csp_data[stage+"_data"] = response self.portfolio.csp_data[stage + "_data"] = response
db.session.add(self.portfolio) db.session.add(self.portfolio)
db.session.commit() db.session.commit()
@ -156,12 +155,13 @@ class PortfolioStateMachine(
def is_csp_data_valid(self, event): def is_csp_data_valid(self, event):
# check portfolio csp details json field for fields # check portfolio csp details json field for fields
if self.portfolio.csp_data is None or \ if self.portfolio.csp_data is None or not isinstance(
not isinstance(self.portfolio.csp_data, dict): self.portfolio.csp_data, dict
):
return False return False
stage = self.current_state.name.split('_IN_PROGRESS')[0].lower() stage = self.current_state.name.split("_IN_PROGRESS")[0].lower()
stage_data = self.portfolio.csp_data.get(stage+"_data") stage_data = self.portfolio.csp_data.get(stage + "_data")
cls = get_stage_csp_class(stage, "result") cls = get_stage_csp_class(stage, "result")
if not cls: if not cls:
return False return False
@ -174,8 +174,7 @@ class PortfolioStateMachine(
return True return True
#print('failed condition', self.portfolio.csp_data) # print('failed condition', self.portfolio.csp_data)
@property @property
def application_id(self): def application_id(self):

View File

@ -13,26 +13,24 @@ def portfolio():
portfolio = PortfolioFactory.create() portfolio = PortfolioFactory.create()
return portfolio return portfolio
def test_fsm_creation(portfolio): def test_fsm_creation(portfolio):
sm = PortfolioStateMachineFactory.create(portfolio=portfolio) sm = PortfolioStateMachineFactory.create(portfolio=portfolio)
assert sm.portfolio assert sm.portfolio
def test_fsm_transition_start(portfolio): def test_fsm_transition_start(portfolio):
sm = PortfolioStateMachineFactory.create(portfolio=portfolio) sm = PortfolioStateMachineFactory.create(portfolio=portfolio)
assert sm.portfolio assert sm.portfolio
assert sm.state == FSMStates.UNSTARTED assert sm.state == FSMStates.UNSTARTED
# next_state does not create the trigger callbacks !!! # next_state does not create the trigger callbacks !!!
#sm.next_state() # sm.next_state()
sm.init() sm.init()
assert sm.state == FSMStates.STARTING assert sm.state == FSMStates.STARTING
sm.start() sm.start()
assert sm.state == FSMStates.STARTED assert sm.state == FSMStates.STARTED
#import ipdb;ipdb.set_trace()
sm.create_tenant(a=1, b=2) sm.create_tenant(a=1, b=2)
assert sm.state == FSMStates.TENANT_CREATED assert sm.state == FSMStates.TENANT_CREATED

View File

@ -6,6 +6,7 @@ from atst.domain.portfolios import (
Portfolios, Portfolios,
PortfolioError, PortfolioError,
PortfolioDeletionApplicationsExistError, PortfolioDeletionApplicationsExistError,
PortfolioStateMachines,
) )
from atst.domain.portfolio_roles import PortfolioRoles from atst.domain.portfolio_roles import PortfolioRoles
from atst.domain.applications import Applications from atst.domain.applications import Applications
@ -256,16 +257,16 @@ def test_for_user_does_not_include_deleted_application_roles():
) )
assert len(Portfolios.for_user(user2)) == 0 assert len(Portfolios.for_user(user2)) == 0
def test_create_state_machine(portfolio): def test_create_state_machine(portfolio):
fsm = Portfolios.create_state_machine(portfolio) fsm = PortfolioStateMachines.create(portfolio)
assert fsm assert fsm
def test_get_portfolios_pending_provisioning(session): def test_get_portfolios_pending_provisioning(session):
for x in range(5): for x in range(5):
portfolio = PortfolioFactory.create() portfolio = PortfolioFactory.create()
sm = PortfolioStateMachineFactory.create(portfolio=portfolio) sm = PortfolioStateMachineFactory.create(portfolio=portfolio)
if x == 2: sm.state = FSMStates.COMPLETED if x == 2:
sm.state = FSMStates.COMPLETED
assert len(Portfolios.get_portfolios_pending_provisioning()) == 4 assert len(Portfolios.get_portfolios_pending_provisioning()) == 4

View File

@ -343,6 +343,7 @@ class NotificationRecipientFactory(Base):
email = factory.Faker("email") email = factory.Faker("email")
class PortfolioStateMachineFactory(Base): class PortfolioStateMachineFactory(Base):
class Meta: class Meta:
model = PortfolioStateMachine model = PortfolioStateMachine
@ -352,6 +353,6 @@ class PortfolioStateMachineFactory(Base):
@classmethod @classmethod
def _create(cls, model_class, *args, **kwargs): def _create(cls, model_class, *args, **kwargs):
portfolio = kwargs.pop("portfolio", PortfolioFactory.create()) portfolio = kwargs.pop("portfolio", PortfolioFactory.create())
kwargs.update({'portfolio': portfolio}) kwargs.update({"portfolio": portfolio})
fsm = super()._create(model_class, *args, **kwargs) fsm = super()._create(model_class, *args, **kwargs)
return fsm return fsm

View File

@ -17,6 +17,8 @@ from atst.jobs import (
create_environment, create_environment,
do_provision_user, do_provision_user,
do_provision_portfolio, do_provision_portfolio,
do_create_environment,
do_create_atat_admin_user,
) )
from atst.models.utils import claim_for_update from atst.models.utils import claim_for_update
from atst.domain.exceptions import ClaimFailedException from atst.domain.exceptions import ClaimFailedException
@ -34,6 +36,7 @@ from atst.models import CSPRole, EnvironmentRole, ApplicationRoleStatus
def csp(): def csp():
return Mock(wraps=MockCloudProvider({}, with_delay=False, with_failure=False)) return Mock(wraps=MockCloudProvider({}, with_delay=False, with_failure=False))
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def portfolio(): def portfolio():
portfolio = PortfolioFactory.create() portfolio = PortfolioFactory.create()
@ -316,21 +319,28 @@ def test_do_provision_user(csp, session):
# I expect that the EnvironmentRole now has a csp_user_id # I expect that the EnvironmentRole now has a csp_user_id
assert environment_role.csp_user_id assert environment_role.csp_user_id
def test_dispatch_provision_portfolio(csp, session, portfolio, celery_app, celery_worker, monkeypatch):
def test_dispatch_provision_portfolio(
csp, session, portfolio, celery_app, celery_worker, monkeypatch
):
sm = PortfolioStateMachineFactory.create(portfolio=portfolio) sm = PortfolioStateMachineFactory.create(portfolio=portfolio)
mock = Mock() mock = Mock()
monkeypatch.setattr("atst.jobs.provision_portfolio", mock) monkeypatch.setattr("atst.jobs.provision_portfolio", mock)
dispatch_provision_portfolio.run() dispatch_provision_portfolio.run()
mock.delay.assert_called_once_with(portfolio_id=portfolio.id) mock.delay.assert_called_once_with(portfolio_id=portfolio.id)
def test_do_provision_portfolio(csp, session, portfolio): def test_do_provision_portfolio(csp, session, portfolio):
do_provision_portfolio(csp=csp, portfolio_id=portfolio.id) do_provision_portfolio(csp=csp, portfolio_id=portfolio.id)
session.refresh(portfolio) session.refresh(portfolio)
assert portfolio.state_machine assert portfolio.state_machine
def test_provision_portfolio_create_tenant(csp, session, portfolio, celery_app, celery_worker, monkeypatch):
def test_provision_portfolio_create_tenant(
csp, session, portfolio, celery_app, celery_worker, monkeypatch
):
sm = PortfolioStateMachineFactory.create(portfolio=portfolio) sm = PortfolioStateMachineFactory.create(portfolio=portfolio)
#mock = Mock() # mock = Mock()
#monkeypatch.setattr("atst.jobs.provision_portfolio", mock) # monkeypatch.setattr("atst.jobs.provision_portfolio", mock)
#dispatch_provision_portfolio.run() # dispatch_provision_portfolio.run()
#mock.delay.assert_called_once_with(portfolio_id=portfolio.id) # mock.delay.assert_called_once_with(portfolio_id=portfolio.id)