diff --git a/atst/domain/csp/cloud/azure_cloud_provider.py b/atst/domain/csp/cloud/azure_cloud_provider.py index 0ed18d9d..84a9238c 100644 --- a/atst/domain/csp/cloud/azure_cloud_provider.py +++ b/atst/domain/csp/cloud/azure_cloud_provider.py @@ -1,3 +1,4 @@ +import json import re from secrets import token_urlsafe from typing import Dict @@ -16,6 +17,7 @@ from .models import ( BillingProfileTenantAccessCSPResult, BillingProfileVerificationCSPPayload, BillingProfileVerificationCSPResult, + KeyVaultCredentials, ManagementGroupCSPResponse, TaskOrderBillingCreationCSPPayload, TaskOrderBillingCreationCSPResult, @@ -25,6 +27,7 @@ from .models import ( TenantCSPResult, ) from .policy import AzurePolicyManager +from atst.utils import sha256_hex AZURE_ENVIRONMENT = "AZURE_PUBLIC_CLOUD" # TBD AZURE_SKU_ID = "?" # probably a static sku specific to ATAT/JEDI @@ -85,7 +88,7 @@ class AzureCloudProvider(CloudProviderInterface): def set_secret(self, secret_key, secret_value): credential = self._get_client_secret_credential_obj({}) - secret_client = self.secrets.SecretClient( + secret_client = self.sdk.secrets.SecretClient( vault_url=self.vault_url, credential=credential, ) try: @@ -98,7 +101,7 @@ class AzureCloudProvider(CloudProviderInterface): def get_secret(self, secret_key): credential = self._get_client_secret_credential_obj({}) - secret_client = self.secrets.SecretClient( + secret_client = self.sdk.secrets.SecretClient( vault_url=self.vault_url, credential=credential, ) try: @@ -166,8 +169,15 @@ class AzureCloudProvider(CloudProviderInterface): } def create_application(self, payload: ApplicationCSPPayload): - creds = payload.creds - credentials = self._get_credential_obj(creds, resource=AZURE_MANAGEMENT_API) + creds = self._source_creds(payload.tenant_id) + credentials = self._get_credential_obj( + { + "client_id": creds.root_sp_client_id, + "secret_key": creds.root_sp_key, + "tenant_id": creds.root_tenant_id, + }, + resource=AZURE_MANAGEMENT_API, + ) response = self._create_management_group( credentials, @@ -632,26 +642,23 @@ class AzureCloudProvider(CloudProviderInterface): "tenant_id": self.tenant_id, } - def get_credentials(self, scope="portfolio", tenant_id=None): - """ - This could be implemented to determine, based on type, whether to return creds for: - - scope="atat": the ATAT main app registration in ATAT's home tenant - - scope="tenantadmin": the tenant administrator credentials - - scope="portfolio": the credentials for the ATAT SP in the portfolio tenant - """ - if scope == "atat": - return self._root_creds - elif scope == "tenantadmin": - # magic with key vault happens - return { - "client_id": "some id", - "secret_key": "very secret", - "tenant_id": tenant_id, - } - elif scope == "portfolio": - # magic with key vault happens - return { - "client_id": "some id", - "secret_key": "very secret", - "tenant_id": tenant_id, - } + def _source_creds(self, tenant_id=None) -> KeyVaultCredentials: + if tenant_id: + return self._source_tenant_creds(tenant_id) + else: + return KeyVaultCredentials( + root_tenant_id=self._root_creds.get("tenant_id"), + root_sp_client_id=self._root_creds.get("client_id"), + root_sp_key=self._root_creds.get("secret_key"), + ) + + def update_tenant_creds(self, tenant_id, secret): + hashed = sha256_hex(tenant_id) + self.set_secret(hashed, json.dumps(secret)) + + return secret + + def _source_tenant_creds(self, tenant_id): + hashed = sha256_hex(tenant_id) + raw_creds = self.get_secret(hashed) + return KeyVaultCredentials(**json.loads(raw_creds)) diff --git a/atst/domain/csp/cloud/mock_cloud_provider.py b/atst/domain/csp/cloud/mock_cloud_provider.py index 6df61003..10d62e15 100644 --- a/atst/domain/csp/cloud/mock_cloud_provider.py +++ b/atst/domain/csp/cloud/mock_cloud_provider.py @@ -347,8 +347,12 @@ class MockCloudProvider(CloudProviderInterface): def create_application(self, payload: ApplicationCSPPayload): self._maybe_raise(self.UNAUTHORIZED_RATE, GeneralCSPException) - id_ = f"{AZURE_MGMNT_PATH}{payload.management_group_name}" - return ApplicationCSPResult(id=id_) + return ApplicationCSPResult( + id=f"{AZURE_MGMNT_PATH}{payload.management_group_name}" + ) def get_credentials(self, scope="portfolio", tenant_id=None): return self.root_creds() + + def update_tenant_creds(self, tenant_id, secret): + return secret diff --git a/atst/domain/csp/cloud/models.py b/atst/domain/csp/cloud/models.py index 369bed31..b4ff9232 100644 --- a/atst/domain/csp/cloud/models.py +++ b/atst/domain/csp/cloud/models.py @@ -2,7 +2,7 @@ from typing import Dict, List, Optional import re from uuid import uuid4 -from pydantic import BaseModel, validator +from pydantic import BaseModel, validator, root_validator from atst.utils import snake_to_camel @@ -241,7 +241,7 @@ AZURE_MGMNT_PATH = "/providers/Microsoft.Management/managementGroups/" MANAGEMENT_GROUP_NAME_REGEX = "^[a-zA-Z0-9\-_\(\)\.]+$" -class ManagementGroupCSPPayload(BaseCSPPayload): +class ManagementGroupCSPPayload(AliasModel): """ :param: management_group_name: Just pass a UUID for this. :param: display_name: This can contain any character and @@ -250,6 +250,7 @@ class ManagementGroupCSPPayload(BaseCSPPayload): i.e. /providers/Microsoft.Management/managementGroups/[management group ID] """ + tenant_id: str management_group_name: Optional[str] display_name: str parent_id: str @@ -288,3 +289,55 @@ class ApplicationCSPPayload(ManagementGroupCSPPayload): class ApplicationCSPResult(ManagementGroupCSPResponse): pass + + +class KeyVaultCredentials(BaseModel): + root_sp_client_id: Optional[str] + root_sp_key: Optional[str] + root_tenant_id: Optional[str] + + tenant_id: Optional[str] + + tenant_admin_username: Optional[str] + tenant_admin_password: Optional[str] + + tenant_sp_client_id: Optional[str] + tenant_sp_key: Optional[str] + + @root_validator(pre=True) + def enforce_admin_creds(cls, values): + tenant_id = values.get("tenant_id") + username = values.get("tenant_admin_username") + password = values.get("tenant_admin_password") + if any([username, password]) and not all([tenant_id, username, password]): + raise ValueError( + "tenant_id, tenant_admin_username, and tenant_admin_password must all be set if any one is" + ) + + return values + + @root_validator(pre=True) + def enforce_sp_creds(cls, values): + tenant_id = values.get("tenant_id") + client_id = values.get("tenant_sp_client_id") + key = values.get("tenant_sp_key") + if any([client_id, key]) and not all([tenant_id, client_id, key]): + raise ValueError( + "tenant_id, tenant_sp_client_id, and tenant_sp_key must all be set if any one is" + ) + + return values + + @root_validator(pre=True) + def enforce_root_creds(cls, values): + sp_creds = [ + values.get("root_tenant_id"), + values.get("root_sp_client_id"), + values.get("root_sp_key"), + ] + if any(sp_creds) and not all(sp_creds): + raise ValueError( + "root_tenant_id, root_sp_client_id, and root_sp_key must all be set if any one is" + ) + + return values diff --git a/atst/jobs.py b/atst/jobs.py index 7a4a3792..14256336 100644 --- a/atst/jobs.py +++ b/atst/jobs.py @@ -59,15 +59,14 @@ def do_create_application(csp: CloudProviderInterface, application_id=None): with claim_for_update(application) as application: - if application.cloud_id is not None: + if application.cloud_id: return csp_details = application.portfolio.csp_data parent_id = csp_details.get("root_management_group_id") tenant_id = csp_details.get("tenant_id") - creds = csp.get_credentials(tenant_id) payload = ApplicationCSPPayload( - creds=creds, display_name=application.name, parent_id=parent_id + tenant_id=tenant_id, display_name=application.name, parent_id=parent_id ) app_result = csp.create_application(payload) diff --git a/atst/models/portfolio_state_machine.py b/atst/models/portfolio_state_machine.py index cf42710b..be9324b1 100644 --- a/atst/models/portfolio_state_machine.py +++ b/atst/models/portfolio_state_machine.py @@ -175,7 +175,7 @@ class PortfolioStateMachine( tenant_id = new_creds.get("tenant_id") secret = self.csp.get_secret(tenant_id, new_creds) secret.update(new_creds) - self.csp.set_secret(tenant_id, secret) + self.csp.update_tenant_creds(tenant_id, secret) except PydanticValidationError as exc: app.logger.error( f"Failed to cast response to valid result class {self.__repr__()}:", diff --git a/atst/utils/__init__.py b/atst/utils/__init__.py index 09c63dea..79d5362a 100644 --- a/atst/utils/__init__.py +++ b/atst/utils/__init__.py @@ -1,3 +1,4 @@ +import hashlib import re from sqlalchemy.exc import IntegrityError @@ -41,3 +42,8 @@ def commit_or_raise_already_exists_error(message): except IntegrityError: db.session.rollback() raise AlreadyExistsError(message) + + +def sha256_hex(string): + hsh = hashlib.sha256(string.encode()) + return hsh.digest().hex() diff --git a/tests/domain/cloud/test_azure_csp.py b/tests/domain/cloud/test_azure_csp.py index 0d23d6c0..39fa2f77 100644 --- a/tests/domain/cloud/test_azure_csp.py +++ b/tests/domain/cloud/test_azure_csp.py @@ -1,5 +1,7 @@ +import pytest +import json from uuid import uuid4 -from unittest.mock import Mock +from unittest.mock import Mock, patch from tests.factories import ApplicationFactory, EnvironmentFactory from tests.mock_azure import AUTH_CREDENTIALS, mock_azure @@ -84,13 +86,28 @@ def test_create_environment_succeeds(mock_azure: AzureCloudProvider): assert result.id == "Test Id" +# mock the get_secret so it returns a JSON string +MOCK_CREDS = { + "tenant_id": str(uuid4()), + "tenant_sp_client_id": str(uuid4()), + "tenant_sp_key": "1234", +} + + +def mock_get_secret(azure, func): + azure.get_secret = func + + return azure + + def test_create_application_succeeds(mock_azure: AzureCloudProvider): application = ApplicationFactory.create() - mock_management_group_create(mock_azure, {"id": "Test Id"}) + mock_azure = mock_get_secret(mock_azure, lambda *a, **k: json.dumps(MOCK_CREDS)) + payload = ApplicationCSPPayload( - creds={}, display_name=application.name, parent_id=str(uuid4()) + tenant_id="1234", display_name=application.name, parent_id=str(uuid4()) ) result = mock_azure.create_application(payload) diff --git a/tests/domain/cloud/test_payloads.py b/tests/domain/cloud/test_models.py similarity index 54% rename from tests/domain/cloud/test_payloads.py rename to tests/domain/cloud/test_models.py index d92a4840..d9fc963d 100644 --- a/tests/domain/cloud/test_payloads.py +++ b/tests/domain/cloud/test_models.py @@ -4,6 +4,7 @@ from pydantic import ValidationError from atst.domain.csp.cloud.models import ( AZURE_MGMNT_PATH, + KeyVaultCredentials, ManagementGroupCSPPayload, ManagementGroupCSPResponse, ) @@ -12,25 +13,25 @@ from atst.domain.csp.cloud.models import ( def test_ManagementGroupCSPPayload_management_group_name(): # supplies management_group_name when absent payload = ManagementGroupCSPPayload( - creds={}, display_name="Council of Naboo", parent_id="Galactic_Senate" + tenant_id="any-old-id", + display_name="Council of Naboo", + parent_id="Galactic_Senate", ) assert payload.management_group_name # validates management_group_name with pytest.raises(ValidationError): payload = ManagementGroupCSPPayload( - creds={}, + tenant_id="any-old-id", management_group_name="council of Naboo 1%^&", display_name="Council of Naboo", parent_id="Galactic_Senate", ) # shortens management_group_name to fit - name = "council_of_naboo" - for _ in range(90): - name = f"{name}1" + name = "council_of_naboo".ljust(95, "1") assert len(name) > 90 payload = ManagementGroupCSPPayload( - creds={}, + tenant_id="any-old-id", management_group_name=name, display_name="Council of Naboo", parent_id="Galactic_Senate", @@ -40,12 +41,10 @@ def test_ManagementGroupCSPPayload_management_group_name(): def test_ManagementGroupCSPPayload_display_name(): # shortens display_name to fit - name = "Council of Naboo" - for _ in range(90): - name = f"{name}1" + name = "Council of Naboo".ljust(95, "1") assert len(name) > 90 payload = ManagementGroupCSPPayload( - creds={}, display_name=name, parent_id="Galactic_Senate" + tenant_id="any-old-id", display_name=name, parent_id="Galactic_Senate" ) assert len(payload.display_name) == 90 @@ -54,12 +53,14 @@ def test_ManagementGroupCSPPayload_parent_id(): full_path = f"{AZURE_MGMNT_PATH}Galactic_Senate" # adds full path payload = ManagementGroupCSPPayload( - creds={}, display_name="Council of Naboo", parent_id="Galactic_Senate" + tenant_id="any-old-id", + display_name="Council of Naboo", + parent_id="Galactic_Senate", ) assert payload.parent_id == full_path # keeps full path payload = ManagementGroupCSPPayload( - creds={}, display_name="Council of Naboo", parent_id=full_path + tenant_id="any-old-id", display_name="Council of Naboo", parent_id=full_path ) assert payload.parent_id == full_path @@ -70,3 +71,29 @@ def test_ManagementGroupCSPResponse_id(): **{"id": "/path/to/naboo-123", "other": "stuff"} ) assert response.id == full_id + + +def test_KeyVaultCredentials_enforce_admin_creds(): + with pytest.raises(ValidationError): + KeyVaultCredentials(tenant_id="an id", tenant_admin_username="C3PO") + assert KeyVaultCredentials( + tenant_id="an id", + tenant_admin_username="C3PO", + tenant_admin_password="beep boop", + ) + + +def test_KeyVaultCredentials_enforce_sp_creds(): + with pytest.raises(ValidationError): + KeyVaultCredentials(tenant_id="an id", tenant_sp_client_id="C3PO") + assert KeyVaultCredentials( + tenant_id="an id", tenant_sp_client_id="C3PO", tenant_sp_key="beep boop" + ) + + +def test_KeyVaultCredentials_enforce_root_creds(): + with pytest.raises(ValidationError): + KeyVaultCredentials(root_tenant_id="an id", root_sp_client_id="C3PO") + assert KeyVaultCredentials( + root_tenant_id="an id", root_sp_client_id="C3PO", root_sp_key="beep boop" + ) diff --git a/tests/mock_azure.py b/tests/mock_azure.py index 7fa67667..4f37848e 100644 --- a/tests/mock_azure.py +++ b/tests/mock_azure.py @@ -72,6 +72,12 @@ def mock_secrets(): return Mock(spec=secrets) +def mock_identity(): + import azure.identity as identity + + return Mock(spec=identity) + + class MockAzureSDK(object): def __init__(self): from msrestazure.azure_cloud import AZURE_PUBLIC_CLOUD @@ -88,6 +94,7 @@ class MockAzureSDK(object): self.requests = mock_requests() # may change to a JEDI cloud self.cloud = AZURE_PUBLIC_CLOUD + self.identity = mock_identity() @pytest.fixture(scope="function") diff --git a/tests/utils/test_hash.py b/tests/utils/test_hash.py new file mode 100644 index 00000000..5cfb8489 --- /dev/null +++ b/tests/utils/test_hash.py @@ -0,0 +1,16 @@ +import random +import re +import string + +from atst.utils import sha256_hex + + +def test_sha256_hex(): + sample = "".join( + random.choices(string.ascii_uppercase + string.digits, k=random.randrange(200)) + ) + hashed = sha256_hex(sample) + assert re.match("^[a-zA-Z0-9]+$", hashed) + assert len(hashed) == 64 + hashed_again = sha256_hex(sample) + assert hashed == hashed_again