state machine integration wip

This commit is contained in:
tomdds 2020-01-16 13:44:10 -05:00
parent 187ee0033e
commit b1adaf771d
6 changed files with 204 additions and 59 deletions

View File

@ -34,9 +34,7 @@ def make_csp_provider(app, csp=None):
def _stage_to_classname(stage): def _stage_to_classname(stage):
return "".join( return "".join(map(lambda word: word.capitalize(), stage.split("_")))
map(lambda word: word.capitalize(), stage.replace("_", " ").split(" "))
)
def get_stage_csp_class(stage, class_type): def get_stage_csp_class(stage, class_type):
@ -45,7 +43,7 @@ def get_stage_csp_class(stage, class_type):
class_type is either 'payload' or 'result' class_type is either 'payload' or 'result'
""" """
cls_name = "".join([_stage_to_classname(stage), "CSP", class_type.capitalize()]) cls_name = f"{_stage_to_classname(stage)}CSP{class_type.capitalize()}"
try: try:
return getattr(importlib.import_module("atst.domain.csp.cloud"), cls_name) return getattr(importlib.import_module("atst.domain.csp.cloud"), cls_name)
except AttributeError: except AttributeError:

View File

@ -186,11 +186,29 @@ class TenantCSPResult(AliasModel):
tenant_id: str tenant_id: str
user_object_id: str user_object_id: str
tenant_admin_username: str
tenant_admin_password: str
class Config: class Config:
fields = { fields = {
"user_object_id": "objectId", "user_object_id": "objectId",
} }
def dict(self, *args, **kwargs):
exclude = {"tenant_admin_username", "tenant_admin_password"}
if "exclude" not in kwargs:
kwargs["exclude"] = exclude
else:
kwargs["exclude"].update(exclude)
return super().dict(*args, **kwargs)
def get_creds(self):
return {
"tenant_admin_username": self.tenant_admin_username,
"tenant_admin_password": self.tenant_admin_password,
"tenant_id": self.tenant_id
}
class BillingProfileAddress(AliasModel): class BillingProfileAddress(AliasModel):
company_name: str company_name: str
@ -215,7 +233,7 @@ class BillingProfileCLINBudget(AliasModel):
class BillingProfileCSPPayload(BaseCSPPayload): class BillingProfileCSPPayload(BaseCSPPayload):
tenant_id: str tenant_id: str
display_name: str billing_profile_display_name: str
enabled_azure_plans: Optional[List[str]] enabled_azure_plans: Optional[List[str]]
address: BillingProfileAddress address: BillingProfileAddress
@ -229,6 +247,11 @@ class BillingProfileCSPPayload(BaseCSPPayload):
""" """
return v or [] return v or []
class Config:
fields = {
"billing_profile_display_name": "displayName"
}
class BillingProfileCreateCSPResult(AliasModel): class BillingProfileCreateCSPResult(AliasModel):
billing_profile_validate_url: str billing_profile_validate_url: str
@ -252,9 +275,14 @@ class BillingInvoiceSection(AliasModel):
class BillingProfileProperties(AliasModel): class BillingProfileProperties(AliasModel):
address: BillingProfileAddress address: BillingProfileAddress
display_name: str billing_profile_display_name: str
invoice_sections: List[BillingInvoiceSection] invoice_sections: List[BillingInvoiceSection]
class Config:
fields = {
"billing_profile_display_name": "displayName"
}
class BillingProfileCSPResult(AliasModel): class BillingProfileCSPResult(AliasModel):
billing_profile_id: str billing_profile_id: str
@ -269,14 +297,14 @@ class BillingProfileCSPResult(AliasModel):
} }
class BillingRoleAssignmentCSPPayload(BaseCSPPayload): class BillingProfileTenantAccessCSPPayload(BaseCSPPayload):
tenant_id: str tenant_id: str
user_object_id: str user_object_id: str
billing_account_name: str billing_account_name: str
billing_profile_name: str billing_profile_name: str
class BillingRoleAssignmentCSPResult(AliasModel): class BillingProfileTenantAccessCSPResult(AliasModel):
billing_role_assignment_id: str billing_role_assignment_id: str
billing_role_assignment_name: str billing_role_assignment_name: str
@ -286,7 +314,7 @@ class BillingRoleAssignmentCSPResult(AliasModel):
"billing_role_assignment_name": "name", "billing_role_assignment_name": "name",
} }
class EnableTaskOrderBillingCSPPayload(BaseCSPPayload): class TaskOrderBillingCSPPayload(BaseCSPPayload):
billing_account_name: str billing_account_name: str
billing_profile_name: str billing_profile_name: str
@ -297,14 +325,14 @@ class EnableTaskOrderBillingCSPResult(AliasModel):
class Config: class Config:
fields = {"task_order_billing_validation_url": "Location", "retry_after": "Retry-After"} fields = {"task_order_billing_validation_url": "Location", "retry_after": "Retry-After"}
class VerifyTaskOrderBillingCSPPayload(BaseCSPPayload): class TaskOrderBillingCSPResult(BaseCSPPayload):
task_order_billing_validation_url: str task_order_billing_validation_url: str
class BillingProfileEnabledPlanDetails(AliasModel): class BillingProfileEnabledPlanDetails(AliasModel):
enabled_azure_plans: List[Dict] enabled_azure_plans: List[Dict]
class BillingProfileEnabledCSPResult(AliasModel): class TaskOrderBillingCSPResult(AliasModel):
billing_profile_id: str billing_profile_id: str
billing_profile_name: str billing_profile_name: str
billing_profile_enabled_plan_details: BillingProfileEnabledPlanDetails billing_profile_enabled_plan_details: BillingProfileEnabledPlanDetails
@ -534,9 +562,11 @@ class MockCloudProvider(CloudProviderInterface):
"tenant_id": response["tenantId"], "tenant_id": response["tenantId"],
"user_id": response["userId"], "user_id": response["userId"],
"user_object_id": response["objectId"], "user_object_id": response["objectId"],
"tenant_admin_username": "test",
"tenant_admin_password": "test"
} }
def create_billing_profile(self, creds, tenant_admin_details, billing_owner_id): def create_billing_profile(self, payload):
# call billing profile creation endpoint, specifying owner # call billing profile creation endpoint, specifying owner
# Payload: # Payload:
""" """
@ -576,7 +606,55 @@ class MockCloudProvider(CloudProviderInterface):
self._maybe_raise(self.UNAUTHORIZED_RATE, self.AUTHORIZATION_EXCEPTION) self._maybe_raise(self.UNAUTHORIZED_RATE, self.AUTHORIZATION_EXCEPTION)
response = {"id": "string"} response = {"id": "string"}
return {"billing_profile_id": response["id"]} # return {"billing_profile_id": response["id"]}
return {
'id': '/providers/Microsoft.Billing/billingAccounts/7c89b735-b22b-55c0-ab5a-c624843e8bf6:de4416ce-acc6-44b1-8122-c87c4e903c91_2019-05-31/billingProfiles/KQWI-W2SU-BG7-TGB',
'name': 'KQWI-W2SU-BG7-TGB',
'properties': {
'address': {
'addressLine1': '123 S Broad Street, Suite 2400',
'city': 'Philadelphia',
'companyName': 'Promptworks',
'country': 'US',
'postalCode': '19109',
'region': 'PA'
},
'currency': 'USD',
'displayName': 'Test Billing Profile',
'enabledAzurePlans': [],
'hasReadAccess': True,
'invoiceDay': 5,
'invoiceEmailOptIn': False,
'invoiceSections': [{
'id': '/providers/Microsoft.Billing/billingAccounts/7c89b735-b22b-55c0-ab5a-c624843e8bf6:de4416ce-acc6-44b1-8122-c87c4e903c91_2019-05-31/billingProfiles/KQWI-W2SU-BG7-TGB/invoiceSections/CHCO-BAAR-PJA-TGB',
'name': 'CHCO-BAAR-PJA-TGB',
'properties': {
'displayName': 'Test Billing Profile'
},
'type': 'Microsoft.Billing/billingAccounts/billingProfiles/invoiceSections'
}]
},
'type': 'Microsoft.Billing/billingAccounts/billingProfiles'
}
def create_billing_profile_tenant_access(self, payload):
self._maybe_raise(self.NETWORK_FAILURE_PCT, self.NETWORK_EXCEPTION)
self._maybe_raise(self.SERVER_FAILURE_PCT, self.SERVER_EXCEPTION)
self._maybe_raise(self.UNAUTHORIZED_RATE, self.AUTHORIZATION_EXCEPTION)
return {
"id": "/providers/Microsoft.Billing/billingAccounts/7c89b735-b22b-55c0-ab5a-c624843e8bf6:de4416ce-acc6-44b1-8122-c87c4e903c91_2019-05-31/billingProfiles/KQWI-W2SU-BG7-TGB/billingRoleAssignments/40000000-aaaa-bbbb-cccc-100000000000_0a5f4926-e3ee-4f47-a6e3-8b0a30a40e3d",
"name": "40000000-aaaa-bbbb-cccc-100000000000_0a5f4926-e3ee-4f47-a6e3-8b0a30a40e3d",
"properties": {
"createdOn": "2020-01-14T14:39:26.3342192+00:00",
"createdByPrincipalId": "82e2b376-3297-4096-8743-ed65b3be0b03",
"principalId": "0a5f4926-e3ee-4f47-a6e3-8b0a30a40e3d",
"principalTenantId": "60ff9d34-82bf-4f21-b565-308ef0533435",
"roleDefinitionId": "/providers/Microsoft.Billing/billingAccounts/7c89b735-b22b-55c0-ab5a-c624843e8bf6:de4416ce-acc6-44b1-8122-c87c4e903c91_2019-05-31/billingProfiles/KQWI-W2SU-BG7-TGB/billingRoleDefinitions/40000000-aaaa-bbbb-cccc-100000000000",
"scope": "/providers/Microsoft.Billing/billingAccounts/7c89b735-b22b-55c0-ab5a-c624843e8bf6:de4416ce-acc6-44b1-8122-c87c4e903c91_2019-05-31/billingProfiles/KQWI-W2SU-BG7-TGB"
},
"type": "Microsoft.Billing/billingRoleAssignments"
}
def create_or_update_user(self, auth_credentials, user_info, csp_role_id): def create_or_update_user(self, auth_credentials, user_info, csp_role_id):
self._authorize(auth_credentials) self._authorize(auth_credentials)
@ -633,7 +711,7 @@ class MockCloudProvider(CloudProviderInterface):
@property @property
def _auth_credentials(self): def _auth_credentials(self):
return {"username": "mock-cloud", "pass": "shh"} return {"username": "mock-cloud", "password": "shh"}
def _authorize(self, credentials): def _authorize(self, credentials):
self._delay(1, 5) self._delay(1, 5)
@ -778,6 +856,9 @@ class AzureCloudProvider(CloudProviderInterface):
headers=create_tenant_headers, headers=create_tenant_headers,
) )
print('create tenant result')
print(result.json())
if result.status_code == 200: if result.status_code == 200:
return self._ok(TenantCSPResult(**result.json())) return self._ok(TenantCSPResult(**result.json()))
else: else:
@ -836,7 +917,7 @@ class AzureCloudProvider(CloudProviderInterface):
else: else:
return self._error(result.json()) return self._error(result.json())
def grant_billing_profile_tenant_access(self, payload: BillingRoleAssignmentCSPPayload): def create_billing_profile_tenant_access(self, payload: BillingProfileTenantAccessCSPPayload):
sp_token = self._get_sp_token(payload.creds) sp_token = self._get_sp_token(payload.creds)
request_body = { request_body = {
"properties": { "properties": {
@ -854,11 +935,11 @@ class AzureCloudProvider(CloudProviderInterface):
result = self.sdk.requests.post(url, headers=headers, json=request_body) result = self.sdk.requests.post(url, headers=headers, json=request_body)
if result.status_code == 201: if result.status_code == 201:
return self._ok(BillingRoleAssignmentCSPResult(**result.json())) return self._ok(BillingProfileTenantAccessCSPResult(**result.json()))
else: else:
return self._error(result.json()) return self._error(result.json())
def enable_task_order_billing(self, payload: EnableTaskOrderBillingCSPPayload): def enable_task_order_billing(self, payload: TaskOrderBillingCSPPayload):
sp_token = self._get_sp_token(payload.creds) sp_token = self._get_sp_token(payload.creds)
request_body = [ request_body = [
{ {
@ -884,7 +965,7 @@ class AzureCloudProvider(CloudProviderInterface):
# 202 has location/retry after headers # 202 has location/retry after headers
return self._ok(BillingProfileCreateCSPResult(**result.headers)) return self._ok(BillingProfileCreateCSPResult(**result.headers))
elif result.status_code == 200: elif result.status_code == 200:
return self._ok(BillingProfileEnabledCSPResult(**result.json())) return self._ok(TaskOrderBillingCSPResult(**result.json()))
else: else:
return self._error(result.json()) return self._error(result.json())
@ -903,13 +984,13 @@ class AzureCloudProvider(CloudProviderInterface):
if result.status_code == 202: if result.status_code == 202:
# 202 has location/retry after headers # 202 has location/retry after headers
return self._ok(EnableTaskOrderBillingCSPResult(**result.headers)) return self._ok(TaskOrderBillingCSPResult(**result.headers))
elif result.status_code == 200: elif result.status_code == 200:
return self._ok(BillingProfileEnabledCSPResult(**result.json())) return self._ok(TaskOrderBillingCSPResult(**result.json()))
else: else:
return self._error(result.json()) return self._error(result.json())
def report_clin(self, payload: ReportCLINCSPPayload): def create_billing_instruction(self, payload: ReportCLINCSPPayload):
sp_token = self._get_sp_token(payload.creds) sp_token = self._get_sp_token(payload.creds)
if sp_token is None: if sp_token is None:
raise AuthenticationException( raise AuthenticationException(
@ -1022,7 +1103,9 @@ class AzureCloudProvider(CloudProviderInterface):
return sub_id_match.group(1) return sub_id_match.group(1)
def _get_sp_token(self, creds): def _get_sp_token(self, creds):
home_tenant_id = creds.get("home_tenant_id") home_tenant_id = creds.get(
"home_tenant_id"
)
client_id = creds.get("client_id") client_id = creds.get("client_id")
secret_key = creds.get("secret_key") secret_key = creds.get("secret_key")

View File

@ -10,6 +10,9 @@ class StageStates(Enum):
class AzureStages(Enum): class AzureStages(Enum):
TENANT = "tenant" TENANT = "tenant"
BILLING_PROFILE = "billing profile" BILLING_PROFILE = "billing profile"
BILLING_PROFILE_TENANT_ACCESS = "billing profile tenant access"
TASK_ORDER_BILLING = "task order billing"
BILLING_INSTRUCTION = "billing instruction"
def _build_csp_states(csp_stages): def _build_csp_states(csp_stages):

View File

@ -1,3 +1,7 @@
from random import choice, choices
import re
import string
from sqlalchemy import Column, ForeignKey, Enum as SQLAEnum from sqlalchemy import Column, ForeignKey, Enum as SQLAEnum
from sqlalchemy.orm import relationship, reconstructor from sqlalchemy.orm import relationship, reconstructor
from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.dialects.postgresql import UUID
@ -17,6 +21,16 @@ import atst.models.mixins as mixins
from atst.models.mixins.state_machines import FSMStates, AzureStages, _build_transitions from atst.models.mixins.state_machines import FSMStates, AzureStages, _build_transitions
def make_password():
return choice(string.ascii_letters) + "".join(
choices(string.ascii_letters + string.digits + string.punctuation, k=15)
)
def fetch_portfolio_creds(portfolio):
return dict(username="mock-cloud", password="shh")
@add_state_features(Tags) @add_state_features(Tags)
class StateMachineWithTags(Machine): class StateMachineWithTags(Machine):
pass pass
@ -73,57 +87,49 @@ class PortfolioStateMachine(
return getattr(FSMStates, self.state) return getattr(FSMStates, self.state)
return self.state return self.state
def trigger_next_transition(self): def trigger_next_transition(self, **kwargs):
state_obj = self.machine.get_state(self.state) state_obj = self.machine.get_state(self.state)
if state_obj.is_system: if state_obj.is_system:
if self.current_state in (FSMStates.UNSTARTED, FSMStates.STARTING): if self.current_state in (FSMStates.UNSTARTED, FSMStates.STARTING):
# call the first trigger availabe for these two system states # call the first trigger availabe for these two system states
trigger_name = self.machine.get_triggers(self.current_state.name)[0] trigger_name = self.machine.get_triggers(self.current_state.name)[0]
self.trigger(trigger_name) self.trigger(trigger_name, **kwargs)
elif self.current_state == FSMStates.STARTED: elif self.current_state == FSMStates.STARTED:
# get the first trigger that starts with 'create_' # get the first trigger that starts with 'create_'
create_trigger = self._get_first_stage_create_trigger() create_trigger = self._get_first_stage_create_trigger()
if create_trigger: if create_trigger:
self.trigger(create_trigger) self.trigger(create_trigger, **kwargs)
else: else:
self.fail_stage(stage) self.fail_stage(stage)
elif state_obj.is_IN_PROGRESS: elif state_obj.is_CREATED:
pass triggers = self.machine.get_triggers(state_obj.name)
self.trigger(triggers[-1], **kwargs)
# elif state_obj.is_TENANT:
# pass
# elif state_obj.is_BILLING_PROFILE:
# pass
# @with_payload # @with_payload
def after_in_progress_callback(self, event): def after_in_progress_callback(self, event):
stage = self.current_state.name.split("_IN_PROGRESS")[0].lower() stage = self.current_state.name.split("_IN_PROGRESS")[0].lower()
if stage == "tenant":
payload = dict( # nosec # Accumulate payload w/ creds
creds={"username": "mock-cloud", "pass": "shh"}, payload = event.kwargs.get("csp_data")
user_id="123", payload["creds"] = event.kwargs.get("creds")
password="123",
domain_name="123",
first_name="john",
last_name="doe",
country_code="US",
password_recovery_email_address="password@email.com",
)
elif stage == "billing_profile":
payload = dict(creds={"username": "mock-cloud", "pass": "shh"},)
payload_data_cls = get_stage_csp_class(stage, "payload") payload_data_cls = get_stage_csp_class(stage, "payload")
if not payload_data_cls: if not payload_data_cls:
print("could not resolve payload data class")
self.fail_stage(stage) self.fail_stage(stage)
try: try:
payload_data = payload_data_cls(**payload) payload_data = payload_data_cls(**payload)
except PydanticValidationError as exc: except PydanticValidationError as exc:
print("Payload Validation Error:")
print(exc.json()) print(exc.json())
print("got")
print(payload)
self.fail_stage(stage) self.fail_stage(stage)
# TODO: Determine best place to do this, maybe @reconstructor
csp = event.kwargs.get("csp") csp = event.kwargs.get("csp")
if csp is not None: if csp is not None:
self.csp = AzureCSP(app).cloud self.csp = AzureCSP(app).cloud
@ -132,7 +138,8 @@ class PortfolioStateMachine(
for attempt in range(5): for attempt in range(5):
try: try:
response = getattr(self.csp, "create_" + stage)(payload_data) func_name = f"create_{stage}"
response = getattr(self.csp, func_name)(payload_data)
except (ConnectionException, UnknownServerException) as exc: except (ConnectionException, UnknownServerException) as exc:
print("caught exception. retry", attempt) print("caught exception. retry", attempt)
continue continue
@ -140,14 +147,17 @@ class PortfolioStateMachine(
break break
else: else:
# failed all attempts # failed all attempts
print("failed")
self.fail_stage(stage) self.fail_stage(stage)
if self.portfolio.csp_data is None: if self.portfolio.csp_data is None:
self.portfolio.csp_data = {} self.portfolio.csp_data = {}
self.portfolio.csp_data[stage + "_data"] = response self.portfolio.csp_data.update(response)
db.session.add(self.portfolio) db.session.add(self.portfolio)
db.session.commit() db.session.commit()
# store any updated creds, if necessary
self.finish_stage(stage) self.finish_stage(stage)
def is_csp_data_valid(self, event): def is_csp_data_valid(self, event):
@ -156,16 +166,23 @@ class PortfolioStateMachine(
if self.portfolio.csp_data is None or not isinstance( if self.portfolio.csp_data is None or not isinstance(
self.portfolio.csp_data, dict self.portfolio.csp_data, dict
): ):
print("no csp data")
return False return False
stage = self.current_state.name.split("_IN_PROGRESS")[0].lower() stage = self.current_state.name.split("_IN_PROGRESS")[0].lower()
stage_data = self.portfolio.csp_data.get(stage + "_data") stage_data = self.portfolio.csp_data
cls = get_stage_csp_class(stage, "result") cls = get_stage_csp_class(stage, "result")
if not cls: if not cls:
return False return False
try: try:
cls(**stage_data) dc = cls(**stage_data)
if getattr(dc, "get_creds", None) is not None:
new_creds = dc.get_creds()
# TODO: how/where to store these
# TODO: credential schema
# self.store_creds(self.portfolio, new_creds)
except PydanticValidationError as exc: except PydanticValidationError as exc:
print(exc.json()) print(exc.json())
return False return False

View File

@ -153,7 +153,7 @@ def test_create_tenant(mock_azure: AzureCloudProvider):
mock_azure.sdk.requests.post.return_value = mock_result mock_azure.sdk.requests.post.return_value = mock_result
payload = TenantCSPPayload( payload = TenantCSPPayload(
**dict( **dict(
creds={"username": "mock-cloud", "pass": "shh"}, creds={"username": "mock-cloud", "password": "shh"},
user_id="admin", user_id="admin",
password="JediJan13$coot", password="JediJan13$coot",
domain_name="jediccpospawnedtenant2", domain_name="jediccpospawnedtenant2",
@ -190,7 +190,7 @@ def test_create_billing_profile(mock_azure: AzureCloudProvider):
country="US", country="US",
postal_code="19109", postal_code="19109",
), ),
creds={"username": "mock-cloud", "pass": "shh"}, creds={"username": "mock-cloud", "password": "shh"},
tenant_id="60ff9d34-82bf-4f21-b565-308ef0533435", tenant_id="60ff9d34-82bf-4f21-b565-308ef0533435",
display_name="Test Billing Profile", display_name="Test Billing Profile",
) )
@ -258,7 +258,7 @@ def test_validate_billing_profile_creation(mock_azure: AzureCloudProvider):
) )
def test_grant_billing_profile_tenant_access(mock_azure: AzureCloudProvider): def test_create_billing_profile_tenant_access(mock_azure: AzureCloudProvider):
mock_azure.sdk.adal.AuthenticationContext.return_value.context.acquire_token_with_client_credentials.return_value = { mock_azure.sdk.adal.AuthenticationContext.return_value.context.acquire_token_with_client_credentials.return_value = {
"accessToken": "TOKEN" "accessToken": "TOKEN"
} }
@ -295,7 +295,7 @@ def test_grant_billing_profile_tenant_access(mock_azure: AzureCloudProvider):
) )
) )
result = mock_azure.grant_billing_profile_tenant_access(payload) result = mock_azure.create_billing_profile_tenant_access(payload)
body: BillingRoleAssignmentCSPResult = result.get("body") body: BillingRoleAssignmentCSPResult = result.get("body")
assert ( assert (
body.billing_role_assignment_name body.billing_role_assignment_name
@ -303,7 +303,7 @@ def test_grant_billing_profile_tenant_access(mock_azure: AzureCloudProvider):
) )
def test_enable_task_order_billing(mock_azure: AzureCloudProvider): def test_create_task_order_billing(mock_azure: AzureCloudProvider):
mock_azure.sdk.adal.AuthenticationContext.return_value.context.acquire_token_with_client_credentials.return_value = { mock_azure.sdk.adal.AuthenticationContext.return_value.context.acquire_token_with_client_credentials.return_value = {
"accessToken": "TOKEN" "accessToken": "TOKEN"
} }
@ -401,7 +401,7 @@ def test_validate_task_order_billing_enabled(mock_azure):
) )
def test_report_clin(mock_azure: AzureCloudProvider): def test_create_billing_instruction(mock_azure: AzureCloudProvider):
mock_azure.sdk.adal.AuthenticationContext.return_value.context.acquire_token_with_client_credentials.return_value = { mock_azure.sdk.adal.AuthenticationContext.return_value.context.acquire_token_with_client_credentials.return_value = {
"accessToken": "TOKEN" "accessToken": "TOKEN"
} }
@ -432,7 +432,7 @@ def test_report_clin(mock_azure: AzureCloudProvider):
billing_profile_name="KQWI-W2SU-BG7-TGB", billing_profile_name="KQWI-W2SU-BG7-TGB",
) )
) )
result = mock_azure.report_clin(payload) result = mock_azure.create_billing_instruction(payload)
body: ReportCLINCSPResult = result.get("body") body: ReportCLINCSPResult = result.get("body")
assert body.reported_clin_name == "TO1:CLIN001" assert body.reported_clin_name == "TO1:CLIN001"

View File

@ -1,11 +1,12 @@
import pytest import pytest
import re
from tests.factories import ( from tests.factories import (
PortfolioFactory, PortfolioFactory,
PortfolioStateMachineFactory, PortfolioStateMachineFactory,
) )
from atst.models import FSMStates from atst.models import FSMStates, PortfolioStateMachine
from atst.models.mixins.state_machines import AzureStages, StageStates, compose_state from atst.models.mixins.state_machines import AzureStages, StageStates, compose_state
from atst.domain.csp import get_stage_csp_class from atst.domain.csp import get_stage_csp_class
@ -78,7 +79,7 @@ def test_state_machine_initialization(portfolio):
def test_fsm_transition_start(portfolio): def test_fsm_transition_start(portfolio):
sm = PortfolioStateMachineFactory.create(portfolio=portfolio) sm: PortfolioStateMachine = PortfolioStateMachineFactory.create(portfolio=portfolio)
assert sm.portfolio assert sm.portfolio
assert sm.state == FSMStates.UNSTARTED assert sm.state == FSMStates.UNSTARTED
@ -87,5 +88,48 @@ def test_fsm_transition_start(portfolio):
sm.start() sm.start()
assert sm.state == FSMStates.STARTED assert sm.state == FSMStates.STARTED
sm.create_tenant(a=1, b=2)
# Should source all creds for portfolio? might be easier to manage than per-step specific ones
creds = {"username": "mock-cloud", "password": "shh"}
if portfolio.csp_data is not None:
csp_data = portfolio.csp_data
else:
csp_data = {}
ppoc = portfolio.owner
user_id = f"{ppoc.first_name[0]}{ppoc.last_name}".lower()
domain_name = re.sub("[^0-9a-zA-Z]+", "", portfolio.name).lower()
portfolio_data = {
"user_id": user_id,
"password": "jklfsdNCVD83nklds2#202",
"domain_name": domain_name,
"first_name": ppoc.first_name,
"last_name": ppoc.last_name,
"country_code": "US",
"password_recovery_email_address": ppoc.email,
"address": {
"company_name": "",
"address_line_1": "",
"city": "",
"region": "",
"country": "",
"postal_code": "",
},
"billing_profile_display_name": "My Billing Profile",
}
collected_data = dict(list(csp_data.items()) + list(portfolio_data.items()))
sm.trigger_next_transition(creds=creds, csp_data=collected_data)
assert sm.state == FSMStates.TENANT_CREATED assert sm.state == FSMStates.TENANT_CREATED
assert portfolio.csp_data.get("tenant_id", None) is not None
if portfolio.csp_data is not None:
csp_data = portfolio.csp_data
else:
csp_data = {}
collected_data = dict(list(csp_data.items()) + list(portfolio_data.items()))
sm.trigger_next_transition(creds=creds, csp_data=collected_data)
assert sm.state == FSMStates.BILLING_PROFILE_CREATED