diff --git a/atst/domain/csp/__init__.py b/atst/domain/csp/__init__.py index a7dfd6ee..fc452935 100644 --- a/atst/domain/csp/__init__.py +++ b/atst/domain/csp/__init__.py @@ -32,8 +32,12 @@ def make_csp_provider(app, csp=None): else: app.csp = MockCSP(app) + def _stage_to_classname(stage): - return "".join(map(lambda word: word.capitalize(), stage.replace('_', ' ').split(" "))) + return "".join( + map(lambda word: word.capitalize(), stage.replace("_", " ").split(" ")) + ) + def get_stage_csp_class(stage, class_type): """ @@ -46,4 +50,3 @@ def get_stage_csp_class(stage, class_type): return getattr(importlib.import_module("atst.domain.csp.cloud"), cls_name) except AttributeError: print("could not import CSP Result class <%s>" % cls_name) - diff --git a/atst/domain/csp/cloud.py b/atst/domain/csp/cloud.py index 5f68045b..9f764da0 100644 --- a/atst/domain/csp/cloud.py +++ b/atst/domain/csp/cloud.py @@ -143,7 +143,7 @@ class BaselineProvisionException(GeneralCSPException): class BaseCSPPayload(BaseModel): - #{"username": "mock-cloud", "pass": "shh"} + # {"username": "mock-cloud", "pass": "shh"} creds: Dict @@ -179,6 +179,8 @@ class BillingProfileAddress(BaseModel): "postalCode": "string" }, """ + + class BillingProfileCLINBudget(BaseModel): clinBudget: Dict """ @@ -190,7 +192,10 @@ class BillingProfileCLINBudget(BaseModel): } """ -class BillingProfileCSPPayload(BaseCSPPayload, BillingProfileAddress, BillingProfileCLINBudget): + +class BillingProfileCSPPayload( + BaseCSPPayload, BillingProfileAddress, BillingProfileCLINBudget +): displayName: str poNumber: str invoiceEmailOptIn: str @@ -411,7 +416,6 @@ class MockCloudProvider(CloudProviderInterface): return {"id": self._id(), "credentials": self._auth_credentials} - def create_tenant(self, payload): """ payload is an instance of TenantCSPPayload data class @@ -432,7 +436,6 @@ class MockCloudProvider(CloudProviderInterface): "user_object_id": response["objectId"], } - def create_billing_profile(self, creds, tenant_admin_details, billing_owner_id): # call billing profile creation endpoint, specifying owner # Payload: @@ -475,7 +478,6 @@ class MockCloudProvider(CloudProviderInterface): response = {"id": "string"} return {"billing_profile_id": response["id"]} - def create_or_update_user(self, auth_credentials, user_info, csp_role_id): self._authorize(auth_credentials) @@ -655,7 +657,6 @@ class AzureCloudProvider(CloudProviderInterface): "role_name": role_assignment_id, } - def create_tenant(self, payload): # auth as SP that is allowed to create tenant? (tenant creation sp creds) # create tenant with owner details (populated from portfolio point of contact, pw is generated) diff --git a/atst/domain/portfolios/portfolios.py b/atst/domain/portfolios/portfolios.py index e7f82a0d..1254ac71 100644 --- a/atst/domain/portfolios/portfolios.py +++ b/atst/domain/portfolios/portfolios.py @@ -1,3 +1,4 @@ +from sqlalchemy import or_ from typing import List from uuid import UUID @@ -7,7 +8,14 @@ from atst.domain.authz import Authorization from atst.domain.portfolio_roles import PortfolioRoles from atst.domain.invitations import PortfolioInvitations -from atst.models import Portfolio, PortfolioStateMachine, FSMStates, Permissions, PortfolioRole, PortfolioRoleStatus +from atst.models import ( + Portfolio, + PortfolioStateMachine, + FSMStates, + Permissions, + PortfolioRole, + PortfolioRoleStatus, +) from .query import PortfoliosQuery, PortfolioStateMachinesQuery from .scopes import ScopedPortfolio @@ -21,17 +29,15 @@ class PortfolioDeletionApplicationsExistError(Exception): pass - class PortfolioStateMachines(object): - @classmethod def create(cls, portfolio, **sm_attrs): - sm_attrs.update({'portfolio': portfolio}) + sm_attrs.update({"portfolio": portfolio}) sm = PortfolioStateMachinesQuery.create(**sm_attrs) return sm -class Portfolios(object): +class Portfolios(object): @classmethod def get_or_create_state_machine(cls, portfolio): """ @@ -133,12 +139,9 @@ class Portfolios(object): PortfoliosQuery.add_and_commit(portfolio) - @classmethod def base_provision_query(cls): - return ( - db.session.query(Portfolio.id) - ) + return db.session.query(Portfolio.id) @classmethod def get_portfolios_pending_provisioning(cls) -> List[UUID]: @@ -150,19 +153,19 @@ class Portfolios(object): """ results = ( - cls.base_provision_query().\ - join(PortfolioStateMachine).\ - filter( - or_( - PortfolioStateMachine.state == FSMStates.UNSTARTED, - PortfolioStateMachine.state == FSMStates.FAILED, - PortfolioStateMachine.state == FSMStates.TENANT_CREATION_FAILED, - ) + cls.base_provision_query() + .join(PortfolioStateMachine) + .filter( + or_( + PortfolioStateMachine.state == FSMStates.UNSTARTED, + PortfolioStateMachine.state == FSMStates.FAILED, + PortfolioStateMachine.state == FSMStates.TENANT_FAILED, ) + ) ) return [id_ for id_, in results] - #db.session.query(PortfolioStateMachine).\ + # db.session.query(PortfolioStateMachine).\ # filter( # or_( # PortfolioStateMachine.state==FSMStates.UNSTARTED, diff --git a/atst/domain/portfolios/query.py b/atst/domain/portfolios/query.py index b098d9cc..0fa8f5ac 100644 --- a/atst/domain/portfolios/query.py +++ b/atst/domain/portfolios/query.py @@ -9,7 +9,8 @@ from atst.models.application_role import ( ) from atst.models.application import Application from atst.models.portfolio_state_machine import PortfolioStateMachine -#from atst.models.application import Application + +# from atst.models.application import Application class PortfolioStateMachinesQuery(Query): diff --git a/atst/jobs.py b/atst/jobs.py index 68c2fdc5..f4611a9a 100644 --- a/atst/jobs.py +++ b/atst/jobs.py @@ -8,12 +8,10 @@ from atst.models import ( EnvironmentRoleJobFailure, EnvironmentRole, PortfolioJobFailure, - FSMStates, ) from atst.domain.csp.cloud import CloudProviderInterface, GeneralCSPException from atst.domain.environments import Environments from atst.domain.portfolios import Portfolios -from atst.domain.portfolios.query import PortfolioStateMachinesQuery from atst.domain.environment_roles import EnvironmentRoles from atst.models.utils import claim_for_update @@ -29,6 +27,7 @@ class RecordPortfolioFailure(celery.Task): db.session.add(failure) db.session.commit() + class RecordEnvironmentFailure(celery.Task): def on_failure(self, exc, task_id, args, kwargs, einfo): if "environment_id" in kwargs: @@ -64,7 +63,6 @@ def send_notification_mail(recipients, subject, body): app.mailer.send(recipients, subject, body) - def do_create_environment(csp: CloudProviderInterface, environment_id=None): environment = Environments.get(environment_id) @@ -150,6 +148,7 @@ def do_provision_portfolio(csp: CloudProviderInterface, portfolio_id=None): def provision_portfolio(self, portfolio_id=None): do_work(do_provision_portfolio, self, app.csp.cloud, portfolio_id=portfolio_id) + @celery.task(bind=True, base=RecordEnvironmentFailure) def create_environment(self, environment_id=None): do_work(do_create_environment, self, app.csp.cloud, environment_id=environment_id) diff --git a/atst/models/__init__.py b/atst/models/__init__.py index 7fa7c3f8..f6c48306 100644 --- a/atst/models/__init__.py +++ b/atst/models/__init__.py @@ -7,7 +7,11 @@ from .audit_event import AuditEvent from .clin import CLIN, JEDICLINType from .environment import Environment from .environment_role import EnvironmentRole, CSPRole -from .job_failure import EnvironmentJobFailure, EnvironmentRoleJobFailure, PortfolioJobFailure +from .job_failure import ( + EnvironmentJobFailure, + EnvironmentRoleJobFailure, + PortfolioJobFailure, +) from .notification_recipient import NotificationRecipient from .permissions import Permissions from .permission_set import PermissionSet diff --git a/atst/models/job_failure.py b/atst/models/job_failure.py index 7c358f0e..7a7f010a 100644 --- a/atst/models/job_failure.py +++ b/atst/models/job_failure.py @@ -15,8 +15,8 @@ class EnvironmentRoleJobFailure(Base, mixins.JobFailureMixin): environment_role_id = Column(ForeignKey("environment_roles.id"), nullable=False) + class PortfolioJobFailure(Base, mixins.JobFailureMixin): __tablename__ = "portfolio_job_failures" portfolio_id = Column(ForeignKey("portfolios.id"), nullable=False) - diff --git a/atst/models/mixins/state_machines.py b/atst/models/mixins/state_machines.py index 493843df..bc35209d 100644 --- a/atst/models/mixins/state_machines.py +++ b/atst/models/mixins/state_machines.py @@ -1,109 +1,137 @@ from enum import Enum -from atst.database import db class StageStates(Enum): CREATED = "created" IN_PROGRESS = "in progress" FAILED = "failed" + class AzureStages(Enum): TENANT = "tenant" BILLING_PROFILE = "billing profile" ADMIN_SUBSCRIPTION = "admin subscription" + def _build_csp_states(csp_stages): states = { - 'UNSTARTED' : "unstarted", - 'STARTING' : "starting", - 'STARTED' : "started", - 'COMPLETED' : "completed", - 'FAILED' : "failed", + "UNSTARTED": "unstarted", + "STARTING": "starting", + "STARTED": "started", + "COMPLETED": "completed", + "FAILED": "failed", } for csp_stage in csp_stages: for state in StageStates: - states[csp_stage.name+"_"+state.name] = csp_stage.value+" "+state.value + states[csp_stage.name + "_" + state.name] = ( + csp_stage.value + " " + state.value + ) return states -FSMStates = Enum('FSMStates', _build_csp_states(AzureStages)) + +FSMStates = Enum("FSMStates", _build_csp_states(AzureStages)) def _build_transitions(csp_stages): transitions = [] states = [] - compose_state = lambda csp_stage, state: getattr(FSMStates, "_".join([csp_stage.name, state.name])) + compose_state = lambda csp_stage, state: getattr( + FSMStates, "_".join([csp_stage.name, state.name]) + ) for stage_i, csp_stage in enumerate(csp_stages): for state in StageStates: - states.append(dict(name=compose_state(csp_stage, state), tags=[csp_stage.name, state.name])) + states.append( + dict( + name=compose_state(csp_stage, state), + tags=[csp_stage.name, state.name], + ) + ) if state == StageStates.CREATED: if stage_i > 0: - src = compose_state(list(csp_stages)[stage_i-1] , StageStates.CREATED) + src = compose_state( + list(csp_stages)[stage_i - 1], StageStates.CREATED + ) else: src = FSMStates.STARTED transitions.append( dict( - trigger='create_'+csp_stage.name.lower(), + trigger="create_" + csp_stage.name.lower(), source=src, dest=compose_state(csp_stage, StageStates.IN_PROGRESS), - after='after_in_progress_callback', + after="after_in_progress_callback", ) ) if state == StageStates.IN_PROGRESS: transitions.append( dict( - trigger='finish_'+csp_stage.name.lower(), + trigger="finish_" + csp_stage.name.lower(), source=compose_state(csp_stage, state), dest=compose_state(csp_stage, StageStates.CREATED), - conditions=['is_csp_data_valid'], + conditions=["is_csp_data_valid"], ) ) if state == StageStates.FAILED: transitions.append( dict( - trigger='fail_'+csp_stage.name.lower(), + trigger="fail_" + csp_stage.name.lower(), source=compose_state(csp_stage, StageStates.IN_PROGRESS), dest=compose_state(csp_stage, StageStates.FAILED), ) ) return states, transitions -class FSMMixin(): + +class FSMMixin: system_states = [ - {'name': FSMStates.UNSTARTED.name, 'tags': ['system']}, - {'name': FSMStates.STARTING.name, 'tags': ['system']}, - {'name': FSMStates.STARTED.name, 'tags': ['system']}, - {'name': FSMStates.FAILED.name, 'tags': ['system']}, - {'name': FSMStates.COMPLETED.name, 'tags': ['system']}, + {"name": FSMStates.UNSTARTED.name, "tags": ["system"]}, + {"name": FSMStates.STARTING.name, "tags": ["system"]}, + {"name": FSMStates.STARTED.name, "tags": ["system"]}, + {"name": FSMStates.FAILED.name, "tags": ["system"]}, + {"name": FSMStates.COMPLETED.name, "tags": ["system"]}, ] system_transitions = [ - {'trigger': 'init', 'source': FSMStates.UNSTARTED, 'dest': FSMStates.STARTING}, - {'trigger': 'start', 'source': FSMStates.STARTING, 'dest': FSMStates.STARTED}, - {'trigger': 'reset', 'source': '*', 'dest': FSMStates.UNSTARTED}, - {'trigger': 'fail', 'source': '*', 'dest': FSMStates.FAILED,} + {"trigger": "init", "source": FSMStates.UNSTARTED, "dest": FSMStates.STARTING}, + {"trigger": "start", "source": FSMStates.STARTING, "dest": FSMStates.STARTED}, + {"trigger": "reset", "source": "*", "dest": FSMStates.UNSTARTED}, + {"trigger": "fail", "source": "*", "dest": FSMStates.FAILED,}, ] - def prepare_init(self, event): pass - def before_init(self, event): pass - def after_init(self, event): pass + def prepare_init(self, event): + pass - def prepare_start(self, event): pass - def before_start(self, event): pass - def after_start(self, event): pass + def before_init(self, event): + pass - def prepare_reset(self, event): pass - def before_reset(self, event): pass - def after_reset(self, event): pass + def after_init(self, event): + pass + + def prepare_start(self, event): + pass + + def before_start(self, event): + pass + + def after_start(self, event): + pass + + def prepare_reset(self, event): + pass + + def before_reset(self, event): + pass + + def after_reset(self, event): + pass def fail_stage(self, stage): - fail_trigger = 'fail'+stage + fail_trigger = "fail" + stage if fail_trigger in self.machine.get_triggers(self.current_state.name): self.trigger(fail_trigger) def finish_stage(self, stage): - finish_trigger = 'finish_'+stage + finish_trigger = "finish_" + stage if finish_trigger in self.machine.get_triggers(self.current_state.name): self.trigger(finish_trigger) - diff --git a/atst/models/portfolio.py b/atst/models/portfolio.py index 948bef19..f60ed8de 100644 --- a/atst/models/portfolio.py +++ b/atst/models/portfolio.py @@ -14,7 +14,6 @@ from atst.database import db from sqlalchemy_json import NestedMutableJson - class Portfolio( Base, mixins.TimestampsMixin, mixins.AuditableMixin, mixins.DeletableMixin ): @@ -43,8 +42,9 @@ class Portfolio( primaryjoin="and_(Application.portfolio_id == Portfolio.id, Application.deleted == False)", ) - state_machine = relationship("PortfolioStateMachine", - uselist=False, back_populates="portfolio") + state_machine = relationship( + "PortfolioStateMachine", uselist=False, back_populates="portfolio" + ) roles = relationship("PortfolioRole") diff --git a/atst/models/portfolio_state_machine.py b/atst/models/portfolio_state_machine.py index 13b614b3..3c934197 100644 --- a/atst/models/portfolio_state_machine.py +++ b/atst/models/portfolio_state_machine.py @@ -1,5 +1,3 @@ -import importlib - from sqlalchemy import Column, ForeignKey, Enum as SQLAEnum from sqlalchemy.orm import relationship, reconstructor from sqlalchemy.dialects.postgresql import UUID @@ -13,36 +11,35 @@ from flask import current_app as app from atst.domain.csp.cloud import ConnectionException, UnknownServerException from atst.domain.csp import MockCSP, AzureCSP, get_stage_csp_class from atst.database import db -from atst.queue import celery from atst.models.types import Id from atst.models.base import Base import atst.models.mixins as mixins -from atst.models.mixins.state_machines import ( - FSMStates, AzureStages, _build_transitions -) - +from atst.models.mixins.state_machines import FSMStates, AzureStages, _build_transitions @add_state_features(Tags) class StateMachineWithTags(Machine): pass + class PortfolioStateMachine( - Base, mixins.TimestampsMixin, mixins.AuditableMixin, mixins.DeletableMixin, mixins.FSMMixin, + Base, + mixins.TimestampsMixin, + mixins.AuditableMixin, + mixins.DeletableMixin, + mixins.FSMMixin, ): __tablename__ = "portfolio_state_machines" id = Id() - portfolio_id = Column( - UUID(as_uuid=True), - ForeignKey("portfolios.id"), - ) + portfolio_id = Column(UUID(as_uuid=True), ForeignKey("portfolios.id"),) portfolio = relationship("Portfolio", back_populates="state_machine") state = Column( SQLAEnum(FSMStates, native_enum=False, create_constraint=False), - default=FSMStates.UNSTARTED, nullable=False + default=FSMStates.UNSTARTED, + nullable=False, ) def __init__(self, portfolio, csp=None, **kwargs): @@ -60,15 +57,15 @@ class PortfolioStateMachine( Attach a machine depending on the current state. """ self.machine = StateMachineWithTags( - model = self, - send_event=True, - initial=self.current_state if self.state else FSMStates.UNSTARTED, - auto_transitions=False, - after_state_change='after_state_change', + model=self, + send_event=True, + initial=self.current_state if self.state else FSMStates.UNSTARTED, + auto_transitions=False, + after_state_change="after_state_change", ) states, transitions = _build_transitions(AzureStages) - self.machine.add_states(self.system_states+states) - self.machine.add_transitions(self.system_transitions+transitions) + self.machine.add_states(self.system_states + states) + self.machine.add_transitions(self.system_transitions + transitions) @property def current_state(self): @@ -87,37 +84,38 @@ class PortfolioStateMachine( elif self.current_state == FSMStates.STARTED: # get the first trigger that starts with 'create_' - create_trigger = list(filter(lambda trigger: trigger.startswith('create_'), - self.machine.get_triggers(FSMStates.STARTED.name)))[0] + create_trigger = list( + filter( + lambda trigger: trigger.startswith("create_"), + self.machine.get_triggers(FSMStates.STARTED.name), + ) + )[0] self.trigger(create_trigger) elif state_obj.is_IN_PROGRESS: pass - #elif state_obj.is_TENANT: + # elif state_obj.is_TENANT: # pass - #elif state_obj.is_BILLING_PROFILE: + # elif state_obj.is_BILLING_PROFILE: # pass - - #@with_payload + # @with_payload def after_in_progress_callback(self, event): - stage = self.current_state.name.split('_IN_PROGRESS')[0].lower() - if stage == 'tenant': - payload = dict( - creds={"username": "mock-cloud", "pass": "shh"}, - user_id='123', - password='123', - domain_name='123', - first_name='john', - last_name='doe', - country_code='US', - password_recovery_email_address='password@email.com' - ) - elif stage == 'billing_profile': - payload = dict( + stage = self.current_state.name.split("_IN_PROGRESS")[0].lower() + if stage == "tenant": + payload = dict( # nosec creds={"username": "mock-cloud", "pass": "shh"}, + user_id="123", + password="123", + domain_name="123", + first_name="john", + last_name="doe", + country_code="US", + password_recovery_email_address="password@email.com", ) + elif stage == "billing_profile": + payload = dict(creds={"username": "mock-cloud", "pass": "shh"},) payload_data_cls = get_stage_csp_class(stage, "payload") if not payload_data_cls: @@ -128,7 +126,7 @@ class PortfolioStateMachine( print(exc.json()) self.fail_stage(stage) - csp = event.kwargs.get('csp') + csp = event.kwargs.get("csp") if csp is not None: self.csp = AzureCSP(app).cloud else: @@ -136,18 +134,19 @@ class PortfolioStateMachine( for attempt in range(5): try: - response = getattr(self.csp, 'create_'+stage)(payload_data) + response = getattr(self.csp, "create_" + stage)(payload_data) except (ConnectionException, UnknownServerException) as exc: - print('caught exception. retry', attempt) + print("caught exception. retry", attempt) continue - else: break + else: + break else: # failed all attempts self.fail_stage(stage) if self.portfolio.csp_data is None: self.portfolio.csp_data = {} - self.portfolio.csp_data[stage+"_data"] = response + self.portfolio.csp_data[stage + "_data"] = response db.session.add(self.portfolio) db.session.commit() @@ -156,12 +155,13 @@ class PortfolioStateMachine( def is_csp_data_valid(self, event): # check portfolio csp details json field for fields - if self.portfolio.csp_data is None or \ - not isinstance(self.portfolio.csp_data, dict): + if self.portfolio.csp_data is None or not isinstance( + self.portfolio.csp_data, dict + ): return False - stage = self.current_state.name.split('_IN_PROGRESS')[0].lower() - stage_data = self.portfolio.csp_data.get(stage+"_data") + stage = self.current_state.name.split("_IN_PROGRESS")[0].lower() + stage_data = self.portfolio.csp_data.get(stage + "_data") cls = get_stage_csp_class(stage, "result") if not cls: return False @@ -174,8 +174,7 @@ class PortfolioStateMachine( return True - #print('failed condition', self.portfolio.csp_data) - + # print('failed condition', self.portfolio.csp_data) @property def application_id(self): diff --git a/tests/domain/test_portfolio_state_machine.py b/tests/domain/test_portfolio_state_machine.py index 405cfc50..0aa90867 100644 --- a/tests/domain/test_portfolio_state_machine.py +++ b/tests/domain/test_portfolio_state_machine.py @@ -13,26 +13,24 @@ def portfolio(): portfolio = PortfolioFactory.create() return portfolio + def test_fsm_creation(portfolio): sm = PortfolioStateMachineFactory.create(portfolio=portfolio) assert sm.portfolio + def test_fsm_transition_start(portfolio): sm = PortfolioStateMachineFactory.create(portfolio=portfolio) assert sm.portfolio assert sm.state == FSMStates.UNSTARTED # next_state does not create the trigger callbacks !!! - #sm.next_state() + # sm.next_state() sm.init() assert sm.state == FSMStates.STARTING sm.start() assert sm.state == FSMStates.STARTED - #import ipdb;ipdb.set_trace() sm.create_tenant(a=1, b=2) assert sm.state == FSMStates.TENANT_CREATED - - - diff --git a/tests/domain/test_portfolios.py b/tests/domain/test_portfolios.py index aaabbbc2..4a61e2cd 100644 --- a/tests/domain/test_portfolios.py +++ b/tests/domain/test_portfolios.py @@ -6,6 +6,7 @@ from atst.domain.portfolios import ( Portfolios, PortfolioError, PortfolioDeletionApplicationsExistError, + PortfolioStateMachines, ) from atst.domain.portfolio_roles import PortfolioRoles from atst.domain.applications import Applications @@ -256,16 +257,16 @@ def test_for_user_does_not_include_deleted_application_roles(): ) assert len(Portfolios.for_user(user2)) == 0 + def test_create_state_machine(portfolio): - fsm = Portfolios.create_state_machine(portfolio) + fsm = PortfolioStateMachines.create(portfolio) assert fsm + def test_get_portfolios_pending_provisioning(session): for x in range(5): portfolio = PortfolioFactory.create() sm = PortfolioStateMachineFactory.create(portfolio=portfolio) - if x == 2: sm.state = FSMStates.COMPLETED + if x == 2: + sm.state = FSMStates.COMPLETED assert len(Portfolios.get_portfolios_pending_provisioning()) == 4 - - - diff --git a/tests/factories.py b/tests/factories.py index c4b48d20..e47aa897 100644 --- a/tests/factories.py +++ b/tests/factories.py @@ -343,6 +343,7 @@ class NotificationRecipientFactory(Base): email = factory.Faker("email") + class PortfolioStateMachineFactory(Base): class Meta: model = PortfolioStateMachine @@ -352,6 +353,6 @@ class PortfolioStateMachineFactory(Base): @classmethod def _create(cls, model_class, *args, **kwargs): portfolio = kwargs.pop("portfolio", PortfolioFactory.create()) - kwargs.update({'portfolio': portfolio}) + kwargs.update({"portfolio": portfolio}) fsm = super()._create(model_class, *args, **kwargs) return fsm diff --git a/tests/test_jobs.py b/tests/test_jobs.py index dcecdc06..ff8e4602 100644 --- a/tests/test_jobs.py +++ b/tests/test_jobs.py @@ -17,6 +17,8 @@ from atst.jobs import ( create_environment, do_provision_user, do_provision_portfolio, + do_create_environment, + do_create_atat_admin_user, ) from atst.models.utils import claim_for_update from atst.domain.exceptions import ClaimFailedException @@ -34,6 +36,7 @@ from atst.models import CSPRole, EnvironmentRole, ApplicationRoleStatus def csp(): return Mock(wraps=MockCloudProvider({}, with_delay=False, with_failure=False)) + @pytest.fixture(scope="function") def portfolio(): portfolio = PortfolioFactory.create() @@ -316,21 +319,28 @@ def test_do_provision_user(csp, session): # I expect that the EnvironmentRole now has a csp_user_id assert environment_role.csp_user_id -def test_dispatch_provision_portfolio(csp, session, portfolio, celery_app, celery_worker, monkeypatch): + +def test_dispatch_provision_portfolio( + csp, session, portfolio, celery_app, celery_worker, monkeypatch +): sm = PortfolioStateMachineFactory.create(portfolio=portfolio) mock = Mock() monkeypatch.setattr("atst.jobs.provision_portfolio", mock) dispatch_provision_portfolio.run() mock.delay.assert_called_once_with(portfolio_id=portfolio.id) + def test_do_provision_portfolio(csp, session, portfolio): do_provision_portfolio(csp=csp, portfolio_id=portfolio.id) session.refresh(portfolio) assert portfolio.state_machine -def test_provision_portfolio_create_tenant(csp, session, portfolio, celery_app, celery_worker, monkeypatch): + +def test_provision_portfolio_create_tenant( + csp, session, portfolio, celery_app, celery_worker, monkeypatch +): sm = PortfolioStateMachineFactory.create(portfolio=portfolio) - #mock = Mock() - #monkeypatch.setattr("atst.jobs.provision_portfolio", mock) - #dispatch_provision_portfolio.run() - #mock.delay.assert_called_once_with(portfolio_id=portfolio.id) + # mock = Mock() + # monkeypatch.setattr("atst.jobs.provision_portfolio", mock) + # dispatch_provision_portfolio.run() + # mock.delay.assert_called_once_with(portfolio_id=portfolio.id)