From abd03be806e8ca2f61373c147eb6019ca6f5c0fd Mon Sep 17 00:00:00 2001 From: dandds Date: Mon, 27 Jan 2020 15:00:20 -0500 Subject: [PATCH] Store and pull tenant creds from Key Vault. The tenant ID should be hashed and used as the key for the JSON blob of relevant creds for any given tenant. Azure CSP interface methods that need to source creds should call the internal `_source_creds` method, either with a `tenant_id` or no parameters. That method will source the creds. If a tenant ID is provided, it will source them from the Key Vault. If not provided, it will return the default creds for the app registration in the home tenant. --- atst/domain/csp/cloud/azure_cloud_provider.py | 61 +++++++++++-------- atst/domain/csp/cloud/mock_cloud_provider.py | 8 ++- atst/domain/csp/cloud/models.py | 57 ++++++++++++++++- atst/jobs.py | 5 +- atst/models/portfolio_state_machine.py | 2 +- atst/utils/__init__.py | 6 ++ tests/domain/cloud/test_azure_csp.py | 23 ++++++- .../{test_payloads.py => test_models.py} | 51 ++++++++++++---- tests/mock_azure.py | 7 +++ tests/utils/test_hash.py | 16 +++++ 10 files changed, 186 insertions(+), 50 deletions(-) rename tests/domain/cloud/{test_payloads.py => test_models.py} (54%) create mode 100644 tests/utils/test_hash.py diff --git a/atst/domain/csp/cloud/azure_cloud_provider.py b/atst/domain/csp/cloud/azure_cloud_provider.py index 0ed18d9d..84a9238c 100644 --- a/atst/domain/csp/cloud/azure_cloud_provider.py +++ b/atst/domain/csp/cloud/azure_cloud_provider.py @@ -1,3 +1,4 @@ +import json import re from secrets import token_urlsafe from typing import Dict @@ -16,6 +17,7 @@ from .models import ( BillingProfileTenantAccessCSPResult, BillingProfileVerificationCSPPayload, BillingProfileVerificationCSPResult, + KeyVaultCredentials, ManagementGroupCSPResponse, TaskOrderBillingCreationCSPPayload, TaskOrderBillingCreationCSPResult, @@ -25,6 +27,7 @@ from .models import ( TenantCSPResult, ) from .policy import AzurePolicyManager +from atst.utils import sha256_hex AZURE_ENVIRONMENT = "AZURE_PUBLIC_CLOUD" # TBD AZURE_SKU_ID = "?" # probably a static sku specific to ATAT/JEDI @@ -85,7 +88,7 @@ class AzureCloudProvider(CloudProviderInterface): def set_secret(self, secret_key, secret_value): credential = self._get_client_secret_credential_obj({}) - secret_client = self.secrets.SecretClient( + secret_client = self.sdk.secrets.SecretClient( vault_url=self.vault_url, credential=credential, ) try: @@ -98,7 +101,7 @@ class AzureCloudProvider(CloudProviderInterface): def get_secret(self, secret_key): credential = self._get_client_secret_credential_obj({}) - secret_client = self.secrets.SecretClient( + secret_client = self.sdk.secrets.SecretClient( vault_url=self.vault_url, credential=credential, ) try: @@ -166,8 +169,15 @@ class AzureCloudProvider(CloudProviderInterface): } def create_application(self, payload: ApplicationCSPPayload): - creds = payload.creds - credentials = self._get_credential_obj(creds, resource=AZURE_MANAGEMENT_API) + creds = self._source_creds(payload.tenant_id) + credentials = self._get_credential_obj( + { + "client_id": creds.root_sp_client_id, + "secret_key": creds.root_sp_key, + "tenant_id": creds.root_tenant_id, + }, + resource=AZURE_MANAGEMENT_API, + ) response = self._create_management_group( credentials, @@ -632,26 +642,23 @@ class AzureCloudProvider(CloudProviderInterface): "tenant_id": self.tenant_id, } - def get_credentials(self, scope="portfolio", tenant_id=None): - """ - This could be implemented to determine, based on type, whether to return creds for: - - scope="atat": the ATAT main app registration in ATAT's home tenant - - scope="tenantadmin": the tenant administrator credentials - - scope="portfolio": the credentials for the ATAT SP in the portfolio tenant - """ - if scope == "atat": - return self._root_creds - elif scope == "tenantadmin": - # magic with key vault happens - return { - "client_id": "some id", - "secret_key": "very secret", - "tenant_id": tenant_id, - } - elif scope == "portfolio": - # magic with key vault happens - return { - "client_id": "some id", - "secret_key": "very secret", - "tenant_id": tenant_id, - } + def _source_creds(self, tenant_id=None) -> KeyVaultCredentials: + if tenant_id: + return self._source_tenant_creds(tenant_id) + else: + return KeyVaultCredentials( + root_tenant_id=self._root_creds.get("tenant_id"), + root_sp_client_id=self._root_creds.get("client_id"), + root_sp_key=self._root_creds.get("secret_key"), + ) + + def update_tenant_creds(self, tenant_id, secret): + hashed = sha256_hex(tenant_id) + self.set_secret(hashed, json.dumps(secret)) + + return secret + + def _source_tenant_creds(self, tenant_id): + hashed = sha256_hex(tenant_id) + raw_creds = self.get_secret(hashed) + return KeyVaultCredentials(**json.loads(raw_creds)) diff --git a/atst/domain/csp/cloud/mock_cloud_provider.py b/atst/domain/csp/cloud/mock_cloud_provider.py index 6df61003..10d62e15 100644 --- a/atst/domain/csp/cloud/mock_cloud_provider.py +++ b/atst/domain/csp/cloud/mock_cloud_provider.py @@ -347,8 +347,12 @@ class MockCloudProvider(CloudProviderInterface): def create_application(self, payload: ApplicationCSPPayload): self._maybe_raise(self.UNAUTHORIZED_RATE, GeneralCSPException) - id_ = f"{AZURE_MGMNT_PATH}{payload.management_group_name}" - return ApplicationCSPResult(id=id_) + return ApplicationCSPResult( + id=f"{AZURE_MGMNT_PATH}{payload.management_group_name}" + ) def get_credentials(self, scope="portfolio", tenant_id=None): return self.root_creds() + + def update_tenant_creds(self, tenant_id, secret): + return secret diff --git a/atst/domain/csp/cloud/models.py b/atst/domain/csp/cloud/models.py index 369bed31..b4ff9232 100644 --- a/atst/domain/csp/cloud/models.py +++ b/atst/domain/csp/cloud/models.py @@ -2,7 +2,7 @@ from typing import Dict, List, Optional import re from uuid import uuid4 -from pydantic import BaseModel, validator +from pydantic import BaseModel, validator, root_validator from atst.utils import snake_to_camel @@ -241,7 +241,7 @@ AZURE_MGMNT_PATH = "/providers/Microsoft.Management/managementGroups/" MANAGEMENT_GROUP_NAME_REGEX = "^[a-zA-Z0-9\-_\(\)\.]+$" -class ManagementGroupCSPPayload(BaseCSPPayload): +class ManagementGroupCSPPayload(AliasModel): """ :param: management_group_name: Just pass a UUID for this. :param: display_name: This can contain any character and @@ -250,6 +250,7 @@ class ManagementGroupCSPPayload(BaseCSPPayload): i.e. /providers/Microsoft.Management/managementGroups/[management group ID] """ + tenant_id: str management_group_name: Optional[str] display_name: str parent_id: str @@ -288,3 +289,55 @@ class ApplicationCSPPayload(ManagementGroupCSPPayload): class ApplicationCSPResult(ManagementGroupCSPResponse): pass + + +class KeyVaultCredentials(BaseModel): + root_sp_client_id: Optional[str] + root_sp_key: Optional[str] + root_tenant_id: Optional[str] + + tenant_id: Optional[str] + + tenant_admin_username: Optional[str] + tenant_admin_password: Optional[str] + + tenant_sp_client_id: Optional[str] + tenant_sp_key: Optional[str] + + @root_validator(pre=True) + def enforce_admin_creds(cls, values): + tenant_id = values.get("tenant_id") + username = values.get("tenant_admin_username") + password = values.get("tenant_admin_password") + if any([username, password]) and not all([tenant_id, username, password]): + raise ValueError( + "tenant_id, tenant_admin_username, and tenant_admin_password must all be set if any one is" + ) + + return values + + @root_validator(pre=True) + def enforce_sp_creds(cls, values): + tenant_id = values.get("tenant_id") + client_id = values.get("tenant_sp_client_id") + key = values.get("tenant_sp_key") + if any([client_id, key]) and not all([tenant_id, client_id, key]): + raise ValueError( + "tenant_id, tenant_sp_client_id, and tenant_sp_key must all be set if any one is" + ) + + return values + + @root_validator(pre=True) + def enforce_root_creds(cls, values): + sp_creds = [ + values.get("root_tenant_id"), + values.get("root_sp_client_id"), + values.get("root_sp_key"), + ] + if any(sp_creds) and not all(sp_creds): + raise ValueError( + "root_tenant_id, root_sp_client_id, and root_sp_key must all be set if any one is" + ) + + return values diff --git a/atst/jobs.py b/atst/jobs.py index 7a4a3792..14256336 100644 --- a/atst/jobs.py +++ b/atst/jobs.py @@ -59,15 +59,14 @@ def do_create_application(csp: CloudProviderInterface, application_id=None): with claim_for_update(application) as application: - if application.cloud_id is not None: + if application.cloud_id: return csp_details = application.portfolio.csp_data parent_id = csp_details.get("root_management_group_id") tenant_id = csp_details.get("tenant_id") - creds = csp.get_credentials(tenant_id) payload = ApplicationCSPPayload( - creds=creds, display_name=application.name, parent_id=parent_id + tenant_id=tenant_id, display_name=application.name, parent_id=parent_id ) app_result = csp.create_application(payload) diff --git a/atst/models/portfolio_state_machine.py b/atst/models/portfolio_state_machine.py index cf42710b..be9324b1 100644 --- a/atst/models/portfolio_state_machine.py +++ b/atst/models/portfolio_state_machine.py @@ -175,7 +175,7 @@ class PortfolioStateMachine( tenant_id = new_creds.get("tenant_id") secret = self.csp.get_secret(tenant_id, new_creds) secret.update(new_creds) - self.csp.set_secret(tenant_id, secret) + self.csp.update_tenant_creds(tenant_id, secret) except PydanticValidationError as exc: app.logger.error( f"Failed to cast response to valid result class {self.__repr__()}:", diff --git a/atst/utils/__init__.py b/atst/utils/__init__.py index 09c63dea..79d5362a 100644 --- a/atst/utils/__init__.py +++ b/atst/utils/__init__.py @@ -1,3 +1,4 @@ +import hashlib import re from sqlalchemy.exc import IntegrityError @@ -41,3 +42,8 @@ def commit_or_raise_already_exists_error(message): except IntegrityError: db.session.rollback() raise AlreadyExistsError(message) + + +def sha256_hex(string): + hsh = hashlib.sha256(string.encode()) + return hsh.digest().hex() diff --git a/tests/domain/cloud/test_azure_csp.py b/tests/domain/cloud/test_azure_csp.py index 0d23d6c0..39fa2f77 100644 --- a/tests/domain/cloud/test_azure_csp.py +++ b/tests/domain/cloud/test_azure_csp.py @@ -1,5 +1,7 @@ +import pytest +import json from uuid import uuid4 -from unittest.mock import Mock +from unittest.mock import Mock, patch from tests.factories import ApplicationFactory, EnvironmentFactory from tests.mock_azure import AUTH_CREDENTIALS, mock_azure @@ -84,13 +86,28 @@ def test_create_environment_succeeds(mock_azure: AzureCloudProvider): assert result.id == "Test Id" +# mock the get_secret so it returns a JSON string +MOCK_CREDS = { + "tenant_id": str(uuid4()), + "tenant_sp_client_id": str(uuid4()), + "tenant_sp_key": "1234", +} + + +def mock_get_secret(azure, func): + azure.get_secret = func + + return azure + + def test_create_application_succeeds(mock_azure: AzureCloudProvider): application = ApplicationFactory.create() - mock_management_group_create(mock_azure, {"id": "Test Id"}) + mock_azure = mock_get_secret(mock_azure, lambda *a, **k: json.dumps(MOCK_CREDS)) + payload = ApplicationCSPPayload( - creds={}, display_name=application.name, parent_id=str(uuid4()) + tenant_id="1234", display_name=application.name, parent_id=str(uuid4()) ) result = mock_azure.create_application(payload) diff --git a/tests/domain/cloud/test_payloads.py b/tests/domain/cloud/test_models.py similarity index 54% rename from tests/domain/cloud/test_payloads.py rename to tests/domain/cloud/test_models.py index d92a4840..d9fc963d 100644 --- a/tests/domain/cloud/test_payloads.py +++ b/tests/domain/cloud/test_models.py @@ -4,6 +4,7 @@ from pydantic import ValidationError from atst.domain.csp.cloud.models import ( AZURE_MGMNT_PATH, + KeyVaultCredentials, ManagementGroupCSPPayload, ManagementGroupCSPResponse, ) @@ -12,25 +13,25 @@ from atst.domain.csp.cloud.models import ( def test_ManagementGroupCSPPayload_management_group_name(): # supplies management_group_name when absent payload = ManagementGroupCSPPayload( - creds={}, display_name="Council of Naboo", parent_id="Galactic_Senate" + tenant_id="any-old-id", + display_name="Council of Naboo", + parent_id="Galactic_Senate", ) assert payload.management_group_name # validates management_group_name with pytest.raises(ValidationError): payload = ManagementGroupCSPPayload( - creds={}, + tenant_id="any-old-id", management_group_name="council of Naboo 1%^&", display_name="Council of Naboo", parent_id="Galactic_Senate", ) # shortens management_group_name to fit - name = "council_of_naboo" - for _ in range(90): - name = f"{name}1" + name = "council_of_naboo".ljust(95, "1") assert len(name) > 90 payload = ManagementGroupCSPPayload( - creds={}, + tenant_id="any-old-id", management_group_name=name, display_name="Council of Naboo", parent_id="Galactic_Senate", @@ -40,12 +41,10 @@ def test_ManagementGroupCSPPayload_management_group_name(): def test_ManagementGroupCSPPayload_display_name(): # shortens display_name to fit - name = "Council of Naboo" - for _ in range(90): - name = f"{name}1" + name = "Council of Naboo".ljust(95, "1") assert len(name) > 90 payload = ManagementGroupCSPPayload( - creds={}, display_name=name, parent_id="Galactic_Senate" + tenant_id="any-old-id", display_name=name, parent_id="Galactic_Senate" ) assert len(payload.display_name) == 90 @@ -54,12 +53,14 @@ def test_ManagementGroupCSPPayload_parent_id(): full_path = f"{AZURE_MGMNT_PATH}Galactic_Senate" # adds full path payload = ManagementGroupCSPPayload( - creds={}, display_name="Council of Naboo", parent_id="Galactic_Senate" + tenant_id="any-old-id", + display_name="Council of Naboo", + parent_id="Galactic_Senate", ) assert payload.parent_id == full_path # keeps full path payload = ManagementGroupCSPPayload( - creds={}, display_name="Council of Naboo", parent_id=full_path + tenant_id="any-old-id", display_name="Council of Naboo", parent_id=full_path ) assert payload.parent_id == full_path @@ -70,3 +71,29 @@ def test_ManagementGroupCSPResponse_id(): **{"id": "/path/to/naboo-123", "other": "stuff"} ) assert response.id == full_id + + +def test_KeyVaultCredentials_enforce_admin_creds(): + with pytest.raises(ValidationError): + KeyVaultCredentials(tenant_id="an id", tenant_admin_username="C3PO") + assert KeyVaultCredentials( + tenant_id="an id", + tenant_admin_username="C3PO", + tenant_admin_password="beep boop", + ) + + +def test_KeyVaultCredentials_enforce_sp_creds(): + with pytest.raises(ValidationError): + KeyVaultCredentials(tenant_id="an id", tenant_sp_client_id="C3PO") + assert KeyVaultCredentials( + tenant_id="an id", tenant_sp_client_id="C3PO", tenant_sp_key="beep boop" + ) + + +def test_KeyVaultCredentials_enforce_root_creds(): + with pytest.raises(ValidationError): + KeyVaultCredentials(root_tenant_id="an id", root_sp_client_id="C3PO") + assert KeyVaultCredentials( + root_tenant_id="an id", root_sp_client_id="C3PO", root_sp_key="beep boop" + ) diff --git a/tests/mock_azure.py b/tests/mock_azure.py index 7fa67667..4f37848e 100644 --- a/tests/mock_azure.py +++ b/tests/mock_azure.py @@ -72,6 +72,12 @@ def mock_secrets(): return Mock(spec=secrets) +def mock_identity(): + import azure.identity as identity + + return Mock(spec=identity) + + class MockAzureSDK(object): def __init__(self): from msrestazure.azure_cloud import AZURE_PUBLIC_CLOUD @@ -88,6 +94,7 @@ class MockAzureSDK(object): self.requests = mock_requests() # may change to a JEDI cloud self.cloud = AZURE_PUBLIC_CLOUD + self.identity = mock_identity() @pytest.fixture(scope="function") diff --git a/tests/utils/test_hash.py b/tests/utils/test_hash.py new file mode 100644 index 00000000..5cfb8489 --- /dev/null +++ b/tests/utils/test_hash.py @@ -0,0 +1,16 @@ +import random +import re +import string + +from atst.utils import sha256_hex + + +def test_sha256_hex(): + sample = "".join( + random.choices(string.ascii_uppercase + string.digits, k=random.randrange(200)) + ) + hashed = sha256_hex(sample) + assert re.match("^[a-zA-Z0-9]+$", hashed) + assert len(hashed) == 64 + hashed_again = sha256_hex(sample) + assert hashed == hashed_again