From f08d53d7a03bbcbbb37a084040fb4052ad9d0cef Mon Sep 17 00:00:00 2001 From: tomdds Date: Fri, 24 Jan 2020 15:42:23 -0500 Subject: [PATCH] Transition all Cloud Interface Methods to use Dataclasses --- atst/domain/csp/cloud.py | 18 ++--- atst/models/portfolio_state_machine.py | 83 ++++++++------------ tests/domain/test_portfolio_state_machine.py | 4 +- 3 files changed, 44 insertions(+), 61 deletions(-) diff --git a/atst/domain/csp/cloud.py b/atst/domain/csp/cloud.py index d22f9475..c7134431 100644 --- a/atst/domain/csp/cloud.py +++ b/atst/domain/csp/cloud.py @@ -529,8 +529,8 @@ class MockCloudProvider(CloudProviderInterface): def set_secret(self, secret_key: str, secret_value: str): pass - def get_secret(self, secret_key: str): - return {} + def get_secret(self, secret_key: str, default=dict()): + return default def create_environment(self, auth_credentials, user, environment): self._authorize(auth_credentials) @@ -598,7 +598,7 @@ class MockCloudProvider(CloudProviderInterface): "tenant_admin_username": "test", "tenant_admin_password": "test", } - ).dict() + ) def create_billing_profile_creation( self, payload: BillingProfileCreationCSPPayload @@ -613,7 +613,7 @@ class MockCloudProvider(CloudProviderInterface): billing_profile_verify_url="https://zombo.com", billing_profile_retry_after=10, ) - ).dict() + ) def create_billing_profile_verification( self, payload: BillingProfileVerificationCSPPayload @@ -651,7 +651,7 @@ class MockCloudProvider(CloudProviderInterface): }, "type": "Microsoft.Billing/billingAccounts/billingProfiles", } - ).dict() + ) def create_billing_profile_tenant_access(self, payload): self._maybe_raise(self.NETWORK_FAILURE_PCT, self.NETWORK_EXCEPTION) @@ -672,7 +672,7 @@ class MockCloudProvider(CloudProviderInterface): }, "type": "Microsoft.Billing/billingRoleAssignments", } - ).dict() + ) def create_task_order_billing_creation( self, payload: TaskOrderBillingCreationCSPPayload @@ -683,7 +683,7 @@ class MockCloudProvider(CloudProviderInterface): return TaskOrderBillingCreationCSPResult( **{"Location": "https://somelocation", "Retry-After": "10"} - ).dict() + ) def create_task_order_billing_verification( self, payload: TaskOrderBillingVerificationCSPPayload @@ -720,7 +720,7 @@ class MockCloudProvider(CloudProviderInterface): }, "type": "Microsoft.Billing/billingAccounts/billingProfiles", } - ).dict() + ) def create_billing_instruction(self, payload: BillingInstructionCSPPayload): self._maybe_raise(self.NETWORK_FAILURE_PCT, self.NETWORK_EXCEPTION) @@ -737,7 +737,7 @@ class MockCloudProvider(CloudProviderInterface): }, "type": "Microsoft.Billing/billingAccounts/billingProfiles/billingInstructions", } - ).dict() + ) def create_or_update_user(self, auth_credentials, user_info, csp_role_id): self._authorize(auth_credentials) diff --git a/atst/models/portfolio_state_machine.py b/atst/models/portfolio_state_machine.py index a0cc77cd..1390ceb1 100644 --- a/atst/models/portfolio_state_machine.py +++ b/atst/models/portfolio_state_machine.py @@ -144,70 +144,51 @@ class PortfolioStateMachine( else: self.csp = MockCSP(app).cloud - attempts_count = 5 - for attempt in range(attempts_count): - try: - func_name = f"create_{stage}" - response = getattr(self.csp, func_name)(payload_data) - except (ConnectionException, UnknownServerException) as exc: - app.logger.error( - f"CSP api call. Caught exception for {self.__repr__()}. Retry attempt {attempt}", - exc_info=1, - ) - continue - else: - break - else: - # failed all attempts - logger.info(f"CSP api call failed after {attempts_count} attempts.") + try: + func_name = f"create_{stage}" + response = getattr(self.csp, func_name)(payload_data) + if self.portfolio.csp_data is None: + self.portfolio.csp_data = {} + self.portfolio.csp_data.update(response.dict()) + db.session.add(self.portfolio) + db.session.commit() + + if getattr(response, "get_creds", None) is not None: + new_creds = response.get_creds() + # TODO: one way salted hash of tenant_id to use as kv key name? + tenant_id = new_creds.get("tenant_id") + secret = self.csp.get_secret(tenant_id, new_creds) + secret.update(new_creds) + self.csp.set_secret(tenant_id, secret) + except PydanticValidationError as exc: + app.logger.error( + f"Failed to cast response to valid result class {self.__repr__()}:", + exc_info=1, + ) + app.logger.info(exc.json()) + print(exc.json()) + app.logger.info(payload_data) + self.fail_stage(stage) + except (ConnectionException, UnknownServerException) as exc: + app.logger.error( + f"CSP api call. Caught exception for {self.__repr__()}.", exc_info=1, + ) self.fail_stage(stage) - - if self.portfolio.csp_data is None: - self.portfolio.csp_data = {} - self.portfolio.csp_data.update(response) - db.session.add(self.portfolio) - db.session.commit() - - # store any updated creds, if necessary self.finish_stage(stage) def is_csp_data_valid(self, event): - # check portfolio csp details json field for fields + """ + This function guards advancing states from *_IN_PROGRESS to *_COMPLETED. + """ if self.portfolio.csp_data is None or not isinstance( self.portfolio.csp_data, dict ): print("no csp data") return False - stage = self.current_state.name.split("_IN_PROGRESS")[0].lower() - stage_data = self.portfolio.csp_data - cls = get_stage_csp_class(stage, "result") - if not cls: - return False - - try: - dc = cls(**stage_data) - if getattr(dc, "get_creds", None) is not None: - new_creds = dc.get_creds() - tenant_id = new_creds.get("tenant_id") - secret = self.csp.get_secret(tenant_id) - secret.update(new_creds) - self.csp.set_secret(tenant_id, secret) - - except PydanticValidationError as exc: - app.logger.error( - f"Payload Validation Error in {self.__repr__()}:", exc_info=1 - ) - app.logger.info(exc.json()) - app.logger.info(payload) - - return False - return True - # print('failed condition', self.portfolio.csp_data) - @property def application_id(self): return None diff --git a/tests/domain/test_portfolio_state_machine.py b/tests/domain/test_portfolio_state_machine.py index 0d37c9c7..eeebf6e9 100644 --- a/tests/domain/test_portfolio_state_machine.py +++ b/tests/domain/test_portfolio_state_machine.py @@ -11,6 +11,8 @@ from atst.models.mixins.state_machines import AzureStages, StageStates, compose_ from atst.models.portfolio import Portfolio from atst.domain.csp import get_stage_csp_class +# TODO: Write failure case tests + @pytest.fixture(scope="function") def portfolio(): @@ -122,7 +124,7 @@ def test_fsm_transition_start(portfolio: Portfolio): "last_name": ppoc.last_name, "country_code": "US", "password_recovery_email_address": ppoc.email, - "address": { + "address": { # TODO: TBD if we're sourcing this from data or config "company_name": "", "address_line_1": "", "city": "",