state machine integration wip

This commit is contained in:
tomdds 2020-01-16 13:44:10 -05:00
parent 187ee0033e
commit b1adaf771d
6 changed files with 204 additions and 59 deletions

View File

@ -34,9 +34,7 @@ def make_csp_provider(app, csp=None):
def _stage_to_classname(stage):
return "".join(
map(lambda word: word.capitalize(), stage.replace("_", " ").split(" "))
)
return "".join(map(lambda word: word.capitalize(), stage.split("_")))
def get_stage_csp_class(stage, class_type):
@ -45,7 +43,7 @@ def get_stage_csp_class(stage, class_type):
class_type is either 'payload' or 'result'
"""
cls_name = "".join([_stage_to_classname(stage), "CSP", class_type.capitalize()])
cls_name = f"{_stage_to_classname(stage)}CSP{class_type.capitalize()}"
try:
return getattr(importlib.import_module("atst.domain.csp.cloud"), cls_name)
except AttributeError:

View File

@ -186,11 +186,29 @@ class TenantCSPResult(AliasModel):
tenant_id: str
user_object_id: str
tenant_admin_username: str
tenant_admin_password: str
class Config:
fields = {
"user_object_id": "objectId",
}
def dict(self, *args, **kwargs):
exclude = {"tenant_admin_username", "tenant_admin_password"}
if "exclude" not in kwargs:
kwargs["exclude"] = exclude
else:
kwargs["exclude"].update(exclude)
return super().dict(*args, **kwargs)
def get_creds(self):
return {
"tenant_admin_username": self.tenant_admin_username,
"tenant_admin_password": self.tenant_admin_password,
"tenant_id": self.tenant_id
}
class BillingProfileAddress(AliasModel):
company_name: str
@ -215,7 +233,7 @@ class BillingProfileCLINBudget(AliasModel):
class BillingProfileCSPPayload(BaseCSPPayload):
tenant_id: str
display_name: str
billing_profile_display_name: str
enabled_azure_plans: Optional[List[str]]
address: BillingProfileAddress
@ -229,6 +247,11 @@ class BillingProfileCSPPayload(BaseCSPPayload):
"""
return v or []
class Config:
fields = {
"billing_profile_display_name": "displayName"
}
class BillingProfileCreateCSPResult(AliasModel):
billing_profile_validate_url: str
@ -252,9 +275,14 @@ class BillingInvoiceSection(AliasModel):
class BillingProfileProperties(AliasModel):
address: BillingProfileAddress
display_name: str
billing_profile_display_name: str
invoice_sections: List[BillingInvoiceSection]
class Config:
fields = {
"billing_profile_display_name": "displayName"
}
class BillingProfileCSPResult(AliasModel):
billing_profile_id: str
@ -269,14 +297,14 @@ class BillingProfileCSPResult(AliasModel):
}
class BillingRoleAssignmentCSPPayload(BaseCSPPayload):
class BillingProfileTenantAccessCSPPayload(BaseCSPPayload):
tenant_id: str
user_object_id: str
billing_account_name: str
billing_profile_name: str
class BillingRoleAssignmentCSPResult(AliasModel):
class BillingProfileTenantAccessCSPResult(AliasModel):
billing_role_assignment_id: str
billing_role_assignment_name: str
@ -286,7 +314,7 @@ class BillingRoleAssignmentCSPResult(AliasModel):
"billing_role_assignment_name": "name",
}
class EnableTaskOrderBillingCSPPayload(BaseCSPPayload):
class TaskOrderBillingCSPPayload(BaseCSPPayload):
billing_account_name: str
billing_profile_name: str
@ -297,14 +325,14 @@ class EnableTaskOrderBillingCSPResult(AliasModel):
class Config:
fields = {"task_order_billing_validation_url": "Location", "retry_after": "Retry-After"}
class VerifyTaskOrderBillingCSPPayload(BaseCSPPayload):
class TaskOrderBillingCSPResult(BaseCSPPayload):
task_order_billing_validation_url: str
class BillingProfileEnabledPlanDetails(AliasModel):
enabled_azure_plans: List[Dict]
class BillingProfileEnabledCSPResult(AliasModel):
class TaskOrderBillingCSPResult(AliasModel):
billing_profile_id: str
billing_profile_name: str
billing_profile_enabled_plan_details: BillingProfileEnabledPlanDetails
@ -534,9 +562,11 @@ class MockCloudProvider(CloudProviderInterface):
"tenant_id": response["tenantId"],
"user_id": response["userId"],
"user_object_id": response["objectId"],
"tenant_admin_username": "test",
"tenant_admin_password": "test"
}
def create_billing_profile(self, creds, tenant_admin_details, billing_owner_id):
def create_billing_profile(self, payload):
# call billing profile creation endpoint, specifying owner
# Payload:
"""
@ -576,7 +606,55 @@ class MockCloudProvider(CloudProviderInterface):
self._maybe_raise(self.UNAUTHORIZED_RATE, self.AUTHORIZATION_EXCEPTION)
response = {"id": "string"}
return {"billing_profile_id": response["id"]}
# return {"billing_profile_id": response["id"]}
return {
'id': '/providers/Microsoft.Billing/billingAccounts/7c89b735-b22b-55c0-ab5a-c624843e8bf6:de4416ce-acc6-44b1-8122-c87c4e903c91_2019-05-31/billingProfiles/KQWI-W2SU-BG7-TGB',
'name': 'KQWI-W2SU-BG7-TGB',
'properties': {
'address': {
'addressLine1': '123 S Broad Street, Suite 2400',
'city': 'Philadelphia',
'companyName': 'Promptworks',
'country': 'US',
'postalCode': '19109',
'region': 'PA'
},
'currency': 'USD',
'displayName': 'Test Billing Profile',
'enabledAzurePlans': [],
'hasReadAccess': True,
'invoiceDay': 5,
'invoiceEmailOptIn': False,
'invoiceSections': [{
'id': '/providers/Microsoft.Billing/billingAccounts/7c89b735-b22b-55c0-ab5a-c624843e8bf6:de4416ce-acc6-44b1-8122-c87c4e903c91_2019-05-31/billingProfiles/KQWI-W2SU-BG7-TGB/invoiceSections/CHCO-BAAR-PJA-TGB',
'name': 'CHCO-BAAR-PJA-TGB',
'properties': {
'displayName': 'Test Billing Profile'
},
'type': 'Microsoft.Billing/billingAccounts/billingProfiles/invoiceSections'
}]
},
'type': 'Microsoft.Billing/billingAccounts/billingProfiles'
}
def create_billing_profile_tenant_access(self, payload):
self._maybe_raise(self.NETWORK_FAILURE_PCT, self.NETWORK_EXCEPTION)
self._maybe_raise(self.SERVER_FAILURE_PCT, self.SERVER_EXCEPTION)
self._maybe_raise(self.UNAUTHORIZED_RATE, self.AUTHORIZATION_EXCEPTION)
return {
"id": "/providers/Microsoft.Billing/billingAccounts/7c89b735-b22b-55c0-ab5a-c624843e8bf6:de4416ce-acc6-44b1-8122-c87c4e903c91_2019-05-31/billingProfiles/KQWI-W2SU-BG7-TGB/billingRoleAssignments/40000000-aaaa-bbbb-cccc-100000000000_0a5f4926-e3ee-4f47-a6e3-8b0a30a40e3d",
"name": "40000000-aaaa-bbbb-cccc-100000000000_0a5f4926-e3ee-4f47-a6e3-8b0a30a40e3d",
"properties": {
"createdOn": "2020-01-14T14:39:26.3342192+00:00",
"createdByPrincipalId": "82e2b376-3297-4096-8743-ed65b3be0b03",
"principalId": "0a5f4926-e3ee-4f47-a6e3-8b0a30a40e3d",
"principalTenantId": "60ff9d34-82bf-4f21-b565-308ef0533435",
"roleDefinitionId": "/providers/Microsoft.Billing/billingAccounts/7c89b735-b22b-55c0-ab5a-c624843e8bf6:de4416ce-acc6-44b1-8122-c87c4e903c91_2019-05-31/billingProfiles/KQWI-W2SU-BG7-TGB/billingRoleDefinitions/40000000-aaaa-bbbb-cccc-100000000000",
"scope": "/providers/Microsoft.Billing/billingAccounts/7c89b735-b22b-55c0-ab5a-c624843e8bf6:de4416ce-acc6-44b1-8122-c87c4e903c91_2019-05-31/billingProfiles/KQWI-W2SU-BG7-TGB"
},
"type": "Microsoft.Billing/billingRoleAssignments"
}
def create_or_update_user(self, auth_credentials, user_info, csp_role_id):
self._authorize(auth_credentials)
@ -633,7 +711,7 @@ class MockCloudProvider(CloudProviderInterface):
@property
def _auth_credentials(self):
return {"username": "mock-cloud", "pass": "shh"}
return {"username": "mock-cloud", "password": "shh"}
def _authorize(self, credentials):
self._delay(1, 5)
@ -778,6 +856,9 @@ class AzureCloudProvider(CloudProviderInterface):
headers=create_tenant_headers,
)
print('create tenant result')
print(result.json())
if result.status_code == 200:
return self._ok(TenantCSPResult(**result.json()))
else:
@ -836,7 +917,7 @@ class AzureCloudProvider(CloudProviderInterface):
else:
return self._error(result.json())
def grant_billing_profile_tenant_access(self, payload: BillingRoleAssignmentCSPPayload):
def create_billing_profile_tenant_access(self, payload: BillingProfileTenantAccessCSPPayload):
sp_token = self._get_sp_token(payload.creds)
request_body = {
"properties": {
@ -854,11 +935,11 @@ class AzureCloudProvider(CloudProviderInterface):
result = self.sdk.requests.post(url, headers=headers, json=request_body)
if result.status_code == 201:
return self._ok(BillingRoleAssignmentCSPResult(**result.json()))
return self._ok(BillingProfileTenantAccessCSPResult(**result.json()))
else:
return self._error(result.json())
def enable_task_order_billing(self, payload: EnableTaskOrderBillingCSPPayload):
def enable_task_order_billing(self, payload: TaskOrderBillingCSPPayload):
sp_token = self._get_sp_token(payload.creds)
request_body = [
{
@ -884,7 +965,7 @@ class AzureCloudProvider(CloudProviderInterface):
# 202 has location/retry after headers
return self._ok(BillingProfileCreateCSPResult(**result.headers))
elif result.status_code == 200:
return self._ok(BillingProfileEnabledCSPResult(**result.json()))
return self._ok(TaskOrderBillingCSPResult(**result.json()))
else:
return self._error(result.json())
@ -903,13 +984,13 @@ class AzureCloudProvider(CloudProviderInterface):
if result.status_code == 202:
# 202 has location/retry after headers
return self._ok(EnableTaskOrderBillingCSPResult(**result.headers))
return self._ok(TaskOrderBillingCSPResult(**result.headers))
elif result.status_code == 200:
return self._ok(BillingProfileEnabledCSPResult(**result.json()))
return self._ok(TaskOrderBillingCSPResult(**result.json()))
else:
return self._error(result.json())
def report_clin(self, payload: ReportCLINCSPPayload):
def create_billing_instruction(self, payload: ReportCLINCSPPayload):
sp_token = self._get_sp_token(payload.creds)
if sp_token is None:
raise AuthenticationException(
@ -1022,7 +1103,9 @@ class AzureCloudProvider(CloudProviderInterface):
return sub_id_match.group(1)
def _get_sp_token(self, creds):
home_tenant_id = creds.get("home_tenant_id")
home_tenant_id = creds.get(
"home_tenant_id"
)
client_id = creds.get("client_id")
secret_key = creds.get("secret_key")

View File

@ -10,6 +10,9 @@ class StageStates(Enum):
class AzureStages(Enum):
TENANT = "tenant"
BILLING_PROFILE = "billing profile"
BILLING_PROFILE_TENANT_ACCESS = "billing profile tenant access"
TASK_ORDER_BILLING = "task order billing"
BILLING_INSTRUCTION = "billing instruction"
def _build_csp_states(csp_stages):

View File

@ -1,3 +1,7 @@
from random import choice, choices
import re
import string
from sqlalchemy import Column, ForeignKey, Enum as SQLAEnum
from sqlalchemy.orm import relationship, reconstructor
from sqlalchemy.dialects.postgresql import UUID
@ -17,6 +21,16 @@ import atst.models.mixins as mixins
from atst.models.mixins.state_machines import FSMStates, AzureStages, _build_transitions
def make_password():
return choice(string.ascii_letters) + "".join(
choices(string.ascii_letters + string.digits + string.punctuation, k=15)
)
def fetch_portfolio_creds(portfolio):
return dict(username="mock-cloud", password="shh")
@add_state_features(Tags)
class StateMachineWithTags(Machine):
pass
@ -73,57 +87,49 @@ class PortfolioStateMachine(
return getattr(FSMStates, self.state)
return self.state
def trigger_next_transition(self):
def trigger_next_transition(self, **kwargs):
state_obj = self.machine.get_state(self.state)
if state_obj.is_system:
if self.current_state in (FSMStates.UNSTARTED, FSMStates.STARTING):
# call the first trigger availabe for these two system states
trigger_name = self.machine.get_triggers(self.current_state.name)[0]
self.trigger(trigger_name)
self.trigger(trigger_name, **kwargs)
elif self.current_state == FSMStates.STARTED:
# get the first trigger that starts with 'create_'
create_trigger = self._get_first_stage_create_trigger()
if create_trigger:
self.trigger(create_trigger)
self.trigger(create_trigger, **kwargs)
else:
self.fail_stage(stage)
elif state_obj.is_IN_PROGRESS:
pass
# elif state_obj.is_TENANT:
# pass
# elif state_obj.is_BILLING_PROFILE:
# pass
elif state_obj.is_CREATED:
triggers = self.machine.get_triggers(state_obj.name)
self.trigger(triggers[-1], **kwargs)
# @with_payload
def after_in_progress_callback(self, event):
stage = self.current_state.name.split("_IN_PROGRESS")[0].lower()
if stage == "tenant":
payload = dict( # nosec
creds={"username": "mock-cloud", "pass": "shh"},
user_id="123",
password="123",
domain_name="123",
first_name="john",
last_name="doe",
country_code="US",
password_recovery_email_address="password@email.com",
)
elif stage == "billing_profile":
payload = dict(creds={"username": "mock-cloud", "pass": "shh"},)
# Accumulate payload w/ creds
payload = event.kwargs.get("csp_data")
payload["creds"] = event.kwargs.get("creds")
payload_data_cls = get_stage_csp_class(stage, "payload")
if not payload_data_cls:
print("could not resolve payload data class")
self.fail_stage(stage)
try:
payload_data = payload_data_cls(**payload)
except PydanticValidationError as exc:
print("Payload Validation Error:")
print(exc.json())
print("got")
print(payload)
self.fail_stage(stage)
# TODO: Determine best place to do this, maybe @reconstructor
csp = event.kwargs.get("csp")
if csp is not None:
self.csp = AzureCSP(app).cloud
@ -132,7 +138,8 @@ class PortfolioStateMachine(
for attempt in range(5):
try:
response = getattr(self.csp, "create_" + stage)(payload_data)
func_name = f"create_{stage}"
response = getattr(self.csp, func_name)(payload_data)
except (ConnectionException, UnknownServerException) as exc:
print("caught exception. retry", attempt)
continue
@ -140,14 +147,17 @@ class PortfolioStateMachine(
break
else:
# failed all attempts
print("failed")
self.fail_stage(stage)
if self.portfolio.csp_data is None:
self.portfolio.csp_data = {}
self.portfolio.csp_data[stage + "_data"] = response
self.portfolio.csp_data.update(response)
db.session.add(self.portfolio)
db.session.commit()
# store any updated creds, if necessary
self.finish_stage(stage)
def is_csp_data_valid(self, event):
@ -156,16 +166,23 @@ class PortfolioStateMachine(
if self.portfolio.csp_data is None or not isinstance(
self.portfolio.csp_data, dict
):
print("no csp data")
return False
stage = self.current_state.name.split("_IN_PROGRESS")[0].lower()
stage_data = self.portfolio.csp_data.get(stage + "_data")
stage_data = self.portfolio.csp_data
cls = get_stage_csp_class(stage, "result")
if not cls:
return False
try:
cls(**stage_data)
dc = cls(**stage_data)
if getattr(dc, "get_creds", None) is not None:
new_creds = dc.get_creds()
# TODO: how/where to store these
# TODO: credential schema
# self.store_creds(self.portfolio, new_creds)
except PydanticValidationError as exc:
print(exc.json())
return False

View File

@ -153,7 +153,7 @@ def test_create_tenant(mock_azure: AzureCloudProvider):
mock_azure.sdk.requests.post.return_value = mock_result
payload = TenantCSPPayload(
**dict(
creds={"username": "mock-cloud", "pass": "shh"},
creds={"username": "mock-cloud", "password": "shh"},
user_id="admin",
password="JediJan13$coot",
domain_name="jediccpospawnedtenant2",
@ -190,7 +190,7 @@ def test_create_billing_profile(mock_azure: AzureCloudProvider):
country="US",
postal_code="19109",
),
creds={"username": "mock-cloud", "pass": "shh"},
creds={"username": "mock-cloud", "password": "shh"},
tenant_id="60ff9d34-82bf-4f21-b565-308ef0533435",
display_name="Test Billing Profile",
)
@ -258,7 +258,7 @@ def test_validate_billing_profile_creation(mock_azure: AzureCloudProvider):
)
def test_grant_billing_profile_tenant_access(mock_azure: AzureCloudProvider):
def test_create_billing_profile_tenant_access(mock_azure: AzureCloudProvider):
mock_azure.sdk.adal.AuthenticationContext.return_value.context.acquire_token_with_client_credentials.return_value = {
"accessToken": "TOKEN"
}
@ -295,7 +295,7 @@ def test_grant_billing_profile_tenant_access(mock_azure: AzureCloudProvider):
)
)
result = mock_azure.grant_billing_profile_tenant_access(payload)
result = mock_azure.create_billing_profile_tenant_access(payload)
body: BillingRoleAssignmentCSPResult = result.get("body")
assert (
body.billing_role_assignment_name
@ -303,7 +303,7 @@ def test_grant_billing_profile_tenant_access(mock_azure: AzureCloudProvider):
)
def test_enable_task_order_billing(mock_azure: AzureCloudProvider):
def test_create_task_order_billing(mock_azure: AzureCloudProvider):
mock_azure.sdk.adal.AuthenticationContext.return_value.context.acquire_token_with_client_credentials.return_value = {
"accessToken": "TOKEN"
}
@ -401,7 +401,7 @@ def test_validate_task_order_billing_enabled(mock_azure):
)
def test_report_clin(mock_azure: AzureCloudProvider):
def test_create_billing_instruction(mock_azure: AzureCloudProvider):
mock_azure.sdk.adal.AuthenticationContext.return_value.context.acquire_token_with_client_credentials.return_value = {
"accessToken": "TOKEN"
}
@ -432,7 +432,7 @@ def test_report_clin(mock_azure: AzureCloudProvider):
billing_profile_name="KQWI-W2SU-BG7-TGB",
)
)
result = mock_azure.report_clin(payload)
result = mock_azure.create_billing_instruction(payload)
body: ReportCLINCSPResult = result.get("body")
assert body.reported_clin_name == "TO1:CLIN001"

View File

@ -1,11 +1,12 @@
import pytest
import re
from tests.factories import (
PortfolioFactory,
PortfolioStateMachineFactory,
)
from atst.models import FSMStates
from atst.models import FSMStates, PortfolioStateMachine
from atst.models.mixins.state_machines import AzureStages, StageStates, compose_state
from atst.domain.csp import get_stage_csp_class
@ -78,7 +79,7 @@ def test_state_machine_initialization(portfolio):
def test_fsm_transition_start(portfolio):
sm = PortfolioStateMachineFactory.create(portfolio=portfolio)
sm: PortfolioStateMachine = PortfolioStateMachineFactory.create(portfolio=portfolio)
assert sm.portfolio
assert sm.state == FSMStates.UNSTARTED
@ -87,5 +88,48 @@ def test_fsm_transition_start(portfolio):
sm.start()
assert sm.state == FSMStates.STARTED
sm.create_tenant(a=1, b=2)
# Should source all creds for portfolio? might be easier to manage than per-step specific ones
creds = {"username": "mock-cloud", "password": "shh"}
if portfolio.csp_data is not None:
csp_data = portfolio.csp_data
else:
csp_data = {}
ppoc = portfolio.owner
user_id = f"{ppoc.first_name[0]}{ppoc.last_name}".lower()
domain_name = re.sub("[^0-9a-zA-Z]+", "", portfolio.name).lower()
portfolio_data = {
"user_id": user_id,
"password": "jklfsdNCVD83nklds2#202",
"domain_name": domain_name,
"first_name": ppoc.first_name,
"last_name": ppoc.last_name,
"country_code": "US",
"password_recovery_email_address": ppoc.email,
"address": {
"company_name": "",
"address_line_1": "",
"city": "",
"region": "",
"country": "",
"postal_code": "",
},
"billing_profile_display_name": "My Billing Profile",
}
collected_data = dict(list(csp_data.items()) + list(portfolio_data.items()))
sm.trigger_next_transition(creds=creds, csp_data=collected_data)
assert sm.state == FSMStates.TENANT_CREATED
assert portfolio.csp_data.get("tenant_id", None) is not None
if portfolio.csp_data is not None:
csp_data = portfolio.csp_data
else:
csp_data = {}
collected_data = dict(list(csp_data.items()) + list(portfolio_data.items()))
sm.trigger_next_transition(creds=creds, csp_data=collected_data)
assert sm.state == FSMStates.BILLING_PROFILE_CREATED