Store and pull tenant creds from Key Vault.

The tenant ID should be hashed and used as the key for the JSON blob of
relevant creds for any given tenant. Azure CSP interface methods that
need to source creds should call the internal `_source_creds` method,
either with a `tenant_id` or no parameters. That method will source the
creds. If a tenant ID is provided, it will source them from the Key
Vault. If not provided, it will return the default creds for the app
registration in the home tenant.
This commit is contained in:
dandds 2020-01-27 15:00:20 -05:00
parent a10d733fb7
commit abd03be806
10 changed files with 186 additions and 50 deletions

View File

@ -1,3 +1,4 @@
import json
import re
from secrets import token_urlsafe
from typing import Dict
@ -16,6 +17,7 @@ from .models import (
BillingProfileTenantAccessCSPResult,
BillingProfileVerificationCSPPayload,
BillingProfileVerificationCSPResult,
KeyVaultCredentials,
ManagementGroupCSPResponse,
TaskOrderBillingCreationCSPPayload,
TaskOrderBillingCreationCSPResult,
@ -25,6 +27,7 @@ from .models import (
TenantCSPResult,
)
from .policy import AzurePolicyManager
from atst.utils import sha256_hex
AZURE_ENVIRONMENT = "AZURE_PUBLIC_CLOUD" # TBD
AZURE_SKU_ID = "?" # probably a static sku specific to ATAT/JEDI
@ -85,7 +88,7 @@ class AzureCloudProvider(CloudProviderInterface):
def set_secret(self, secret_key, secret_value):
credential = self._get_client_secret_credential_obj({})
secret_client = self.secrets.SecretClient(
secret_client = self.sdk.secrets.SecretClient(
vault_url=self.vault_url, credential=credential,
)
try:
@ -98,7 +101,7 @@ class AzureCloudProvider(CloudProviderInterface):
def get_secret(self, secret_key):
credential = self._get_client_secret_credential_obj({})
secret_client = self.secrets.SecretClient(
secret_client = self.sdk.secrets.SecretClient(
vault_url=self.vault_url, credential=credential,
)
try:
@ -166,8 +169,15 @@ class AzureCloudProvider(CloudProviderInterface):
}
def create_application(self, payload: ApplicationCSPPayload):
creds = payload.creds
credentials = self._get_credential_obj(creds, resource=AZURE_MANAGEMENT_API)
creds = self._source_creds(payload.tenant_id)
credentials = self._get_credential_obj(
{
"client_id": creds.root_sp_client_id,
"secret_key": creds.root_sp_key,
"tenant_id": creds.root_tenant_id,
},
resource=AZURE_MANAGEMENT_API,
)
response = self._create_management_group(
credentials,
@ -632,26 +642,23 @@ class AzureCloudProvider(CloudProviderInterface):
"tenant_id": self.tenant_id,
}
def get_credentials(self, scope="portfolio", tenant_id=None):
"""
This could be implemented to determine, based on type, whether to return creds for:
- scope="atat": the ATAT main app registration in ATAT's home tenant
- scope="tenantadmin": the tenant administrator credentials
- scope="portfolio": the credentials for the ATAT SP in the portfolio tenant
"""
if scope == "atat":
return self._root_creds
elif scope == "tenantadmin":
# magic with key vault happens
return {
"client_id": "some id",
"secret_key": "very secret",
"tenant_id": tenant_id,
}
elif scope == "portfolio":
# magic with key vault happens
return {
"client_id": "some id",
"secret_key": "very secret",
"tenant_id": tenant_id,
}
def _source_creds(self, tenant_id=None) -> KeyVaultCredentials:
if tenant_id:
return self._source_tenant_creds(tenant_id)
else:
return KeyVaultCredentials(
root_tenant_id=self._root_creds.get("tenant_id"),
root_sp_client_id=self._root_creds.get("client_id"),
root_sp_key=self._root_creds.get("secret_key"),
)
def update_tenant_creds(self, tenant_id, secret):
hashed = sha256_hex(tenant_id)
self.set_secret(hashed, json.dumps(secret))
return secret
def _source_tenant_creds(self, tenant_id):
hashed = sha256_hex(tenant_id)
raw_creds = self.get_secret(hashed)
return KeyVaultCredentials(**json.loads(raw_creds))

View File

@ -347,8 +347,12 @@ class MockCloudProvider(CloudProviderInterface):
def create_application(self, payload: ApplicationCSPPayload):
self._maybe_raise(self.UNAUTHORIZED_RATE, GeneralCSPException)
id_ = f"{AZURE_MGMNT_PATH}{payload.management_group_name}"
return ApplicationCSPResult(id=id_)
return ApplicationCSPResult(
id=f"{AZURE_MGMNT_PATH}{payload.management_group_name}"
)
def get_credentials(self, scope="portfolio", tenant_id=None):
return self.root_creds()
def update_tenant_creds(self, tenant_id, secret):
return secret

View File

@ -2,7 +2,7 @@ from typing import Dict, List, Optional
import re
from uuid import uuid4
from pydantic import BaseModel, validator
from pydantic import BaseModel, validator, root_validator
from atst.utils import snake_to_camel
@ -241,7 +241,7 @@ AZURE_MGMNT_PATH = "/providers/Microsoft.Management/managementGroups/"
MANAGEMENT_GROUP_NAME_REGEX = "^[a-zA-Z0-9\-_\(\)\.]+$"
class ManagementGroupCSPPayload(BaseCSPPayload):
class ManagementGroupCSPPayload(AliasModel):
"""
:param: management_group_name: Just pass a UUID for this.
:param: display_name: This can contain any character and
@ -250,6 +250,7 @@ class ManagementGroupCSPPayload(BaseCSPPayload):
i.e. /providers/Microsoft.Management/managementGroups/[management group ID]
"""
tenant_id: str
management_group_name: Optional[str]
display_name: str
parent_id: str
@ -288,3 +289,55 @@ class ApplicationCSPPayload(ManagementGroupCSPPayload):
class ApplicationCSPResult(ManagementGroupCSPResponse):
pass
class KeyVaultCredentials(BaseModel):
root_sp_client_id: Optional[str]
root_sp_key: Optional[str]
root_tenant_id: Optional[str]
tenant_id: Optional[str]
tenant_admin_username: Optional[str]
tenant_admin_password: Optional[str]
tenant_sp_client_id: Optional[str]
tenant_sp_key: Optional[str]
@root_validator(pre=True)
def enforce_admin_creds(cls, values):
tenant_id = values.get("tenant_id")
username = values.get("tenant_admin_username")
password = values.get("tenant_admin_password")
if any([username, password]) and not all([tenant_id, username, password]):
raise ValueError(
"tenant_id, tenant_admin_username, and tenant_admin_password must all be set if any one is"
)
return values
@root_validator(pre=True)
def enforce_sp_creds(cls, values):
tenant_id = values.get("tenant_id")
client_id = values.get("tenant_sp_client_id")
key = values.get("tenant_sp_key")
if any([client_id, key]) and not all([tenant_id, client_id, key]):
raise ValueError(
"tenant_id, tenant_sp_client_id, and tenant_sp_key must all be set if any one is"
)
return values
@root_validator(pre=True)
def enforce_root_creds(cls, values):
sp_creds = [
values.get("root_tenant_id"),
values.get("root_sp_client_id"),
values.get("root_sp_key"),
]
if any(sp_creds) and not all(sp_creds):
raise ValueError(
"root_tenant_id, root_sp_client_id, and root_sp_key must all be set if any one is"
)
return values

View File

@ -59,15 +59,14 @@ def do_create_application(csp: CloudProviderInterface, application_id=None):
with claim_for_update(application) as application:
if application.cloud_id is not None:
if application.cloud_id:
return
csp_details = application.portfolio.csp_data
parent_id = csp_details.get("root_management_group_id")
tenant_id = csp_details.get("tenant_id")
creds = csp.get_credentials(tenant_id)
payload = ApplicationCSPPayload(
creds=creds, display_name=application.name, parent_id=parent_id
tenant_id=tenant_id, display_name=application.name, parent_id=parent_id
)
app_result = csp.create_application(payload)

View File

@ -175,7 +175,7 @@ class PortfolioStateMachine(
tenant_id = new_creds.get("tenant_id")
secret = self.csp.get_secret(tenant_id, new_creds)
secret.update(new_creds)
self.csp.set_secret(tenant_id, secret)
self.csp.update_tenant_creds(tenant_id, secret)
except PydanticValidationError as exc:
app.logger.error(
f"Failed to cast response to valid result class {self.__repr__()}:",

View File

@ -1,3 +1,4 @@
import hashlib
import re
from sqlalchemy.exc import IntegrityError
@ -41,3 +42,8 @@ def commit_or_raise_already_exists_error(message):
except IntegrityError:
db.session.rollback()
raise AlreadyExistsError(message)
def sha256_hex(string):
hsh = hashlib.sha256(string.encode())
return hsh.digest().hex()

View File

@ -1,5 +1,7 @@
import pytest
import json
from uuid import uuid4
from unittest.mock import Mock
from unittest.mock import Mock, patch
from tests.factories import ApplicationFactory, EnvironmentFactory
from tests.mock_azure import AUTH_CREDENTIALS, mock_azure
@ -84,13 +86,28 @@ def test_create_environment_succeeds(mock_azure: AzureCloudProvider):
assert result.id == "Test Id"
# mock the get_secret so it returns a JSON string
MOCK_CREDS = {
"tenant_id": str(uuid4()),
"tenant_sp_client_id": str(uuid4()),
"tenant_sp_key": "1234",
}
def mock_get_secret(azure, func):
azure.get_secret = func
return azure
def test_create_application_succeeds(mock_azure: AzureCloudProvider):
application = ApplicationFactory.create()
mock_management_group_create(mock_azure, {"id": "Test Id"})
mock_azure = mock_get_secret(mock_azure, lambda *a, **k: json.dumps(MOCK_CREDS))
payload = ApplicationCSPPayload(
creds={}, display_name=application.name, parent_id=str(uuid4())
tenant_id="1234", display_name=application.name, parent_id=str(uuid4())
)
result = mock_azure.create_application(payload)

View File

@ -4,6 +4,7 @@ from pydantic import ValidationError
from atst.domain.csp.cloud.models import (
AZURE_MGMNT_PATH,
KeyVaultCredentials,
ManagementGroupCSPPayload,
ManagementGroupCSPResponse,
)
@ -12,25 +13,25 @@ from atst.domain.csp.cloud.models import (
def test_ManagementGroupCSPPayload_management_group_name():
# supplies management_group_name when absent
payload = ManagementGroupCSPPayload(
creds={}, display_name="Council of Naboo", parent_id="Galactic_Senate"
tenant_id="any-old-id",
display_name="Council of Naboo",
parent_id="Galactic_Senate",
)
assert payload.management_group_name
# validates management_group_name
with pytest.raises(ValidationError):
payload = ManagementGroupCSPPayload(
creds={},
tenant_id="any-old-id",
management_group_name="council of Naboo 1%^&",
display_name="Council of Naboo",
parent_id="Galactic_Senate",
)
# shortens management_group_name to fit
name = "council_of_naboo"
for _ in range(90):
name = f"{name}1"
name = "council_of_naboo".ljust(95, "1")
assert len(name) > 90
payload = ManagementGroupCSPPayload(
creds={},
tenant_id="any-old-id",
management_group_name=name,
display_name="Council of Naboo",
parent_id="Galactic_Senate",
@ -40,12 +41,10 @@ def test_ManagementGroupCSPPayload_management_group_name():
def test_ManagementGroupCSPPayload_display_name():
# shortens display_name to fit
name = "Council of Naboo"
for _ in range(90):
name = f"{name}1"
name = "Council of Naboo".ljust(95, "1")
assert len(name) > 90
payload = ManagementGroupCSPPayload(
creds={}, display_name=name, parent_id="Galactic_Senate"
tenant_id="any-old-id", display_name=name, parent_id="Galactic_Senate"
)
assert len(payload.display_name) == 90
@ -54,12 +53,14 @@ def test_ManagementGroupCSPPayload_parent_id():
full_path = f"{AZURE_MGMNT_PATH}Galactic_Senate"
# adds full path
payload = ManagementGroupCSPPayload(
creds={}, display_name="Council of Naboo", parent_id="Galactic_Senate"
tenant_id="any-old-id",
display_name="Council of Naboo",
parent_id="Galactic_Senate",
)
assert payload.parent_id == full_path
# keeps full path
payload = ManagementGroupCSPPayload(
creds={}, display_name="Council of Naboo", parent_id=full_path
tenant_id="any-old-id", display_name="Council of Naboo", parent_id=full_path
)
assert payload.parent_id == full_path
@ -70,3 +71,29 @@ def test_ManagementGroupCSPResponse_id():
**{"id": "/path/to/naboo-123", "other": "stuff"}
)
assert response.id == full_id
def test_KeyVaultCredentials_enforce_admin_creds():
with pytest.raises(ValidationError):
KeyVaultCredentials(tenant_id="an id", tenant_admin_username="C3PO")
assert KeyVaultCredentials(
tenant_id="an id",
tenant_admin_username="C3PO",
tenant_admin_password="beep boop",
)
def test_KeyVaultCredentials_enforce_sp_creds():
with pytest.raises(ValidationError):
KeyVaultCredentials(tenant_id="an id", tenant_sp_client_id="C3PO")
assert KeyVaultCredentials(
tenant_id="an id", tenant_sp_client_id="C3PO", tenant_sp_key="beep boop"
)
def test_KeyVaultCredentials_enforce_root_creds():
with pytest.raises(ValidationError):
KeyVaultCredentials(root_tenant_id="an id", root_sp_client_id="C3PO")
assert KeyVaultCredentials(
root_tenant_id="an id", root_sp_client_id="C3PO", root_sp_key="beep boop"
)

View File

@ -72,6 +72,12 @@ def mock_secrets():
return Mock(spec=secrets)
def mock_identity():
import azure.identity as identity
return Mock(spec=identity)
class MockAzureSDK(object):
def __init__(self):
from msrestazure.azure_cloud import AZURE_PUBLIC_CLOUD
@ -88,6 +94,7 @@ class MockAzureSDK(object):
self.requests = mock_requests()
# may change to a JEDI cloud
self.cloud = AZURE_PUBLIC_CLOUD
self.identity = mock_identity()
@pytest.fixture(scope="function")

16
tests/utils/test_hash.py Normal file
View File

@ -0,0 +1,16 @@
import random
import re
import string
from atst.utils import sha256_hex
def test_sha256_hex():
sample = "".join(
random.choices(string.ascii_uppercase + string.digits, k=random.randrange(200))
)
hashed = sha256_hex(sample)
assert re.match("^[a-zA-Z0-9]+$", hashed)
assert len(hashed) == 64
hashed_again = sha256_hex(sample)
assert hashed == hashed_again