diff --git a/atst/domain/csp/cloud/azure_cloud_provider.py b/atst/domain/csp/cloud/azure_cloud_provider.py index 735bf53a..7985f363 100644 --- a/atst/domain/csp/cloud/azure_cloud_provider.py +++ b/atst/domain/csp/cloud/azure_cloud_provider.py @@ -1,7 +1,7 @@ import json import re from secrets import token_urlsafe -from typing import Dict +from typing import Any, Dict from uuid import uuid4 from atst.utils import sha256_hex @@ -104,7 +104,7 @@ class AzureCloudProvider(CloudProviderInterface): self.policy_manager = AzurePolicyManager(config["AZURE_POLICY_LOCATION"]) def set_secret(self, secret_key, secret_value): - credential = self._get_client_secret_credential_obj({}) + credential = self._get_client_secret_credential_obj() secret_client = self.sdk.secrets.SecretClient( vault_url=self.vault_url, credential=credential, ) @@ -117,7 +117,7 @@ class AzureCloudProvider(CloudProviderInterface): ) def get_secret(self, secret_key): - credential = self._get_client_secret_credential_obj({}) + credential = self._get_client_secret_credential_obj() secret_client = self.sdk.secrets.SecretClient( vault_url=self.vault_url, credential=credential, ) @@ -318,7 +318,7 @@ class AzureCloudProvider(CloudProviderInterface): ) def create_tenant(self, payload: TenantCSPPayload): - sp_token = self.get_root_provisioning_token() + sp_token = self._get_root_provisioning_token() if sp_token is None: raise AuthenticationException("Could not resolve token for tenant creation") @@ -336,20 +336,27 @@ class AzureCloudProvider(CloudProviderInterface): ) if result.status_code == 200: - return self._ok( - TenantCSPResult( - **result.json(), - tenant_admin_password=payload.password, - tenant_admin_username=payload.user_id, - ) + result_dict = result.json() + tenant_id = result_dict.get("tenantId") + tenant_admin_username = ( + f"{payload.user_id}@{payload.domain_name}.onmicrosoft.com" ) + self.update_tenant_creds( + tenant_id, + KeyVaultCredentials( + tenant_id=tenant_id, + tenant_admin_username=tenant_admin_username, + tenant_admin_password=payload.password, + ), + ) + return self._ok(TenantCSPResult(**result_dict)) else: return self._error(result.json()) def create_billing_profile_creation( self, payload: BillingProfileCreationCSPPayload ): - sp_token = self.get_root_provisioning_token() + sp_token = self._get_root_provisioning_token() if sp_token is None: raise AuthenticationException( "Could not resolve token for billing profile creation" @@ -381,7 +388,7 @@ class AzureCloudProvider(CloudProviderInterface): def create_billing_profile_verification( self, payload: BillingProfileVerificationCSPPayload ): - sp_token = self.get_root_provisioning_token() + sp_token = self._get_root_provisioning_token() if sp_token is None: raise AuthenticationException( "Could not resolve token for billing profile validation" @@ -406,7 +413,7 @@ class AzureCloudProvider(CloudProviderInterface): def create_billing_profile_tenant_access( self, payload: BillingProfileTenantAccessCSPPayload ): - sp_token = self.get_root_provisioning_token() + sp_token = self._get_root_provisioning_token() request_body = { "properties": { "principalTenantId": payload.tenant_id, # from tenant creation @@ -430,7 +437,7 @@ class AzureCloudProvider(CloudProviderInterface): def create_task_order_billing_creation( self, payload: TaskOrderBillingCreationCSPPayload ): - sp_token = self.get_root_provisioning_token() + sp_token = self._get_root_provisioning_token() request_body = [ { "op": "replace", @@ -460,7 +467,7 @@ class AzureCloudProvider(CloudProviderInterface): def create_task_order_billing_verification( self, payload: TaskOrderBillingVerificationCSPPayload ): - sp_token = self.get_root_provisioning_token() + sp_token = self._get_root_provisioning_token() if sp_token is None: raise AuthenticationException( "Could not resolve token for task order billing validation" @@ -483,7 +490,7 @@ class AzureCloudProvider(CloudProviderInterface): return self._error(result.json()) def create_billing_instruction(self, payload: BillingInstructionCSPPayload): - sp_token = self.get_root_provisioning_token() + sp_token = self._get_root_provisioning_token() if sp_token is None: raise AuthenticationException( "Could not resolve token for task order billing validation" @@ -510,28 +517,8 @@ class AzureCloudProvider(CloudProviderInterface): else: return self._error(result.json()) - def get_elevated_management_token(self, tenant_id): - mgmt_token = self.get_tenant_admin_token( - tenant_id, self.sdk.cloud.endpoints.resource_manager - ) - if mgmt_token is None: - raise AuthenticationException( - "Failed to resolve management token for tenant admin" - ) - - auth_header = { - "Authorization": f"Bearer {mgmt_token}", - } - url = f"{self.sdk.cloud.endpoints.resource_manager}/providers/Microsoft.Authorization/elevateAccess?api-version=2016-07-01" - result = self.sdk.requests.post(url, headers=auth_header) - - if not result.ok: - raise AuthenticationException("Failed to elevate access") - - return mgmt_token - def create_tenant_admin_ownership(self, payload: TenantAdminOwnershipCSPPayload): - mgmt_token = self.get_elevated_management_token(payload.tenant_id) + mgmt_token = self._get_elevated_management_token(payload.tenant_id) role_definition_id = f"/providers/Microsoft.Management/managementGroups/{payload.tenant_id}/providers/Microsoft.Authorization/roleDefinitions/{self.owner_role_def_id}" @@ -558,7 +545,7 @@ class AzureCloudProvider(CloudProviderInterface): def create_tenant_principal_ownership( self, payload: TenantPrincipalOwnershipCSPPayload ): - mgmt_token = self.get_elevated_management_token(payload.tenant_id) + mgmt_token = self._get_elevated_management_token(payload.tenant_id) # NOTE: the tenant_id is also the id of the root management group, once it is created role_definition_id = f"/providers/Microsoft.Management/managementGroups/{payload.tenant_id}/providers/Microsoft.Authorization/roleDefinitions/{self.owner_role_def_id}" @@ -584,8 +571,7 @@ class AzureCloudProvider(CloudProviderInterface): return TenantPrincipalOwnershipCSPResult(**response.json()) def create_tenant_principal_app(self, payload: TenantPrincipalAppCSPPayload): - - graph_token = self.get_tenant_admin_token( + graph_token = self._get_tenant_admin_token( payload.tenant_id, self.graph_resource ) if graph_token is None: @@ -607,7 +593,7 @@ class AzureCloudProvider(CloudProviderInterface): return TenantPrincipalAppCSPResult(**response.json()) def create_tenant_principal(self, payload: TenantPrincipalCSPPayload): - graph_token = self.get_tenant_admin_token( + graph_token = self._get_tenant_admin_token( payload.tenant_id, self.graph_resource ) if graph_token is None: @@ -631,7 +617,7 @@ class AzureCloudProvider(CloudProviderInterface): def create_tenant_principal_credential( self, payload: TenantPrincipalCredentialCSPPayload ): - graph_token = self.get_tenant_admin_token( + graph_token = self._get_tenant_admin_token( payload.tenant_id, self.graph_resource ) if graph_token is None: @@ -652,12 +638,22 @@ class AzureCloudProvider(CloudProviderInterface): response = self.sdk.requests.post(url, json=request_body, headers=auth_header) if response.ok: + result = response.json() + self.update_tenant_creds( + payload.tenant_id, + KeyVaultCredentials( + tenant_id=payload.tenant_id, + tenant_sp_key=result.get("secretText"), + tenant_sp_client_id=payload.principal_app_id, + ), + ) return TenantPrincipalCredentialCSPResult( - principal_client_id=payload.principal_app_id, **response.json() + principal_client_id=payload.principal_app_id, + principal_creds_established=True, ) def create_admin_role_definition(self, payload: AdminRoleDefinitionCSPPayload): - graph_token = self.get_tenant_admin_token( + graph_token = self._get_tenant_admin_token( payload.tenant_id, self.graph_resource ) if graph_token is None: @@ -689,7 +685,7 @@ class AzureCloudProvider(CloudProviderInterface): return AdminRoleDefinitionCSPResult(admin_role_def_id=admin_role_def_id) def create_principal_admin_role(self, payload: PrincipalAdminRoleCSPPayload): - graph_token = self.get_tenant_admin_token( + graph_token = self._get_tenant_admin_token( payload.tenant_id, self.graph_resource ) if graph_token is None: @@ -782,33 +778,22 @@ class AzureCloudProvider(CloudProviderInterface): if sub_id_match: return sub_id_match.group(1) - def get_tenant_admin_token(self, tenant_id, resource): - creds = self.get_secret(tenant_id) + def _get_tenant_admin_token(self, tenant_id, resource): + creds = self._source_tenant_creds(tenant_id) return self._get_up_token_for_resource( - creds.get("admin_username"), - creds.get("admin_password"), + creds.tenant_admin_username, + creds.tenant_admin_password, tenant_id, resource, ) - def get_tenant_principal_token(self, tenant_id, resource): - # creds = self.get_secret(tenant_id) - # return self._get_up_token_for_resource( - # creds.get("admin_username"), - # creds.get("admin_password"), - # tenat_id, - # resource - # ) - pass - - def get_root_provisioning_token(self): - return self._get_sp_token(self._root_creds) - - def _get_sp_token(self, creds): - tenant_id = creds.get("tenant_id") - client_id = creds.get("client_id") - secret_key = creds.get("secret_key") + def _get_root_provisioning_token(self): + creds = self._source_creds() + return self._get_sp_token( + creds.tenant_id, creds.root_sp_client_id, creds.root_sp_key + ) + def _get_sp_token(self, tenant_id, client_id, secret_key): context = self.sdk.adal.AuthenticationContext( f"{self.sdk.cloud.endpoints.active_directory}/{tenant_id}" ) @@ -842,16 +827,14 @@ class AzureCloudProvider(CloudProviderInterface): cloud_environment=self.sdk.cloud, ) - def _get_client_secret_credential_obj(self, creds): + def _get_client_secret_credential_obj(self): + creds = self._source_creds() return self.sdk.identity.ClientSecretCredential( - tenant_id=creds.get("tenant_id"), - client_id=creds.get("client_id"), - client_secret=creds.get("secret_key"), + tenant_id=creds.tenant_id, + client_id=creds.root_sp_client_id, + client_secret=creds.root_sp_key, ) - def _make_tenant_admin_cred_obj(self, username, password): - return self.sdk.credentials.UserPassCredentials(username, password) - def _ok(self, body=None): return self._make_response("ok", body) @@ -878,6 +861,26 @@ class AzureCloudProvider(CloudProviderInterface): "tenant_id": self.tenant_id, } + def _get_elevated_management_token(self, tenant_id): + mgmt_token = self._get_tenant_admin_token( + tenant_id, self.sdk.cloud.endpoints.resource_manager + ) + if mgmt_token is None: + raise AuthenticationException( + "Failed to resolve management token for tenant admin" + ) + + auth_header = { + "Authorization": f"Bearer {mgmt_token}", + } + url = f"{self.sdk.cloud.endpoints.resource_manager}/providers/Microsoft.Authorization/elevateAccess?api-version=2016-07-01" + result = self.sdk.requests.post(url, headers=auth_header) + + if not result.ok: + raise AuthenticationException("Failed to elevate access") + + return mgmt_token + def _source_creds(self, tenant_id=None) -> KeyVaultCredentials: if tenant_id: return self._source_tenant_creds(tenant_id) @@ -888,13 +891,16 @@ class AzureCloudProvider(CloudProviderInterface): root_sp_key=self._root_creds.get("secret_key"), ) - def update_tenant_creds(self, tenant_id, secret): + def update_tenant_creds(self, tenant_id, secret: KeyVaultCredentials): hashed = sha256_hex(tenant_id) - self.set_secret(hashed, json.dumps(secret)) + new_secrets = secret.dict() + curr_secrets = self._source_tenant_creds(tenant_id) + updated_secrets: Dict[str, Any] = {**curr_secrets.dict(), **new_secrets} + us = KeyVaultCredentials(**updated_secrets) + self.set_secret(hashed, json.dumps(us.dict())) + return us - return secret - - def _source_tenant_creds(self, tenant_id): + def _source_tenant_creds(self, tenant_id) -> KeyVaultCredentials: hashed = sha256_hex(tenant_id) raw_creds = self.get_secret(hashed) return KeyVaultCredentials(**json.loads(raw_creds)) diff --git a/atst/domain/csp/cloud/mock_cloud_provider.py b/atst/domain/csp/cloud/mock_cloud_provider.py index 52e68f08..fcc9495a 100644 --- a/atst/domain/csp/cloud/mock_cloud_provider.py +++ b/atst/domain/csp/cloud/mock_cloud_provider.py @@ -331,8 +331,8 @@ class MockCloudProvider(CloudProviderInterface): return TenantPrincipalCredentialCSPResult( **dict( - secretText="principal_secret_key", principal_client_id="principal_client_id", + principal_creds_established=True, ) ) diff --git a/atst/domain/csp/cloud/models.py b/atst/domain/csp/cloud/models.py index ce7769a1..18c969d6 100644 --- a/atst/domain/csp/cloud/models.py +++ b/atst/domain/csp/cloud/models.py @@ -278,10 +278,7 @@ class TenantPrincipalCredentialCSPPayload(BaseCSPPayload): class TenantPrincipalCredentialCSPResult(AliasModel): principal_client_id: str - principal_secret_key: str - - class Config: - fields = {"principal_secret_key": "secretText"} + principal_creds_established: bool class AdminRoleDefinitionCSPPayload(BaseCSPPayload): diff --git a/atst/models/portfolio_state_machine.py b/atst/models/portfolio_state_machine.py index be9324b1..4b14a087 100644 --- a/atst/models/portfolio_state_machine.py +++ b/atst/models/portfolio_state_machine.py @@ -168,14 +168,6 @@ class PortfolioStateMachine( self.portfolio.csp_data.update(response.dict()) db.session.add(self.portfolio) db.session.commit() - - if getattr(response, "get_creds", None) is not None: - new_creds = response.get_creds() - # TODO: one way salted hash of tenant_id to use as kv key name? - tenant_id = new_creds.get("tenant_id") - secret = self.csp.get_secret(tenant_id, new_creds) - secret.update(new_creds) - self.csp.update_tenant_creds(tenant_id, secret) except PydanticValidationError as exc: app.logger.error( f"Failed to cast response to valid result class {self.__repr__()}:", diff --git a/tests/domain/cloud/test_azure_csp.py b/tests/domain/cloud/test_azure_csp.py index ac5f62a2..67c03c9f 100644 --- a/tests/domain/cloud/test_azure_csp.py +++ b/tests/domain/cloud/test_azure_csp.py @@ -102,8 +102,10 @@ MOCK_CREDS = { } -def mock_get_secret(azure, func): - azure.get_secret = func +def mock_get_secret(azure, val=None): + if val is None: + val = json.dumps(MOCK_CREDS) + azure.get_secret = lambda *a, **k: val return azure @@ -111,12 +113,12 @@ def mock_get_secret(azure, func): def test_create_application_succeeds(mock_azure: AzureCloudProvider): application = ApplicationFactory.create() mock_management_group_create(mock_azure, {"id": "Test Id"}) - - mock_azure = mock_get_secret(mock_azure, lambda *a, **k: json.dumps(MOCK_CREDS)) + mock_azure = mock_get_secret(mock_azure) payload = ApplicationCSPPayload( tenant_id="1234", display_name=application.name, parent_id=str(uuid4()) ) + result = mock_azure.create_application(payload) assert result.id == "Test Id" @@ -162,10 +164,6 @@ def test_create_policy_definition_succeeds(mock_azure: AzureCloudProvider): def test_create_tenant(mock_azure: AzureCloudProvider): - mock_azure.sdk.adal.AuthenticationContext.return_value.context.acquire_token_with_client_credentials.return_value = { - "accessToken": "TOKEN" - } - mock_result = Mock() mock_result.json.return_value = { "objectId": "0a5f4926-e3ee-4f47-a6e3-8b0a30a40e3d", @@ -176,7 +174,6 @@ def test_create_tenant(mock_azure: AzureCloudProvider): mock_azure.sdk.requests.post.return_value = mock_result payload = TenantCSPPayload( **dict( - tenant_id="60ff9d34-82bf-4f21-b565-308ef0533435", user_id="admin", password="JediJan13$coot", # pragma: allowlist secret domain_name="jediccpospawnedtenant2", @@ -186,6 +183,7 @@ def test_create_tenant(mock_azure: AzureCloudProvider): password_recovery_email_address="thomas@promptworks.com", ) ) + mock_azure = mock_get_secret(mock_azure) result = mock_azure.create_tenant(payload) body: TenantCSPResult = result.get("body") assert body.tenant_id == "60ff9d34-82bf-4f21-b565-308ef0533435" @@ -446,8 +444,8 @@ def test_create_billing_instruction(mock_azure: AzureCloudProvider): def test_create_tenant_principal_app(mock_azure: AzureCloudProvider): with patch.object( AzureCloudProvider, - "get_elevated_management_token", - wraps=mock_azure.get_elevated_management_token, + "_get_elevated_management_token", + wraps=mock_azure._get_elevated_management_token, ) as get_elevated_management_token: get_elevated_management_token.return_value = "my fake token" @@ -456,11 +454,11 @@ def test_create_tenant_principal_app(mock_azure: AzureCloudProvider): mock_result.json.return_value = {"appId": "appId", "id": "id"} mock_azure.sdk.requests.post.return_value = mock_result + mock_azure = mock_get_secret(mock_azure) payload = TenantPrincipalAppCSPPayload( **{"tenant_id": "6d2d2d6c-a6d6-41e1-8bb1-73d11475f8f4"} ) - result: TenantPrincipalAppCSPResult = mock_azure.create_tenant_principal_app( payload ) @@ -471,8 +469,8 @@ def test_create_tenant_principal_app(mock_azure: AzureCloudProvider): def test_create_tenant_principal(mock_azure: AzureCloudProvider): with patch.object( AzureCloudProvider, - "get_elevated_management_token", - wraps=mock_azure.get_elevated_management_token, + "_get_elevated_management_token", + wraps=mock_azure._get_elevated_management_token, ) as get_elevated_management_token: get_elevated_management_token.return_value = "my fake token" @@ -481,6 +479,7 @@ def test_create_tenant_principal(mock_azure: AzureCloudProvider): mock_result.json.return_value = {"id": "principal_id"} mock_azure.sdk.requests.post.return_value = mock_result + mock_azure = mock_get_secret(mock_azure) payload = TenantPrincipalCSPPayload( **{ @@ -497,8 +496,8 @@ def test_create_tenant_principal(mock_azure: AzureCloudProvider): def test_create_tenant_principal_credential(mock_azure: AzureCloudProvider): with patch.object( AzureCloudProvider, - "get_elevated_management_token", - wraps=mock_azure.get_elevated_management_token, + "_get_elevated_management_token", + wraps=mock_azure._get_elevated_management_token, ) as get_elevated_management_token: get_elevated_management_token.return_value = "my fake token" @@ -508,6 +507,8 @@ def test_create_tenant_principal_credential(mock_azure: AzureCloudProvider): mock_azure.sdk.requests.post.return_value = mock_result + mock_azure = mock_get_secret(mock_azure) + payload = TenantPrincipalCredentialCSPPayload( **{ "tenant_id": "6d2d2d6c-a6d6-41e1-8bb1-73d11475f8f4", @@ -520,14 +521,14 @@ def test_create_tenant_principal_credential(mock_azure: AzureCloudProvider): payload ) - assert result.principal_secret_key == "new secret key" + assert result.principal_creds_established == True def test_create_admin_role_definition(mock_azure: AzureCloudProvider): with patch.object( AzureCloudProvider, - "get_elevated_management_token", - wraps=mock_azure.get_elevated_management_token, + "_get_elevated_management_token", + wraps=mock_azure._get_elevated_management_token, ) as get_elevated_management_token: get_elevated_management_token.return_value = "my fake token" @@ -541,6 +542,7 @@ def test_create_admin_role_definition(mock_azure: AzureCloudProvider): } mock_azure.sdk.requests.get.return_value = mock_result + mock_azure = mock_get_secret(mock_azure) payload = AdminRoleDefinitionCSPPayload( **{"tenant_id": "6d2d2d6c-a6d6-41e1-8bb1-73d11475f8f4"} @@ -556,8 +558,8 @@ def test_create_admin_role_definition(mock_azure: AzureCloudProvider): def test_create_tenant_admin_ownership(mock_azure: AzureCloudProvider): with patch.object( AzureCloudProvider, - "get_elevated_management_token", - wraps=mock_azure.get_elevated_management_token, + "_get_elevated_management_token", + wraps=mock_azure._get_elevated_management_token, ) as get_elevated_management_token: get_elevated_management_token.return_value = "my fake token" @@ -584,8 +586,8 @@ def test_create_tenant_admin_ownership(mock_azure: AzureCloudProvider): def test_create_tenant_principal_ownership(mock_azure: AzureCloudProvider): with patch.object( AzureCloudProvider, - "get_elevated_management_token", - wraps=mock_azure.get_elevated_management_token, + "_get_elevated_management_token", + wraps=mock_azure._get_elevated_management_token, ) as get_elevated_management_token: get_elevated_management_token.return_value = "my fake token"