diff --git a/atst/domain/csp/__init__.py b/atst/domain/csp/__init__.py index 62b28f94..d886f8a2 100644 --- a/atst/domain/csp/__init__.py +++ b/atst/domain/csp/__init__.py @@ -1,5 +1,3 @@ -import importlib - from .cloud import MockCloudProvider from .file_uploads import AzureUploader, MockUploader from .reports import MockReportingProvider @@ -31,22 +29,3 @@ def make_csp_provider(app, csp=None): app.csp = MockCSP(app, test_mode=True) else: app.csp = MockCSP(app) - - -def _stage_to_classname(stage): - return "".join(map(lambda word: word.capitalize(), stage.split("_"))) - - -def get_stage_csp_class(stage, class_type): - """ - given a stage name and class_type return the class - class_type is either 'payload' or 'result' - - """ - cls_name = f"{_stage_to_classname(stage)}CSP{class_type.capitalize()}" - try: - return getattr( - importlib.import_module("atst.domain.csp.cloud.models"), cls_name - ) - except AttributeError: - print("could not import CSP Result class <%s>" % cls_name) diff --git a/atst/domain/csp/cloud/azure_cloud_provider.py b/atst/domain/csp/cloud/azure_cloud_provider.py index 5bbcdd54..0ed18d9d 100644 --- a/atst/domain/csp/cloud/azure_cloud_provider.py +++ b/atst/domain/csp/cloud/azure_cloud_provider.py @@ -1,4 +1,5 @@ import re +from secrets import token_urlsafe from typing import Dict from uuid import uuid4 @@ -25,7 +26,6 @@ from .models import ( ) from .policy import AzurePolicyManager - AZURE_ENVIRONMENT = "AZURE_PUBLIC_CLOUD" # TBD AZURE_SKU_ID = "?" # probably a static sku specific to ATAT/JEDI SUBSCRIPTION_ID_REGEX = re.compile( @@ -295,6 +295,7 @@ class AzureCloudProvider(CloudProviderInterface): if sp_token is None: raise AuthenticationException("Could not resolve token for tenant creation") + payload.password = token_urlsafe(16) create_tenant_body = payload.dict(by_alias=True) create_tenant_headers = { @@ -513,7 +514,6 @@ class AzureCloudProvider(CloudProviderInterface): # we likely only want the budget ID, can be updated or replaced? response = {"id": "id"} - return self._ok({"budget_id": response["id"]}) def _get_management_service_principal(self): diff --git a/atst/domain/csp/cloud/mock_cloud_provider.py b/atst/domain/csp/cloud/mock_cloud_provider.py index 39bc6da3..6df61003 100644 --- a/atst/domain/csp/cloud/mock_cloud_provider.py +++ b/atst/domain/csp/cloud/mock_cloud_provider.py @@ -67,8 +67,8 @@ class MockCloudProvider(CloudProviderInterface): def set_secret(self, secret_key: str, secret_value: str): pass - def get_secret(self, secret_key: str): - return {} + def get_secret(self, secret_key: str, default=dict()): + return default def create_environment(self, auth_credentials, user, environment): self._authorize(auth_credentials) @@ -136,7 +136,7 @@ class MockCloudProvider(CloudProviderInterface): "tenant_admin_username": "test", "tenant_admin_password": "test", } - ).dict() + ) def create_billing_profile_creation( self, payload: BillingProfileCreationCSPPayload @@ -151,7 +151,7 @@ class MockCloudProvider(CloudProviderInterface): billing_profile_verify_url="https://zombo.com", billing_profile_retry_after=10, ) - ).dict() + ) def create_billing_profile_verification( self, payload: BillingProfileVerificationCSPPayload @@ -189,7 +189,7 @@ class MockCloudProvider(CloudProviderInterface): }, "type": "Microsoft.Billing/billingAccounts/billingProfiles", } - ).dict() + ) def create_billing_profile_tenant_access(self, payload): self._maybe_raise(self.NETWORK_FAILURE_PCT, self.NETWORK_EXCEPTION) @@ -210,7 +210,7 @@ class MockCloudProvider(CloudProviderInterface): }, "type": "Microsoft.Billing/billingRoleAssignments", } - ).dict() + ) def create_task_order_billing_creation( self, payload: TaskOrderBillingCreationCSPPayload @@ -221,7 +221,7 @@ class MockCloudProvider(CloudProviderInterface): return TaskOrderBillingCreationCSPResult( **{"Location": "https://somelocation", "Retry-After": "10"} - ).dict() + ) def create_task_order_billing_verification( self, payload: TaskOrderBillingVerificationCSPPayload @@ -258,7 +258,7 @@ class MockCloudProvider(CloudProviderInterface): }, "type": "Microsoft.Billing/billingAccounts/billingProfiles", } - ).dict() + ) def create_billing_instruction(self, payload: BillingInstructionCSPPayload): self._maybe_raise(self.NETWORK_FAILURE_PCT, self.NETWORK_EXCEPTION) @@ -275,7 +275,7 @@ class MockCloudProvider(CloudProviderInterface): }, "type": "Microsoft.Billing/billingAccounts/billingProfiles/billingInstructions", } - ).dict() + ) def create_or_update_user(self, auth_credentials, user_info, csp_role_id): self._authorize(auth_credentials) diff --git a/atst/domain/csp/cloud/models.py b/atst/domain/csp/cloud/models.py index 93ac7d8d..369bed31 100644 --- a/atst/domain/csp/cloud/models.py +++ b/atst/domain/csp/cloud/models.py @@ -37,7 +37,7 @@ class BaseCSPPayload(AliasModel): class TenantCSPPayload(BaseCSPPayload): user_id: str - password: str + password: Optional[str] domain_name: str first_name: str last_name: str diff --git a/atst/models/portfolio_state_machine.py b/atst/models/portfolio_state_machine.py index cdd82da9..cf42710b 100644 --- a/atst/models/portfolio_state_machine.py +++ b/atst/models/portfolio_state_machine.py @@ -1,3 +1,5 @@ +import importlib + from sqlalchemy import Column, ForeignKey, Enum as SQLAEnum from sqlalchemy.orm import relationship, reconstructor from sqlalchemy.dialects.postgresql import UUID @@ -9,7 +11,6 @@ from transitions.extensions.states import add_state_features, Tags from flask import current_app as app from atst.domain.csp.cloud.exceptions import ConnectionException, UnknownServerException -from atst.domain.csp import MockCSP, AzureCSP, get_stage_csp_class from atst.database import db from atst.models.types import Id from atst.models.base import Base @@ -17,6 +18,25 @@ import atst.models.mixins as mixins from atst.models.mixins.state_machines import FSMStates, AzureStages, _build_transitions +def _stage_to_classname(stage): + return "".join(map(lambda word: word.capitalize(), stage.split("_"))) + + +def get_stage_csp_class(stage, class_type): + """ + given a stage name and class_type return the class + class_type is either 'payload' or 'result' + + """ + cls_name = f"{_stage_to_classname(stage)}CSP{class_type.capitalize()}" + try: + return getattr( + importlib.import_module("atst.domain.csp.cloud.models"), cls_name + ) + except AttributeError: + print("could not import CSP Result class <%s>" % cls_name) + + @add_state_features(Tags) class StateMachineWithTags(Machine): pass @@ -138,76 +158,53 @@ class PortfolioStateMachine( self.fail_stage(stage) # TODO: Determine best place to do this, maybe @reconstructor - csp = event.kwargs.get("csp") - if csp is not None: - self.csp = AzureCSP(app).cloud - else: - self.csp = MockCSP(app).cloud + self.csp = app.csp.cloud - attempts_count = 5 - for attempt in range(attempts_count): - try: - func_name = f"create_{stage}" - response = getattr(self.csp, func_name)(payload_data) - except (ConnectionException, UnknownServerException) as exc: - app.logger.error( - f"CSP api call. Caught exception for {self.__repr__()}. Retry attempt {attempt}", - exc_info=1, - ) - continue - else: - break - else: - # failed all attempts - logger.info(f"CSP api call failed after {attempts_count} attempts.") + try: + func_name = f"create_{stage}" + response = getattr(self.csp, func_name)(payload_data) + if self.portfolio.csp_data is None: + self.portfolio.csp_data = {} + self.portfolio.csp_data.update(response.dict()) + db.session.add(self.portfolio) + db.session.commit() + + if getattr(response, "get_creds", None) is not None: + new_creds = response.get_creds() + # TODO: one way salted hash of tenant_id to use as kv key name? + tenant_id = new_creds.get("tenant_id") + secret = self.csp.get_secret(tenant_id, new_creds) + secret.update(new_creds) + self.csp.set_secret(tenant_id, secret) + except PydanticValidationError as exc: + app.logger.error( + f"Failed to cast response to valid result class {self.__repr__()}:", + exc_info=1, + ) + app.logger.info(exc.json()) + print(exc.json()) + app.logger.info(payload_data) + self.fail_stage(stage) + except (ConnectionException, UnknownServerException) as exc: + app.logger.error( + f"CSP api call. Caught exception for {self.__repr__()}.", exc_info=1, + ) self.fail_stage(stage) - - if self.portfolio.csp_data is None: - self.portfolio.csp_data = {} - self.portfolio.csp_data.update(response) - db.session.add(self.portfolio) - db.session.commit() - - # store any updated creds, if necessary self.finish_stage(stage) def is_csp_data_valid(self, event): - # check portfolio csp details json field for fields + """ + This function guards advancing states from *_IN_PROGRESS to *_COMPLETED. + """ if self.portfolio.csp_data is None or not isinstance( self.portfolio.csp_data, dict ): print("no csp data") return False - stage = self.current_state.name.split("_IN_PROGRESS")[0].lower() - stage_data = self.portfolio.csp_data - cls = get_stage_csp_class(stage, "result") - if not cls: - return False - - try: - dc = cls(**stage_data) - if getattr(dc, "get_creds", None) is not None: - new_creds = dc.get_creds() - tenant_id = new_creds.get("tenant_id") - secret = self.csp.get_secret(tenant_id) - secret.update(new_creds) - self.csp.set_secret(tenant_id, secret) - - except PydanticValidationError as exc: - app.logger.error( - f"Payload Validation Error in {self.__repr__()}:", exc_info=1 - ) - app.logger.info(exc.json()) - app.logger.info(payload) - - return False - return True - # print('failed condition', self.portfolio.csp_data) - @property def application_id(self): return None diff --git a/tests/domain/cloud/test_azure_csp.py b/tests/domain/cloud/test_azure_csp.py index 228b78e4..0d23d6c0 100644 --- a/tests/domain/cloud/test_azure_csp.py +++ b/tests/domain/cloud/test_azure_csp.py @@ -16,8 +16,6 @@ from atst.domain.csp.cloud.models import ( BillingProfileTenantAccessCSPResult, BillingProfileVerificationCSPPayload, BillingProfileVerificationCSPResult, - BillingInstructionCSPPayload, - BillingInstructionCSPResult, TaskOrderBillingCreationCSPPayload, TaskOrderBillingCreationCSPResult, TaskOrderBillingVerificationCSPPayload, @@ -26,7 +24,6 @@ from atst.domain.csp.cloud.models import ( TenantCSPResult, ) - creds = { "home_tenant_id": "tenant_id", "client_id": "client_id", diff --git a/tests/domain/test_portfolio_state_machine.py b/tests/domain/test_portfolio_state_machine.py index 330d5195..2e412653 100644 --- a/tests/domain/test_portfolio_state_machine.py +++ b/tests/domain/test_portfolio_state_machine.py @@ -1,5 +1,6 @@ import pytest import re +from unittest import mock from tests.factories import ( PortfolioStateMachineFactory, @@ -9,7 +10,9 @@ from tests.factories import ( from atst.models import FSMStates, PortfolioStateMachine, TaskOrder from atst.models.mixins.state_machines import AzureStages, StageStates, compose_state from atst.models.portfolio import Portfolio -from atst.domain.csp import get_stage_csp_class +from atst.models.portfolio_state_machine import get_stage_csp_class + +# TODO: Write failure case tests @pytest.fixture(scope="function") @@ -79,7 +82,10 @@ def test_state_machine_initialization(portfolio): assert ["reset", "fail", create_trigger] == started_triggers -def test_fsm_transition_start(portfolio: Portfolio): +@mock.patch("atst.domain.csp.cloud.MockCloudProvider") +def test_fsm_transition_start(mock_cloud_provider, portfolio: Portfolio): + mock_cloud_provider._authorize.return_value = None + mock_cloud_provider._maybe_raise.return_value = None sm: PortfolioStateMachine = PortfolioStateMachineFactory.create(portfolio=portfolio) assert sm.portfolio assert sm.state == FSMStates.UNSTARTED @@ -101,7 +107,7 @@ def test_fsm_transition_start(portfolio: Portfolio): ] # Should source all creds for portfolio? might be easier to manage than per-step specific ones - creds = {"username": "mock-cloud", "password": "shh"} + creds = {"username": "mock-cloud", "password": "shh"} # pragma: allowlist secret if portfolio.csp_data is not None: csp_data = portfolio.csp_data else: @@ -116,13 +122,13 @@ def test_fsm_transition_start(portfolio: Portfolio): portfolio_data = { "user_id": user_id, - "password": "jklfsdNCVD83nklds2#202", + "password": "jklfsdNCVD83nklds2#202", # pragma: allowlist secret "domain_name": domain_name, "first_name": ppoc.first_name, "last_name": ppoc.last_name, "country_code": "US", "password_recovery_email_address": ppoc.email, - "address": { + "address": { # TODO: TBD if we're sourcing this from data or config "company_name": "", "address_line_1": "", "city": "",