Store and pull tenant creds from Key Vault.

The tenant ID should be hashed and used as the key for the JSON blob of
relevant creds for any given tenant. Azure CSP interface methods that
need to source creds should call the internal `_source_creds` method,
either with a `tenant_id` or no parameters. That method will source the
creds. If a tenant ID is provided, it will source them from the Key
Vault. If not provided, it will return the default creds for the app
registration in the home tenant.
This commit is contained in:
dandds 2020-01-27 15:00:20 -05:00
parent a10d733fb7
commit abd03be806
10 changed files with 186 additions and 50 deletions

View File

@ -1,3 +1,4 @@
import json
import re import re
from secrets import token_urlsafe from secrets import token_urlsafe
from typing import Dict from typing import Dict
@ -16,6 +17,7 @@ from .models import (
BillingProfileTenantAccessCSPResult, BillingProfileTenantAccessCSPResult,
BillingProfileVerificationCSPPayload, BillingProfileVerificationCSPPayload,
BillingProfileVerificationCSPResult, BillingProfileVerificationCSPResult,
KeyVaultCredentials,
ManagementGroupCSPResponse, ManagementGroupCSPResponse,
TaskOrderBillingCreationCSPPayload, TaskOrderBillingCreationCSPPayload,
TaskOrderBillingCreationCSPResult, TaskOrderBillingCreationCSPResult,
@ -25,6 +27,7 @@ from .models import (
TenantCSPResult, TenantCSPResult,
) )
from .policy import AzurePolicyManager from .policy import AzurePolicyManager
from atst.utils import sha256_hex
AZURE_ENVIRONMENT = "AZURE_PUBLIC_CLOUD" # TBD AZURE_ENVIRONMENT = "AZURE_PUBLIC_CLOUD" # TBD
AZURE_SKU_ID = "?" # probably a static sku specific to ATAT/JEDI AZURE_SKU_ID = "?" # probably a static sku specific to ATAT/JEDI
@ -85,7 +88,7 @@ class AzureCloudProvider(CloudProviderInterface):
def set_secret(self, secret_key, secret_value): def set_secret(self, secret_key, secret_value):
credential = self._get_client_secret_credential_obj({}) credential = self._get_client_secret_credential_obj({})
secret_client = self.secrets.SecretClient( secret_client = self.sdk.secrets.SecretClient(
vault_url=self.vault_url, credential=credential, vault_url=self.vault_url, credential=credential,
) )
try: try:
@ -98,7 +101,7 @@ class AzureCloudProvider(CloudProviderInterface):
def get_secret(self, secret_key): def get_secret(self, secret_key):
credential = self._get_client_secret_credential_obj({}) credential = self._get_client_secret_credential_obj({})
secret_client = self.secrets.SecretClient( secret_client = self.sdk.secrets.SecretClient(
vault_url=self.vault_url, credential=credential, vault_url=self.vault_url, credential=credential,
) )
try: try:
@ -166,8 +169,15 @@ class AzureCloudProvider(CloudProviderInterface):
} }
def create_application(self, payload: ApplicationCSPPayload): def create_application(self, payload: ApplicationCSPPayload):
creds = payload.creds creds = self._source_creds(payload.tenant_id)
credentials = self._get_credential_obj(creds, resource=AZURE_MANAGEMENT_API) credentials = self._get_credential_obj(
{
"client_id": creds.root_sp_client_id,
"secret_key": creds.root_sp_key,
"tenant_id": creds.root_tenant_id,
},
resource=AZURE_MANAGEMENT_API,
)
response = self._create_management_group( response = self._create_management_group(
credentials, credentials,
@ -632,26 +642,23 @@ class AzureCloudProvider(CloudProviderInterface):
"tenant_id": self.tenant_id, "tenant_id": self.tenant_id,
} }
def get_credentials(self, scope="portfolio", tenant_id=None): def _source_creds(self, tenant_id=None) -> KeyVaultCredentials:
""" if tenant_id:
This could be implemented to determine, based on type, whether to return creds for: return self._source_tenant_creds(tenant_id)
- scope="atat": the ATAT main app registration in ATAT's home tenant else:
- scope="tenantadmin": the tenant administrator credentials return KeyVaultCredentials(
- scope="portfolio": the credentials for the ATAT SP in the portfolio tenant root_tenant_id=self._root_creds.get("tenant_id"),
""" root_sp_client_id=self._root_creds.get("client_id"),
if scope == "atat": root_sp_key=self._root_creds.get("secret_key"),
return self._root_creds )
elif scope == "tenantadmin":
# magic with key vault happens def update_tenant_creds(self, tenant_id, secret):
return { hashed = sha256_hex(tenant_id)
"client_id": "some id", self.set_secret(hashed, json.dumps(secret))
"secret_key": "very secret",
"tenant_id": tenant_id, return secret
}
elif scope == "portfolio": def _source_tenant_creds(self, tenant_id):
# magic with key vault happens hashed = sha256_hex(tenant_id)
return { raw_creds = self.get_secret(hashed)
"client_id": "some id", return KeyVaultCredentials(**json.loads(raw_creds))
"secret_key": "very secret",
"tenant_id": tenant_id,
}

View File

@ -347,8 +347,12 @@ class MockCloudProvider(CloudProviderInterface):
def create_application(self, payload: ApplicationCSPPayload): def create_application(self, payload: ApplicationCSPPayload):
self._maybe_raise(self.UNAUTHORIZED_RATE, GeneralCSPException) self._maybe_raise(self.UNAUTHORIZED_RATE, GeneralCSPException)
id_ = f"{AZURE_MGMNT_PATH}{payload.management_group_name}" return ApplicationCSPResult(
return ApplicationCSPResult(id=id_) id=f"{AZURE_MGMNT_PATH}{payload.management_group_name}"
)
def get_credentials(self, scope="portfolio", tenant_id=None): def get_credentials(self, scope="portfolio", tenant_id=None):
return self.root_creds() return self.root_creds()
def update_tenant_creds(self, tenant_id, secret):
return secret

View File

@ -2,7 +2,7 @@ from typing import Dict, List, Optional
import re import re
from uuid import uuid4 from uuid import uuid4
from pydantic import BaseModel, validator from pydantic import BaseModel, validator, root_validator
from atst.utils import snake_to_camel from atst.utils import snake_to_camel
@ -241,7 +241,7 @@ AZURE_MGMNT_PATH = "/providers/Microsoft.Management/managementGroups/"
MANAGEMENT_GROUP_NAME_REGEX = "^[a-zA-Z0-9\-_\(\)\.]+$" MANAGEMENT_GROUP_NAME_REGEX = "^[a-zA-Z0-9\-_\(\)\.]+$"
class ManagementGroupCSPPayload(BaseCSPPayload): class ManagementGroupCSPPayload(AliasModel):
""" """
:param: management_group_name: Just pass a UUID for this. :param: management_group_name: Just pass a UUID for this.
:param: display_name: This can contain any character and :param: display_name: This can contain any character and
@ -250,6 +250,7 @@ class ManagementGroupCSPPayload(BaseCSPPayload):
i.e. /providers/Microsoft.Management/managementGroups/[management group ID] i.e. /providers/Microsoft.Management/managementGroups/[management group ID]
""" """
tenant_id: str
management_group_name: Optional[str] management_group_name: Optional[str]
display_name: str display_name: str
parent_id: str parent_id: str
@ -288,3 +289,55 @@ class ApplicationCSPPayload(ManagementGroupCSPPayload):
class ApplicationCSPResult(ManagementGroupCSPResponse): class ApplicationCSPResult(ManagementGroupCSPResponse):
pass pass
class KeyVaultCredentials(BaseModel):
root_sp_client_id: Optional[str]
root_sp_key: Optional[str]
root_tenant_id: Optional[str]
tenant_id: Optional[str]
tenant_admin_username: Optional[str]
tenant_admin_password: Optional[str]
tenant_sp_client_id: Optional[str]
tenant_sp_key: Optional[str]
@root_validator(pre=True)
def enforce_admin_creds(cls, values):
tenant_id = values.get("tenant_id")
username = values.get("tenant_admin_username")
password = values.get("tenant_admin_password")
if any([username, password]) and not all([tenant_id, username, password]):
raise ValueError(
"tenant_id, tenant_admin_username, and tenant_admin_password must all be set if any one is"
)
return values
@root_validator(pre=True)
def enforce_sp_creds(cls, values):
tenant_id = values.get("tenant_id")
client_id = values.get("tenant_sp_client_id")
key = values.get("tenant_sp_key")
if any([client_id, key]) and not all([tenant_id, client_id, key]):
raise ValueError(
"tenant_id, tenant_sp_client_id, and tenant_sp_key must all be set if any one is"
)
return values
@root_validator(pre=True)
def enforce_root_creds(cls, values):
sp_creds = [
values.get("root_tenant_id"),
values.get("root_sp_client_id"),
values.get("root_sp_key"),
]
if any(sp_creds) and not all(sp_creds):
raise ValueError(
"root_tenant_id, root_sp_client_id, and root_sp_key must all be set if any one is"
)
return values

View File

@ -59,15 +59,14 @@ def do_create_application(csp: CloudProviderInterface, application_id=None):
with claim_for_update(application) as application: with claim_for_update(application) as application:
if application.cloud_id is not None: if application.cloud_id:
return return
csp_details = application.portfolio.csp_data csp_details = application.portfolio.csp_data
parent_id = csp_details.get("root_management_group_id") parent_id = csp_details.get("root_management_group_id")
tenant_id = csp_details.get("tenant_id") tenant_id = csp_details.get("tenant_id")
creds = csp.get_credentials(tenant_id)
payload = ApplicationCSPPayload( payload = ApplicationCSPPayload(
creds=creds, display_name=application.name, parent_id=parent_id tenant_id=tenant_id, display_name=application.name, parent_id=parent_id
) )
app_result = csp.create_application(payload) app_result = csp.create_application(payload)

View File

@ -175,7 +175,7 @@ class PortfolioStateMachine(
tenant_id = new_creds.get("tenant_id") tenant_id = new_creds.get("tenant_id")
secret = self.csp.get_secret(tenant_id, new_creds) secret = self.csp.get_secret(tenant_id, new_creds)
secret.update(new_creds) secret.update(new_creds)
self.csp.set_secret(tenant_id, secret) self.csp.update_tenant_creds(tenant_id, secret)
except PydanticValidationError as exc: except PydanticValidationError as exc:
app.logger.error( app.logger.error(
f"Failed to cast response to valid result class {self.__repr__()}:", f"Failed to cast response to valid result class {self.__repr__()}:",

View File

@ -1,3 +1,4 @@
import hashlib
import re import re
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
@ -41,3 +42,8 @@ def commit_or_raise_already_exists_error(message):
except IntegrityError: except IntegrityError:
db.session.rollback() db.session.rollback()
raise AlreadyExistsError(message) raise AlreadyExistsError(message)
def sha256_hex(string):
hsh = hashlib.sha256(string.encode())
return hsh.digest().hex()

View File

@ -1,5 +1,7 @@
import pytest
import json
from uuid import uuid4 from uuid import uuid4
from unittest.mock import Mock from unittest.mock import Mock, patch
from tests.factories import ApplicationFactory, EnvironmentFactory from tests.factories import ApplicationFactory, EnvironmentFactory
from tests.mock_azure import AUTH_CREDENTIALS, mock_azure from tests.mock_azure import AUTH_CREDENTIALS, mock_azure
@ -84,13 +86,28 @@ def test_create_environment_succeeds(mock_azure: AzureCloudProvider):
assert result.id == "Test Id" assert result.id == "Test Id"
# mock the get_secret so it returns a JSON string
MOCK_CREDS = {
"tenant_id": str(uuid4()),
"tenant_sp_client_id": str(uuid4()),
"tenant_sp_key": "1234",
}
def mock_get_secret(azure, func):
azure.get_secret = func
return azure
def test_create_application_succeeds(mock_azure: AzureCloudProvider): def test_create_application_succeeds(mock_azure: AzureCloudProvider):
application = ApplicationFactory.create() application = ApplicationFactory.create()
mock_management_group_create(mock_azure, {"id": "Test Id"}) mock_management_group_create(mock_azure, {"id": "Test Id"})
mock_azure = mock_get_secret(mock_azure, lambda *a, **k: json.dumps(MOCK_CREDS))
payload = ApplicationCSPPayload( payload = ApplicationCSPPayload(
creds={}, display_name=application.name, parent_id=str(uuid4()) tenant_id="1234", display_name=application.name, parent_id=str(uuid4())
) )
result = mock_azure.create_application(payload) result = mock_azure.create_application(payload)

View File

@ -4,6 +4,7 @@ from pydantic import ValidationError
from atst.domain.csp.cloud.models import ( from atst.domain.csp.cloud.models import (
AZURE_MGMNT_PATH, AZURE_MGMNT_PATH,
KeyVaultCredentials,
ManagementGroupCSPPayload, ManagementGroupCSPPayload,
ManagementGroupCSPResponse, ManagementGroupCSPResponse,
) )
@ -12,25 +13,25 @@ from atst.domain.csp.cloud.models import (
def test_ManagementGroupCSPPayload_management_group_name(): def test_ManagementGroupCSPPayload_management_group_name():
# supplies management_group_name when absent # supplies management_group_name when absent
payload = ManagementGroupCSPPayload( payload = ManagementGroupCSPPayload(
creds={}, display_name="Council of Naboo", parent_id="Galactic_Senate" tenant_id="any-old-id",
display_name="Council of Naboo",
parent_id="Galactic_Senate",
) )
assert payload.management_group_name assert payload.management_group_name
# validates management_group_name # validates management_group_name
with pytest.raises(ValidationError): with pytest.raises(ValidationError):
payload = ManagementGroupCSPPayload( payload = ManagementGroupCSPPayload(
creds={}, tenant_id="any-old-id",
management_group_name="council of Naboo 1%^&", management_group_name="council of Naboo 1%^&",
display_name="Council of Naboo", display_name="Council of Naboo",
parent_id="Galactic_Senate", parent_id="Galactic_Senate",
) )
# shortens management_group_name to fit # shortens management_group_name to fit
name = "council_of_naboo" name = "council_of_naboo".ljust(95, "1")
for _ in range(90):
name = f"{name}1"
assert len(name) > 90 assert len(name) > 90
payload = ManagementGroupCSPPayload( payload = ManagementGroupCSPPayload(
creds={}, tenant_id="any-old-id",
management_group_name=name, management_group_name=name,
display_name="Council of Naboo", display_name="Council of Naboo",
parent_id="Galactic_Senate", parent_id="Galactic_Senate",
@ -40,12 +41,10 @@ def test_ManagementGroupCSPPayload_management_group_name():
def test_ManagementGroupCSPPayload_display_name(): def test_ManagementGroupCSPPayload_display_name():
# shortens display_name to fit # shortens display_name to fit
name = "Council of Naboo" name = "Council of Naboo".ljust(95, "1")
for _ in range(90):
name = f"{name}1"
assert len(name) > 90 assert len(name) > 90
payload = ManagementGroupCSPPayload( payload = ManagementGroupCSPPayload(
creds={}, display_name=name, parent_id="Galactic_Senate" tenant_id="any-old-id", display_name=name, parent_id="Galactic_Senate"
) )
assert len(payload.display_name) == 90 assert len(payload.display_name) == 90
@ -54,12 +53,14 @@ def test_ManagementGroupCSPPayload_parent_id():
full_path = f"{AZURE_MGMNT_PATH}Galactic_Senate" full_path = f"{AZURE_MGMNT_PATH}Galactic_Senate"
# adds full path # adds full path
payload = ManagementGroupCSPPayload( payload = ManagementGroupCSPPayload(
creds={}, display_name="Council of Naboo", parent_id="Galactic_Senate" tenant_id="any-old-id",
display_name="Council of Naboo",
parent_id="Galactic_Senate",
) )
assert payload.parent_id == full_path assert payload.parent_id == full_path
# keeps full path # keeps full path
payload = ManagementGroupCSPPayload( payload = ManagementGroupCSPPayload(
creds={}, display_name="Council of Naboo", parent_id=full_path tenant_id="any-old-id", display_name="Council of Naboo", parent_id=full_path
) )
assert payload.parent_id == full_path assert payload.parent_id == full_path
@ -70,3 +71,29 @@ def test_ManagementGroupCSPResponse_id():
**{"id": "/path/to/naboo-123", "other": "stuff"} **{"id": "/path/to/naboo-123", "other": "stuff"}
) )
assert response.id == full_id assert response.id == full_id
def test_KeyVaultCredentials_enforce_admin_creds():
with pytest.raises(ValidationError):
KeyVaultCredentials(tenant_id="an id", tenant_admin_username="C3PO")
assert KeyVaultCredentials(
tenant_id="an id",
tenant_admin_username="C3PO",
tenant_admin_password="beep boop",
)
def test_KeyVaultCredentials_enforce_sp_creds():
with pytest.raises(ValidationError):
KeyVaultCredentials(tenant_id="an id", tenant_sp_client_id="C3PO")
assert KeyVaultCredentials(
tenant_id="an id", tenant_sp_client_id="C3PO", tenant_sp_key="beep boop"
)
def test_KeyVaultCredentials_enforce_root_creds():
with pytest.raises(ValidationError):
KeyVaultCredentials(root_tenant_id="an id", root_sp_client_id="C3PO")
assert KeyVaultCredentials(
root_tenant_id="an id", root_sp_client_id="C3PO", root_sp_key="beep boop"
)

View File

@ -72,6 +72,12 @@ def mock_secrets():
return Mock(spec=secrets) return Mock(spec=secrets)
def mock_identity():
import azure.identity as identity
return Mock(spec=identity)
class MockAzureSDK(object): class MockAzureSDK(object):
def __init__(self): def __init__(self):
from msrestazure.azure_cloud import AZURE_PUBLIC_CLOUD from msrestazure.azure_cloud import AZURE_PUBLIC_CLOUD
@ -88,6 +94,7 @@ class MockAzureSDK(object):
self.requests = mock_requests() self.requests = mock_requests()
# may change to a JEDI cloud # may change to a JEDI cloud
self.cloud = AZURE_PUBLIC_CLOUD self.cloud = AZURE_PUBLIC_CLOUD
self.identity = mock_identity()
@pytest.fixture(scope="function") @pytest.fixture(scope="function")

16
tests/utils/test_hash.py Normal file
View File

@ -0,0 +1,16 @@
import random
import re
import string
from atst.utils import sha256_hex
def test_sha256_hex():
sample = "".join(
random.choices(string.ascii_uppercase + string.digits, k=random.randrange(200))
)
hashed = sha256_hex(sample)
assert re.match("^[a-zA-Z0-9]+$", hashed)
assert len(hashed) == 64
hashed_again = sha256_hex(sample)
assert hashed == hashed_again