diff --git a/atst/domain/csp/cloud/azure_cloud_provider.py b/atst/domain/csp/cloud/azure_cloud_provider.py index af893df2..1ab619b2 100644 --- a/atst/domain/csp/cloud/azure_cloud_provider.py +++ b/atst/domain/csp/cloud/azure_cloud_provider.py @@ -1,5 +1,5 @@ import json -from secrets import token_urlsafe +from secrets import token_hex, token_urlsafe from uuid import uuid4 from flask import current_app as app @@ -12,6 +12,7 @@ from .exceptions import ( ConnectionException, UnknownServerException, SecretException, + DomainNameException, ) from .models import ( AdminRoleDefinitionCSPPayload, @@ -346,12 +347,32 @@ class AzureCloudProvider(CloudProviderInterface): result.status_code, f"azure application error disable user. {str(exc)}", ) + def validate_domain_name(self, name): + response = self.sdk.requests.get( + f"{self.sdk.cloud.endpoints.active_directory}/{name}.onmicrosoft.com/.well-known/openid-configuration", + timeout=30, + ) + # If an existing tenant with name cannot be found, 'error' will be in the response + return "error" in response.json() + + def generate_valid_domain_name(self, name, suffix="", max_tries=6): + if max_tries > 0: + domain_name = name + suffix + if self.validate_domain_name(domain_name): + return domain_name + else: + suffix = token_hex(3) + return self.generate_valid_domain_name(name, suffix, max_tries - 1) + else: + raise DomainNameException(name) + def create_tenant(self, payload: TenantCSPPayload): sp_token = self._get_root_provisioning_token() if sp_token is None: raise AuthenticationException("Could not resolve token for tenant creation") payload.password = token_urlsafe(16) + payload.domain_name = self.generate_valid_domain_name(payload.domain_name) create_tenant_body = payload.dict(by_alias=True) create_tenant_headers = { diff --git a/atst/domain/csp/cloud/exceptions.py b/atst/domain/csp/cloud/exceptions.py index 57eddd5d..98e74468 100644 --- a/atst/domain/csp/cloud/exceptions.py +++ b/atst/domain/csp/cloud/exceptions.py @@ -133,3 +133,14 @@ class SecretException(GeneralCSPException): return "Could not get or set secret for ({}): {}".format( self.tenant_id, self.reason ) + + +class DomainNameException(GeneralCSPException): + """A problem occured when generating the domain name for a tenant""" + + def __init__(self, name): + self.name = name + + @property + def message(self): + return f"Could not generate unique tenant name for {self.name}" diff --git a/atst/models/portfolio.py b/atst/models/portfolio.py index fdf5613b..8131d939 100644 --- a/atst/models/portfolio.py +++ b/atst/models/portfolio.py @@ -1,7 +1,5 @@ import re from itertools import chain -from random import choices -from string import ascii_lowercase, digits from typing import Dict from sqlalchemy import Column, String @@ -191,9 +189,7 @@ class Portfolio( CSP domain name associated with portfolio. If a domain name is not set, generate one. """ - domain_name = re.sub("[^0-9a-zA-Z]+", "", self.name).lower() + "".join( - choices(ascii_lowercase + digits, k=4) - ) + domain_name = re.sub("[^0-9a-zA-Z]+", "", self.name).lower() if self.csp_data: return self.csp_data.get("domain_name", domain_name) else: diff --git a/tests/domain/cloud/test_azure_csp.py b/tests/domain/cloud/test_azure_csp.py index de085100..62f2055e 100644 --- a/tests/domain/cloud/test_azure_csp.py +++ b/tests/domain/cloud/test_azure_csp.py @@ -14,6 +14,7 @@ from atst.domain.csp.cloud.exceptions import ( ConnectionException, UnknownServerException, SecretException, + DomainNameException, ) from atst.domain.csp.cloud import AzureCloudProvider from atst.domain.csp.cloud.models import ( @@ -1588,3 +1589,44 @@ def test_get_calculator_url(mock_azure: AzureCloudProvider): mock_azure.get_calculator_url() == f"{mock_azure.config.get('AZURE_CALC_URL')}?access_token=TOKEN" ) + + +class TestGenerateValidDomainName: + def test_success(self, monkeypatch, mock_azure: AzureCloudProvider): + tenant_name = "tenant" + + def _validate_domain_name(mock_azure, name): + return True + + monkeypatch.setattr( + "atst.domain.csp.cloud.AzureCloudProvider.validate_domain_name", + _validate_domain_name, + ) + assert mock_azure.generate_valid_domain_name(tenant_name) == tenant_name + + def test_failure_after_max_tries(self, monkeypatch, mock_azure: AzureCloudProvider): + def _validate_domain_name(mock_azure, name): + return False + + monkeypatch.setattr( + "atst.domain.csp.cloud.AzureCloudProvider.validate_domain_name", + _validate_domain_name, + ) + with pytest.raises(DomainNameException): + mock_azure.generate_valid_domain_name(name="test", max_tries=3) + + def test_unique(self, monkeypatch, mock_azure: AzureCloudProvider): + # mock that a tenant exists with the name tenant_name + tenant_name = "tenant" + + def _validate_domain_name(mock_azure, name): + if name == tenant_name: + return False + else: + return True + + monkeypatch.setattr( + "atst.domain.csp.cloud.AzureCloudProvider.validate_domain_name", + _validate_domain_name, + ) + assert mock_azure.generate_valid_domain_name(tenant_name) != tenant_name