Store and pull tenant creds from Key Vault.
The tenant ID should be hashed and used as the key for the JSON blob of relevant creds for any given tenant. Azure CSP interface methods that need to source creds should call the internal `_source_creds` method, either with a `tenant_id` or no parameters. That method will source the creds. If a tenant ID is provided, it will source them from the Key Vault. If not provided, it will return the default creds for the app registration in the home tenant.
This commit is contained in:
parent
a10d733fb7
commit
abd03be806
@ -1,3 +1,4 @@
|
||||
import json
|
||||
import re
|
||||
from secrets import token_urlsafe
|
||||
from typing import Dict
|
||||
@ -16,6 +17,7 @@ from .models import (
|
||||
BillingProfileTenantAccessCSPResult,
|
||||
BillingProfileVerificationCSPPayload,
|
||||
BillingProfileVerificationCSPResult,
|
||||
KeyVaultCredentials,
|
||||
ManagementGroupCSPResponse,
|
||||
TaskOrderBillingCreationCSPPayload,
|
||||
TaskOrderBillingCreationCSPResult,
|
||||
@ -25,6 +27,7 @@ from .models import (
|
||||
TenantCSPResult,
|
||||
)
|
||||
from .policy import AzurePolicyManager
|
||||
from atst.utils import sha256_hex
|
||||
|
||||
AZURE_ENVIRONMENT = "AZURE_PUBLIC_CLOUD" # TBD
|
||||
AZURE_SKU_ID = "?" # probably a static sku specific to ATAT/JEDI
|
||||
@ -85,7 +88,7 @@ class AzureCloudProvider(CloudProviderInterface):
|
||||
|
||||
def set_secret(self, secret_key, secret_value):
|
||||
credential = self._get_client_secret_credential_obj({})
|
||||
secret_client = self.secrets.SecretClient(
|
||||
secret_client = self.sdk.secrets.SecretClient(
|
||||
vault_url=self.vault_url, credential=credential,
|
||||
)
|
||||
try:
|
||||
@ -98,7 +101,7 @@ class AzureCloudProvider(CloudProviderInterface):
|
||||
|
||||
def get_secret(self, secret_key):
|
||||
credential = self._get_client_secret_credential_obj({})
|
||||
secret_client = self.secrets.SecretClient(
|
||||
secret_client = self.sdk.secrets.SecretClient(
|
||||
vault_url=self.vault_url, credential=credential,
|
||||
)
|
||||
try:
|
||||
@ -166,8 +169,15 @@ class AzureCloudProvider(CloudProviderInterface):
|
||||
}
|
||||
|
||||
def create_application(self, payload: ApplicationCSPPayload):
|
||||
creds = payload.creds
|
||||
credentials = self._get_credential_obj(creds, resource=AZURE_MANAGEMENT_API)
|
||||
creds = self._source_creds(payload.tenant_id)
|
||||
credentials = self._get_credential_obj(
|
||||
{
|
||||
"client_id": creds.root_sp_client_id,
|
||||
"secret_key": creds.root_sp_key,
|
||||
"tenant_id": creds.root_tenant_id,
|
||||
},
|
||||
resource=AZURE_MANAGEMENT_API,
|
||||
)
|
||||
|
||||
response = self._create_management_group(
|
||||
credentials,
|
||||
@ -632,26 +642,23 @@ class AzureCloudProvider(CloudProviderInterface):
|
||||
"tenant_id": self.tenant_id,
|
||||
}
|
||||
|
||||
def get_credentials(self, scope="portfolio", tenant_id=None):
|
||||
"""
|
||||
This could be implemented to determine, based on type, whether to return creds for:
|
||||
- scope="atat": the ATAT main app registration in ATAT's home tenant
|
||||
- scope="tenantadmin": the tenant administrator credentials
|
||||
- scope="portfolio": the credentials for the ATAT SP in the portfolio tenant
|
||||
"""
|
||||
if scope == "atat":
|
||||
return self._root_creds
|
||||
elif scope == "tenantadmin":
|
||||
# magic with key vault happens
|
||||
return {
|
||||
"client_id": "some id",
|
||||
"secret_key": "very secret",
|
||||
"tenant_id": tenant_id,
|
||||
}
|
||||
elif scope == "portfolio":
|
||||
# magic with key vault happens
|
||||
return {
|
||||
"client_id": "some id",
|
||||
"secret_key": "very secret",
|
||||
"tenant_id": tenant_id,
|
||||
}
|
||||
def _source_creds(self, tenant_id=None) -> KeyVaultCredentials:
|
||||
if tenant_id:
|
||||
return self._source_tenant_creds(tenant_id)
|
||||
else:
|
||||
return KeyVaultCredentials(
|
||||
root_tenant_id=self._root_creds.get("tenant_id"),
|
||||
root_sp_client_id=self._root_creds.get("client_id"),
|
||||
root_sp_key=self._root_creds.get("secret_key"),
|
||||
)
|
||||
|
||||
def update_tenant_creds(self, tenant_id, secret):
|
||||
hashed = sha256_hex(tenant_id)
|
||||
self.set_secret(hashed, json.dumps(secret))
|
||||
|
||||
return secret
|
||||
|
||||
def _source_tenant_creds(self, tenant_id):
|
||||
hashed = sha256_hex(tenant_id)
|
||||
raw_creds = self.get_secret(hashed)
|
||||
return KeyVaultCredentials(**json.loads(raw_creds))
|
||||
|
@ -347,8 +347,12 @@ class MockCloudProvider(CloudProviderInterface):
|
||||
def create_application(self, payload: ApplicationCSPPayload):
|
||||
self._maybe_raise(self.UNAUTHORIZED_RATE, GeneralCSPException)
|
||||
|
||||
id_ = f"{AZURE_MGMNT_PATH}{payload.management_group_name}"
|
||||
return ApplicationCSPResult(id=id_)
|
||||
return ApplicationCSPResult(
|
||||
id=f"{AZURE_MGMNT_PATH}{payload.management_group_name}"
|
||||
)
|
||||
|
||||
def get_credentials(self, scope="portfolio", tenant_id=None):
|
||||
return self.root_creds()
|
||||
|
||||
def update_tenant_creds(self, tenant_id, secret):
|
||||
return secret
|
||||
|
@ -2,7 +2,7 @@ from typing import Dict, List, Optional
|
||||
import re
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, validator
|
||||
from pydantic import BaseModel, validator, root_validator
|
||||
|
||||
from atst.utils import snake_to_camel
|
||||
|
||||
@ -241,7 +241,7 @@ AZURE_MGMNT_PATH = "/providers/Microsoft.Management/managementGroups/"
|
||||
MANAGEMENT_GROUP_NAME_REGEX = "^[a-zA-Z0-9\-_\(\)\.]+$"
|
||||
|
||||
|
||||
class ManagementGroupCSPPayload(BaseCSPPayload):
|
||||
class ManagementGroupCSPPayload(AliasModel):
|
||||
"""
|
||||
:param: management_group_name: Just pass a UUID for this.
|
||||
:param: display_name: This can contain any character and
|
||||
@ -250,6 +250,7 @@ class ManagementGroupCSPPayload(BaseCSPPayload):
|
||||
i.e. /providers/Microsoft.Management/managementGroups/[management group ID]
|
||||
"""
|
||||
|
||||
tenant_id: str
|
||||
management_group_name: Optional[str]
|
||||
display_name: str
|
||||
parent_id: str
|
||||
@ -288,3 +289,55 @@ class ApplicationCSPPayload(ManagementGroupCSPPayload):
|
||||
|
||||
class ApplicationCSPResult(ManagementGroupCSPResponse):
|
||||
pass
|
||||
|
||||
|
||||
class KeyVaultCredentials(BaseModel):
|
||||
root_sp_client_id: Optional[str]
|
||||
root_sp_key: Optional[str]
|
||||
root_tenant_id: Optional[str]
|
||||
|
||||
tenant_id: Optional[str]
|
||||
|
||||
tenant_admin_username: Optional[str]
|
||||
tenant_admin_password: Optional[str]
|
||||
|
||||
tenant_sp_client_id: Optional[str]
|
||||
tenant_sp_key: Optional[str]
|
||||
|
||||
@root_validator(pre=True)
|
||||
def enforce_admin_creds(cls, values):
|
||||
tenant_id = values.get("tenant_id")
|
||||
username = values.get("tenant_admin_username")
|
||||
password = values.get("tenant_admin_password")
|
||||
if any([username, password]) and not all([tenant_id, username, password]):
|
||||
raise ValueError(
|
||||
"tenant_id, tenant_admin_username, and tenant_admin_password must all be set if any one is"
|
||||
)
|
||||
|
||||
return values
|
||||
|
||||
@root_validator(pre=True)
|
||||
def enforce_sp_creds(cls, values):
|
||||
tenant_id = values.get("tenant_id")
|
||||
client_id = values.get("tenant_sp_client_id")
|
||||
key = values.get("tenant_sp_key")
|
||||
if any([client_id, key]) and not all([tenant_id, client_id, key]):
|
||||
raise ValueError(
|
||||
"tenant_id, tenant_sp_client_id, and tenant_sp_key must all be set if any one is"
|
||||
)
|
||||
|
||||
return values
|
||||
|
||||
@root_validator(pre=True)
|
||||
def enforce_root_creds(cls, values):
|
||||
sp_creds = [
|
||||
values.get("root_tenant_id"),
|
||||
values.get("root_sp_client_id"),
|
||||
values.get("root_sp_key"),
|
||||
]
|
||||
if any(sp_creds) and not all(sp_creds):
|
||||
raise ValueError(
|
||||
"root_tenant_id, root_sp_client_id, and root_sp_key must all be set if any one is"
|
||||
)
|
||||
|
||||
return values
|
||||
|
@ -59,15 +59,14 @@ def do_create_application(csp: CloudProviderInterface, application_id=None):
|
||||
|
||||
with claim_for_update(application) as application:
|
||||
|
||||
if application.cloud_id is not None:
|
||||
if application.cloud_id:
|
||||
return
|
||||
|
||||
csp_details = application.portfolio.csp_data
|
||||
parent_id = csp_details.get("root_management_group_id")
|
||||
tenant_id = csp_details.get("tenant_id")
|
||||
creds = csp.get_credentials(tenant_id)
|
||||
payload = ApplicationCSPPayload(
|
||||
creds=creds, display_name=application.name, parent_id=parent_id
|
||||
tenant_id=tenant_id, display_name=application.name, parent_id=parent_id
|
||||
)
|
||||
|
||||
app_result = csp.create_application(payload)
|
||||
|
@ -175,7 +175,7 @@ class PortfolioStateMachine(
|
||||
tenant_id = new_creds.get("tenant_id")
|
||||
secret = self.csp.get_secret(tenant_id, new_creds)
|
||||
secret.update(new_creds)
|
||||
self.csp.set_secret(tenant_id, secret)
|
||||
self.csp.update_tenant_creds(tenant_id, secret)
|
||||
except PydanticValidationError as exc:
|
||||
app.logger.error(
|
||||
f"Failed to cast response to valid result class {self.__repr__()}:",
|
||||
|
@ -1,3 +1,4 @@
|
||||
import hashlib
|
||||
import re
|
||||
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
@ -41,3 +42,8 @@ def commit_or_raise_already_exists_error(message):
|
||||
except IntegrityError:
|
||||
db.session.rollback()
|
||||
raise AlreadyExistsError(message)
|
||||
|
||||
|
||||
def sha256_hex(string):
|
||||
hsh = hashlib.sha256(string.encode())
|
||||
return hsh.digest().hex()
|
||||
|
@ -1,5 +1,7 @@
|
||||
import pytest
|
||||
import json
|
||||
from uuid import uuid4
|
||||
from unittest.mock import Mock
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from tests.factories import ApplicationFactory, EnvironmentFactory
|
||||
from tests.mock_azure import AUTH_CREDENTIALS, mock_azure
|
||||
@ -84,13 +86,28 @@ def test_create_environment_succeeds(mock_azure: AzureCloudProvider):
|
||||
assert result.id == "Test Id"
|
||||
|
||||
|
||||
# mock the get_secret so it returns a JSON string
|
||||
MOCK_CREDS = {
|
||||
"tenant_id": str(uuid4()),
|
||||
"tenant_sp_client_id": str(uuid4()),
|
||||
"tenant_sp_key": "1234",
|
||||
}
|
||||
|
||||
|
||||
def mock_get_secret(azure, func):
|
||||
azure.get_secret = func
|
||||
|
||||
return azure
|
||||
|
||||
|
||||
def test_create_application_succeeds(mock_azure: AzureCloudProvider):
|
||||
application = ApplicationFactory.create()
|
||||
|
||||
mock_management_group_create(mock_azure, {"id": "Test Id"})
|
||||
|
||||
mock_azure = mock_get_secret(mock_azure, lambda *a, **k: json.dumps(MOCK_CREDS))
|
||||
|
||||
payload = ApplicationCSPPayload(
|
||||
creds={}, display_name=application.name, parent_id=str(uuid4())
|
||||
tenant_id="1234", display_name=application.name, parent_id=str(uuid4())
|
||||
)
|
||||
result = mock_azure.create_application(payload)
|
||||
|
||||
|
@ -4,6 +4,7 @@ from pydantic import ValidationError
|
||||
|
||||
from atst.domain.csp.cloud.models import (
|
||||
AZURE_MGMNT_PATH,
|
||||
KeyVaultCredentials,
|
||||
ManagementGroupCSPPayload,
|
||||
ManagementGroupCSPResponse,
|
||||
)
|
||||
@ -12,25 +13,25 @@ from atst.domain.csp.cloud.models import (
|
||||
def test_ManagementGroupCSPPayload_management_group_name():
|
||||
# supplies management_group_name when absent
|
||||
payload = ManagementGroupCSPPayload(
|
||||
creds={}, display_name="Council of Naboo", parent_id="Galactic_Senate"
|
||||
tenant_id="any-old-id",
|
||||
display_name="Council of Naboo",
|
||||
parent_id="Galactic_Senate",
|
||||
)
|
||||
assert payload.management_group_name
|
||||
# validates management_group_name
|
||||
with pytest.raises(ValidationError):
|
||||
payload = ManagementGroupCSPPayload(
|
||||
creds={},
|
||||
tenant_id="any-old-id",
|
||||
management_group_name="council of Naboo 1%^&",
|
||||
display_name="Council of Naboo",
|
||||
parent_id="Galactic_Senate",
|
||||
)
|
||||
# shortens management_group_name to fit
|
||||
name = "council_of_naboo"
|
||||
for _ in range(90):
|
||||
name = f"{name}1"
|
||||
name = "council_of_naboo".ljust(95, "1")
|
||||
|
||||
assert len(name) > 90
|
||||
payload = ManagementGroupCSPPayload(
|
||||
creds={},
|
||||
tenant_id="any-old-id",
|
||||
management_group_name=name,
|
||||
display_name="Council of Naboo",
|
||||
parent_id="Galactic_Senate",
|
||||
@ -40,12 +41,10 @@ def test_ManagementGroupCSPPayload_management_group_name():
|
||||
|
||||
def test_ManagementGroupCSPPayload_display_name():
|
||||
# shortens display_name to fit
|
||||
name = "Council of Naboo"
|
||||
for _ in range(90):
|
||||
name = f"{name}1"
|
||||
name = "Council of Naboo".ljust(95, "1")
|
||||
assert len(name) > 90
|
||||
payload = ManagementGroupCSPPayload(
|
||||
creds={}, display_name=name, parent_id="Galactic_Senate"
|
||||
tenant_id="any-old-id", display_name=name, parent_id="Galactic_Senate"
|
||||
)
|
||||
assert len(payload.display_name) == 90
|
||||
|
||||
@ -54,12 +53,14 @@ def test_ManagementGroupCSPPayload_parent_id():
|
||||
full_path = f"{AZURE_MGMNT_PATH}Galactic_Senate"
|
||||
# adds full path
|
||||
payload = ManagementGroupCSPPayload(
|
||||
creds={}, display_name="Council of Naboo", parent_id="Galactic_Senate"
|
||||
tenant_id="any-old-id",
|
||||
display_name="Council of Naboo",
|
||||
parent_id="Galactic_Senate",
|
||||
)
|
||||
assert payload.parent_id == full_path
|
||||
# keeps full path
|
||||
payload = ManagementGroupCSPPayload(
|
||||
creds={}, display_name="Council of Naboo", parent_id=full_path
|
||||
tenant_id="any-old-id", display_name="Council of Naboo", parent_id=full_path
|
||||
)
|
||||
assert payload.parent_id == full_path
|
||||
|
||||
@ -70,3 +71,29 @@ def test_ManagementGroupCSPResponse_id():
|
||||
**{"id": "/path/to/naboo-123", "other": "stuff"}
|
||||
)
|
||||
assert response.id == full_id
|
||||
|
||||
|
||||
def test_KeyVaultCredentials_enforce_admin_creds():
|
||||
with pytest.raises(ValidationError):
|
||||
KeyVaultCredentials(tenant_id="an id", tenant_admin_username="C3PO")
|
||||
assert KeyVaultCredentials(
|
||||
tenant_id="an id",
|
||||
tenant_admin_username="C3PO",
|
||||
tenant_admin_password="beep boop",
|
||||
)
|
||||
|
||||
|
||||
def test_KeyVaultCredentials_enforce_sp_creds():
|
||||
with pytest.raises(ValidationError):
|
||||
KeyVaultCredentials(tenant_id="an id", tenant_sp_client_id="C3PO")
|
||||
assert KeyVaultCredentials(
|
||||
tenant_id="an id", tenant_sp_client_id="C3PO", tenant_sp_key="beep boop"
|
||||
)
|
||||
|
||||
|
||||
def test_KeyVaultCredentials_enforce_root_creds():
|
||||
with pytest.raises(ValidationError):
|
||||
KeyVaultCredentials(root_tenant_id="an id", root_sp_client_id="C3PO")
|
||||
assert KeyVaultCredentials(
|
||||
root_tenant_id="an id", root_sp_client_id="C3PO", root_sp_key="beep boop"
|
||||
)
|
@ -72,6 +72,12 @@ def mock_secrets():
|
||||
return Mock(spec=secrets)
|
||||
|
||||
|
||||
def mock_identity():
|
||||
import azure.identity as identity
|
||||
|
||||
return Mock(spec=identity)
|
||||
|
||||
|
||||
class MockAzureSDK(object):
|
||||
def __init__(self):
|
||||
from msrestazure.azure_cloud import AZURE_PUBLIC_CLOUD
|
||||
@ -88,6 +94,7 @@ class MockAzureSDK(object):
|
||||
self.requests = mock_requests()
|
||||
# may change to a JEDI cloud
|
||||
self.cloud = AZURE_PUBLIC_CLOUD
|
||||
self.identity = mock_identity()
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
|
16
tests/utils/test_hash.py
Normal file
16
tests/utils/test_hash.py
Normal file
@ -0,0 +1,16 @@
|
||||
import random
|
||||
import re
|
||||
import string
|
||||
|
||||
from atst.utils import sha256_hex
|
||||
|
||||
|
||||
def test_sha256_hex():
|
||||
sample = "".join(
|
||||
random.choices(string.ascii_uppercase + string.digits, k=random.randrange(200))
|
||||
)
|
||||
hashed = sha256_hex(sample)
|
||||
assert re.match("^[a-zA-Z0-9]+$", hashed)
|
||||
assert len(hashed) == 64
|
||||
hashed_again = sha256_hex(sample)
|
||||
assert hashed == hashed_again
|
Loading…
x
Reference in New Issue
Block a user