diff --git a/atst/domain/csp/cloud.py b/atst/domain/csp/cloud.py index 2257ace1..b46b03b4 100644 --- a/atst/domain/csp/cloud.py +++ b/atst/domain/csp/cloud.py @@ -380,7 +380,7 @@ class CloudProviderInterface: def set_secret(self, secret_key: str, secret_value: str): raise NotImplementedError() - def get_secret(self, secret_key: str, secret_value: str): + def get_secret(self, secret_key: str): raise NotImplementedError() def root_creds(self) -> Dict: @@ -520,6 +520,12 @@ class MockCloudProvider(CloudProviderInterface): def root_creds(self): return self._auth_credentials + def set_secret(self, secret_key: str, secret_value: str): + pass + + def get_secret(self, secret_key: str): + return {} + def create_environment(self, auth_credentials, user, environment): self._authorize(auth_credentials) @@ -846,7 +852,7 @@ class AzureCloudProvider(CloudProviderInterface): self.policy_manager = AzurePolicyManager(config["AZURE_POLICY_LOCATION"]) - def set_secret(secret_key, secret_value): + def set_secret(self, secret_key, secret_value): credential = self._get_client_secret_credential_obj() secret_client = self.secrets.SecretClient( vault_url=self.vault_url, credential=credential, @@ -859,7 +865,7 @@ class AzureCloudProvider(CloudProviderInterface): exc_info=1, ) - def get_secret(secret_key): + def get_secret(self, secret_key): credential = self._get_client_secret_credential_obj() secret_client = self.secrets.SecretClient( vault_url=self.vault_url, credential=credential, diff --git a/atst/models/portfolio_state_machine.py b/atst/models/portfolio_state_machine.py index 3f0e4d7d..6e3015c4 100644 --- a/atst/models/portfolio_state_machine.py +++ b/atst/models/portfolio_state_machine.py @@ -1,5 +1,4 @@ from random import choice, choices -import re import string from sqlalchemy import Column, ForeignKey, Enum as SQLAEnum @@ -119,7 +118,6 @@ class PortfolioStateMachine( elif state_obj.is_CREATED: # the create trigger for the next stage should be in the available # triggers for the current state - triggers = self.machine.get_triggers(state_obj.name) create_trigger = next( filter( lambda trigger: trigger.startswith("create_"), @@ -205,11 +203,10 @@ class PortfolioStateMachine( dc = cls(**stage_data) if getattr(dc, "get_creds", None) is not None: new_creds = dc.get_creds() - print("creds to report") - print(new_creds) - # TODO: how/where to store these - # TODO: credential schema - # self.store_creds(self.portfolio, new_creds) + tenant_id = new_creds.get("tenant_id") + secret = self.csp.get_secret(tenant_id) + secret.update(new_creds) + self.csp.set_secret(tenant_id, secret) except PydanticValidationError as exc: app.logger.error( diff --git a/tests/domain/test_portfolio_state_machine.py b/tests/domain/test_portfolio_state_machine.py index 82a6d086..c5ca68cb 100644 --- a/tests/domain/test_portfolio_state_machine.py +++ b/tests/domain/test_portfolio_state_machine.py @@ -2,7 +2,6 @@ import pytest import re from tests.factories import ( - PortfolioFactory, PortfolioStateMachineFactory, TaskOrderFactory, CLINFactory, @@ -36,7 +35,7 @@ def test_state_machine_trigger_next_transition(portfolio): def test_state_machine_compose_state(portfolio): - sm = PortfolioStateMachineFactory.create(portfolio=portfolio) + PortfolioStateMachineFactory.create(portfolio=portfolio) assert ( compose_state(AzureStages.TENANT, StageStates.CREATED) == FSMStates.TENANT_CREATED @@ -44,7 +43,7 @@ def test_state_machine_compose_state(portfolio): def test_state_machine_valid_data_classes_for_stages(portfolio): - sm = PortfolioStateMachineFactory.create(portfolio=portfolio) + PortfolioStateMachineFactory.create(portfolio=portfolio) for stage in AzureStages: assert get_stage_csp_class(stage.name.lower(), "payload") is not None assert get_stage_csp_class(stage.name.lower(), "result") is not None