diff --git a/atst/domain/csp/cloud/azure_cloud_provider.py b/atst/domain/csp/cloud/azure_cloud_provider.py index 2d00dcf4..425fe649 100644 --- a/atst/domain/csp/cloud/azure_cloud_provider.py +++ b/atst/domain/csp/cloud/azure_cloud_provider.py @@ -1,6 +1,5 @@ import json from secrets import token_urlsafe -from typing import Any, Dict from uuid import uuid4 from atst.utils import sha256_hex @@ -1026,12 +1025,10 @@ class AzureCloudProvider(CloudProviderInterface): def update_tenant_creds(self, tenant_id, secret: KeyVaultCredentials): hashed = sha256_hex(tenant_id) - new_secrets = secret.dict() curr_secrets = self._source_tenant_creds(tenant_id) - updated_secrets: Dict[str, Any] = {**curr_secrets.dict(), **new_secrets} - us = KeyVaultCredentials(**updated_secrets) - self.set_secret(hashed, json.dumps(us.dict())) - return us + updated_secrets = curr_secrets.merge_credentials(secret) + self.set_secret(hashed, json.dumps(updated_secrets.dict())) + return updated_secrets def _source_tenant_creds(self, tenant_id) -> KeyVaultCredentials: hashed = sha256_hex(tenant_id) diff --git a/atst/domain/csp/cloud/models.py b/atst/domain/csp/cloud/models.py index 358f7934..25521bb9 100644 --- a/atst/domain/csp/cloud/models.py +++ b/atst/domain/csp/cloud/models.py @@ -417,6 +417,15 @@ class KeyVaultCredentials(BaseModel): return values + def merge_credentials( + self, new_creds: "KeyVaultCredentials" + ) -> "KeyVaultCredentials": + updated_creds = {k: v for k, v in new_creds.dict().items() if v} + old_creds = self.dict() + old_creds.update(updated_creds) + + return KeyVaultCredentials(**old_creds) + class SubscriptionCreationCSPPayload(BaseCSPPayload): display_name: str diff --git a/tests/domain/cloud/test_azure_csp.py b/tests/domain/cloud/test_azure_csp.py index 3a25f849..c89e6a77 100644 --- a/tests/domain/cloud/test_azure_csp.py +++ b/tests/domain/cloud/test_azure_csp.py @@ -25,6 +25,7 @@ from atst.domain.csp.cloud.models import ( CostManagementQueryCSPResult, EnvironmentCSPPayload, EnvironmentCSPResult, + KeyVaultCredentials, PrincipalAdminRoleCSPPayload, PrincipalAdminRoleCSPResult, ProductPurchaseCSPPayload, @@ -938,3 +939,23 @@ def test_create_user(mock_azure: AzureCloudProvider): result = mock_azure.create_user(payload) assert result.id == "id" + + +def test_update_tenant_creds(mock_azure: AzureCloudProvider): + with patch.object( + AzureCloudProvider, "set_secret", wraps=mock_azure.set_secret, + ) as set_secret: + set_secret.return_value = None + existing_secrets = { + "tenant_id": "mytenant", + "tenant_admin_username": "admin", + "tenant_admin_password": "foo", # pragma: allowlist secret + } + mock_azure = mock_get_secret(mock_azure, json.dumps(existing_secrets)) + + mock_new_secrets = KeyVaultCredentials(**MOCK_CREDS) + updated_secret = mock_azure.update_tenant_creds("mytenant", mock_new_secrets) + + assert updated_secret == KeyVaultCredentials( + **{**existing_secrets, **MOCK_CREDS} + ) diff --git a/tests/domain/cloud/test_models.py b/tests/domain/cloud/test_models.py index 10c81293..29bc60cb 100644 --- a/tests/domain/cloud/test_models.py +++ b/tests/domain/cloud/test_models.py @@ -100,6 +100,26 @@ def test_KeyVaultCredentials_enforce_root_creds(): ) +def test_KeyVaultCredentials_merge_credentials(): + old_secret = KeyVaultCredentials( + tenant_id="foo", + tenant_admin_username="bar", + tenant_admin_password="baz", # pragma: allowlist secret + ) + new_secret = KeyVaultCredentials( + tenant_id="foo", tenant_sp_client_id="bip", tenant_sp_key="bop" + ) + + expected_update = KeyVaultCredentials( + tenant_id="foo", + tenant_admin_username="bar", + tenant_admin_password="baz", # pragma: allowlist secret + tenant_sp_client_id="bip", + tenant_sp_key="bop", + ) + assert old_secret.merge_credentials(new_secret) == expected_update + + user_payload = { "tenant_id": "123", "display_name": "Han Solo",