Merge pull request #1380 from dod-ccpo/azure-user-creation
Azure user creation
This commit is contained in:
commit
8b0c28b09f
@ -0,0 +1,29 @@
|
|||||||
|
"""add application_role.cloud_id
|
||||||
|
|
||||||
|
Revision ID: 17da2a475429
|
||||||
|
Revises: 50979d8ef680
|
||||||
|
Create Date: 2020-02-01 10:43:03.073539
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = '17da2a475429' # pragma: allowlist secret
|
||||||
|
down_revision = '50979d8ef680' # pragma: allowlist secret
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.add_column('application_roles', sa.Column('cloud_id', sa.String(), nullable=True))
|
||||||
|
op.add_column('application_roles', sa.Column('claimed_until', sa.TIMESTAMP(timezone=True), nullable=True))
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_column('application_roles', 'cloud_id')
|
||||||
|
op.drop_column('application_roles', 'claimed_until')
|
||||||
|
# ### end Alembic commands ###
|
@ -1,8 +1,12 @@
|
|||||||
|
from itertools import groupby
|
||||||
|
from typing import List
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
from sqlalchemy.orm.exc import NoResultFound
|
from sqlalchemy.orm.exc import NoResultFound
|
||||||
|
|
||||||
from atst.database import db
|
from atst.database import db
|
||||||
from atst.domain.environment_roles import EnvironmentRoles
|
from atst.domain.environment_roles import EnvironmentRoles
|
||||||
from atst.models import ApplicationRole, ApplicationRoleStatus
|
from atst.models import Application, ApplicationRole, ApplicationRoleStatus, Portfolio
|
||||||
from .permission_sets import PermissionSets
|
from .permission_sets import PermissionSets
|
||||||
from .exceptions import NotFoundError
|
from .exceptions import NotFoundError
|
||||||
|
|
||||||
@ -61,6 +65,15 @@ class ApplicationRoles(object):
|
|||||||
except NoResultFound:
|
except NoResultFound:
|
||||||
raise NotFoundError("application_role")
|
raise NotFoundError("application_role")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_many(cls, ids):
|
||||||
|
return (
|
||||||
|
db.session.query(ApplicationRole)
|
||||||
|
.filter(ApplicationRole.id.in_(ids))
|
||||||
|
.filter(ApplicationRole.status != ApplicationRoleStatus.DISABLED)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def update_permission_sets(cls, application_role, new_perm_sets_names):
|
def update_permission_sets(cls, application_role, new_perm_sets_names):
|
||||||
application_role.permission_sets = ApplicationRoles._permission_sets_for_names(
|
application_role.permission_sets = ApplicationRoles._permission_sets_for_names(
|
||||||
@ -92,3 +105,29 @@ class ApplicationRoles(object):
|
|||||||
|
|
||||||
db.session.add(application_role)
|
db.session.add(application_role)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_pending_creation(cls) -> List[List[UUID]]:
|
||||||
|
"""
|
||||||
|
Returns a list of lists of ApplicationRole IDs. The IDs
|
||||||
|
should be grouped by user and portfolio.
|
||||||
|
"""
|
||||||
|
results = (
|
||||||
|
db.session.query(ApplicationRole.id, ApplicationRole.user_id, Portfolio.id)
|
||||||
|
.join(Application, Application.id == ApplicationRole.application_id)
|
||||||
|
.join(Portfolio, Portfolio.id == Application.portfolio_id)
|
||||||
|
.filter(Application.cloud_id.isnot(None))
|
||||||
|
.filter(ApplicationRole.deleted == False)
|
||||||
|
.filter(ApplicationRole.cloud_id.is_(None))
|
||||||
|
.filter(ApplicationRole.user_id.isnot(None))
|
||||||
|
.filter(ApplicationRole.status == ApplicationRoleStatus.ACTIVE)
|
||||||
|
).all()
|
||||||
|
|
||||||
|
groups = []
|
||||||
|
keyfunc = lambda pair: (pair[1], pair[2])
|
||||||
|
sorted_results = sorted(results, key=keyfunc)
|
||||||
|
for _, g in groupby(sorted_results, keyfunc):
|
||||||
|
group = [pair[0] for pair in list(g)]
|
||||||
|
groups.append(group)
|
||||||
|
|
||||||
|
return groups
|
||||||
|
@ -6,7 +6,7 @@ from uuid import uuid4
|
|||||||
from atst.utils import sha256_hex
|
from atst.utils import sha256_hex
|
||||||
|
|
||||||
from .cloud_provider_interface import CloudProviderInterface
|
from .cloud_provider_interface import CloudProviderInterface
|
||||||
from .exceptions import AuthenticationException
|
from .exceptions import AuthenticationException, UserProvisioningException
|
||||||
from .models import (
|
from .models import (
|
||||||
SubscriptionCreationCSPPayload,
|
SubscriptionCreationCSPPayload,
|
||||||
SubscriptionCreationCSPResult,
|
SubscriptionCreationCSPResult,
|
||||||
@ -48,6 +48,8 @@ from .models import (
|
|||||||
TenantPrincipalCSPResult,
|
TenantPrincipalCSPResult,
|
||||||
TenantPrincipalOwnershipCSPPayload,
|
TenantPrincipalOwnershipCSPPayload,
|
||||||
TenantPrincipalOwnershipCSPResult,
|
TenantPrincipalOwnershipCSPResult,
|
||||||
|
UserCSPPayload,
|
||||||
|
UserCSPResult,
|
||||||
)
|
)
|
||||||
from .policy import AzurePolicyManager
|
from .policy import AzurePolicyManager
|
||||||
|
|
||||||
@ -193,9 +195,9 @@ class AzureCloudProvider(CloudProviderInterface):
|
|||||||
creds = self._source_creds(payload.tenant_id)
|
creds = self._source_creds(payload.tenant_id)
|
||||||
credentials = self._get_credential_obj(
|
credentials = self._get_credential_obj(
|
||||||
{
|
{
|
||||||
"client_id": creds.root_sp_client_id,
|
"client_id": creds.tenant_sp_client_id,
|
||||||
"secret_key": creds.root_sp_key,
|
"secret_key": creds.tenant_sp_key,
|
||||||
"tenant_id": creds.root_tenant_id,
|
"tenant_id": creds.tenant_id,
|
||||||
},
|
},
|
||||||
resource=self.sdk.cloud.endpoints.resource_manager,
|
resource=self.sdk.cloud.endpoints.resource_manager,
|
||||||
)
|
)
|
||||||
@ -310,7 +312,9 @@ class AzureCloudProvider(CloudProviderInterface):
|
|||||||
tenant_admin_password=payload.password,
|
tenant_admin_password=payload.password,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
return self._ok(TenantCSPResult(**result_dict))
|
return self._ok(
|
||||||
|
TenantCSPResult(domain_name=payload.domain_name, **result_dict)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return self._error(result.json())
|
return self._error(result.json())
|
||||||
|
|
||||||
@ -850,6 +854,80 @@ class AzureCloudProvider(CloudProviderInterface):
|
|||||||
|
|
||||||
return service_principal
|
return service_principal
|
||||||
|
|
||||||
|
def create_user(self, payload: UserCSPPayload) -> UserCSPResult:
|
||||||
|
"""Create a user in an Azure Active Directory instance.
|
||||||
|
Unlike most of the methods on this interface, this requires
|
||||||
|
two API calls: one POST to create the user and one PATCH to
|
||||||
|
set the alternate email address. The email address cannot
|
||||||
|
be set on the first API call. The email address is
|
||||||
|
necessary so that users can do Self-Service Password
|
||||||
|
Recovery.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
payload {UserCSPPayload} -- a payload object with the
|
||||||
|
data necessary for both calls
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
UserCSPResult -- a result object containing the AAD ID.
|
||||||
|
"""
|
||||||
|
graph_token = self._get_tenant_principal_token(
|
||||||
|
payload.tenant_id, resource=self.graph_resource
|
||||||
|
)
|
||||||
|
if graph_token is None:
|
||||||
|
raise AuthenticationException(
|
||||||
|
"Could not resolve graph token for tenant admin"
|
||||||
|
)
|
||||||
|
|
||||||
|
result = self._create_active_directory_user(graph_token, payload)
|
||||||
|
self._update_active_directory_user_email(graph_token, result.id, payload)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _create_active_directory_user(self, graph_token, payload: UserCSPPayload):
|
||||||
|
request_body = {
|
||||||
|
"accountEnabled": True,
|
||||||
|
"displayName": payload.display_name,
|
||||||
|
"mailNickname": payload.mail_nickname,
|
||||||
|
"userPrincipalName": payload.user_principal_name,
|
||||||
|
"passwordProfile": {
|
||||||
|
"forceChangePasswordNextSignIn": True,
|
||||||
|
"password": payload.password,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
auth_header = {
|
||||||
|
"Authorization": f"Bearer {graph_token}",
|
||||||
|
}
|
||||||
|
|
||||||
|
url = f"{self.graph_resource}v1.0/users"
|
||||||
|
|
||||||
|
response = self.sdk.requests.post(url, headers=auth_header, json=request_body)
|
||||||
|
|
||||||
|
if response.ok:
|
||||||
|
return UserCSPResult(**response.json())
|
||||||
|
else:
|
||||||
|
raise UserProvisioningException(f"Failed to create user: {response.json()}")
|
||||||
|
|
||||||
|
def _update_active_directory_user_email(
|
||||||
|
self, graph_token, user_id, payload: UserCSPPayload
|
||||||
|
):
|
||||||
|
request_body = {"otherMails": [payload.email]}
|
||||||
|
|
||||||
|
auth_header = {
|
||||||
|
"Authorization": f"Bearer {graph_token}",
|
||||||
|
}
|
||||||
|
|
||||||
|
url = f"{self.graph_resource}v1.0/users/{user_id}"
|
||||||
|
|
||||||
|
response = self.sdk.requests.patch(url, headers=auth_header, json=request_body)
|
||||||
|
|
||||||
|
if response.ok:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
raise UserProvisioningException(
|
||||||
|
f"Failed update user email: {response.json()}"
|
||||||
|
)
|
||||||
|
|
||||||
def _extract_subscription_id(self, subscription_url):
|
def _extract_subscription_id(self, subscription_url):
|
||||||
sub_id_match = SUBSCRIPTION_ID_REGEX.match(subscription_url)
|
sub_id_match = SUBSCRIPTION_ID_REGEX.match(subscription_url)
|
||||||
|
|
||||||
@ -871,14 +949,15 @@ class AzureCloudProvider(CloudProviderInterface):
|
|||||||
creds.root_tenant_id, creds.root_sp_client_id, creds.root_sp_key
|
creds.root_tenant_id, creds.root_sp_client_id, creds.root_sp_key
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_sp_token(self, tenant_id, client_id, secret_key):
|
def _get_sp_token(self, tenant_id, client_id, secret_key, resource=None):
|
||||||
context = self.sdk.adal.AuthenticationContext(
|
context = self.sdk.adal.AuthenticationContext(
|
||||||
f"{self.sdk.cloud.endpoints.active_directory}/{tenant_id}"
|
f"{self.sdk.cloud.endpoints.active_directory}/{tenant_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
resource = resource or self.sdk.cloud.endpoints.resource_manager
|
||||||
# TODO: handle failure states here
|
# TODO: handle failure states here
|
||||||
token_response = context.acquire_token_with_client_credentials(
|
token_response = context.acquire_token_with_client_credentials(
|
||||||
self.sdk.cloud.endpoints.resource_manager, client_id, secret_key
|
resource, client_id, secret_key
|
||||||
)
|
)
|
||||||
|
|
||||||
return token_response.get("accessToken", None)
|
return token_response.get("accessToken", None)
|
||||||
@ -939,10 +1018,13 @@ class AzureCloudProvider(CloudProviderInterface):
|
|||||||
"tenant_id": self.tenant_id,
|
"tenant_id": self.tenant_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
def _get_tenant_principal_token(self, tenant_id):
|
def _get_tenant_principal_token(self, tenant_id, resource=None):
|
||||||
creds = self._source_creds(tenant_id)
|
creds = self._source_creds(tenant_id)
|
||||||
return self._get_sp_token(
|
return self._get_sp_token(
|
||||||
creds.tenant_id, creds.tenant_sp_client_id, creds.tenant_sp_key
|
creds.tenant_id,
|
||||||
|
creds.tenant_sp_client_id,
|
||||||
|
creds.tenant_sp_key,
|
||||||
|
resource=resource,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_elevated_management_token(self, tenant_id):
|
def _get_elevated_management_token(self, tenant_id):
|
||||||
|
@ -88,17 +88,6 @@ class UserProvisioningException(GeneralCSPException):
|
|||||||
"""Failed to provision a user
|
"""Failed to provision a user
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, env_identifier, user_identifier, reason):
|
|
||||||
self.env_identifier = env_identifier
|
|
||||||
self.user_identifier = user_identifier
|
|
||||||
self.reason = reason
|
|
||||||
|
|
||||||
@property
|
|
||||||
def message(self):
|
|
||||||
return "Failed to create user {} for environment {}: {}".format(
|
|
||||||
self.user_identifier, self.env_identifier, self.reason
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class UserRemovalException(GeneralCSPException):
|
class UserRemovalException(GeneralCSPException):
|
||||||
"""Failed to remove a user
|
"""Failed to remove a user
|
||||||
|
@ -51,6 +51,8 @@ from .models import (
|
|||||||
TenantPrincipalCSPResult,
|
TenantPrincipalCSPResult,
|
||||||
TenantPrincipalOwnershipCSPPayload,
|
TenantPrincipalOwnershipCSPPayload,
|
||||||
TenantPrincipalOwnershipCSPResult,
|
TenantPrincipalOwnershipCSPResult,
|
||||||
|
UserCSPPayload,
|
||||||
|
UserCSPResult,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -175,6 +177,7 @@ class MockCloudProvider(CloudProviderInterface):
|
|||||||
"tenant_id": "",
|
"tenant_id": "",
|
||||||
"user_id": "",
|
"user_id": "",
|
||||||
"user_object_id": "",
|
"user_object_id": "",
|
||||||
|
"domain_name": "",
|
||||||
"tenant_admin_username": "test",
|
"tenant_admin_username": "test",
|
||||||
"tenant_admin_password": "test",
|
"tenant_admin_password": "test",
|
||||||
}
|
}
|
||||||
@ -474,6 +477,11 @@ class MockCloudProvider(CloudProviderInterface):
|
|||||||
id=f"{AZURE_MGMNT_PATH}{payload.management_group_name}"
|
id=f"{AZURE_MGMNT_PATH}{payload.management_group_name}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def create_user(self, payload: UserCSPPayload):
|
||||||
|
self._maybe_raise(self.UNAUTHORIZED_RATE, GeneralCSPException)
|
||||||
|
|
||||||
|
return UserCSPResult(id=str(uuid4()))
|
||||||
|
|
||||||
def get_credentials(self, scope="portfolio", tenant_id=None):
|
def get_credentials(self, scope="portfolio", tenant_id=None):
|
||||||
return self.root_creds()
|
return self.root_creds()
|
||||||
|
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
|
from secrets import token_urlsafe
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
import re
|
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
import re
|
||||||
|
|
||||||
from pydantic import BaseModel, validator, root_validator
|
from pydantic import BaseModel, validator, root_validator
|
||||||
|
|
||||||
@ -39,6 +40,7 @@ class TenantCSPResult(AliasModel):
|
|||||||
user_id: str
|
user_id: str
|
||||||
tenant_id: str
|
tenant_id: str
|
||||||
user_object_id: str
|
user_object_id: str
|
||||||
|
domain_name: str
|
||||||
|
|
||||||
tenant_admin_username: Optional[str]
|
tenant_admin_username: Optional[str]
|
||||||
tenant_admin_password: Optional[str]
|
tenant_admin_password: Optional[str]
|
||||||
@ -474,3 +476,26 @@ class ProductPurchaseVerificationCSPPayload(BaseCSPPayload):
|
|||||||
|
|
||||||
class ProductPurchaseVerificationCSPResult(AliasModel):
|
class ProductPurchaseVerificationCSPResult(AliasModel):
|
||||||
premium_purchase_date: str
|
premium_purchase_date: str
|
||||||
|
|
||||||
|
|
||||||
|
class UserCSPPayload(BaseCSPPayload):
|
||||||
|
display_name: str
|
||||||
|
tenant_host_name: str
|
||||||
|
email: str
|
||||||
|
password: Optional[str]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def user_principal_name(self):
|
||||||
|
return f"{self.mail_nickname}@{self.tenant_host_name}.onmicrosoft.com"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def mail_nickname(self):
|
||||||
|
return self.display_name.replace(" ", ".").lower()
|
||||||
|
|
||||||
|
@validator("password", pre=True, always=True)
|
||||||
|
def supply_password_default(cls, password):
|
||||||
|
return password or token_urlsafe(16)
|
||||||
|
|
||||||
|
|
||||||
|
class UserCSPResult(AliasModel):
|
||||||
|
id: str
|
||||||
|
79
atst/jobs.py
79
atst/jobs.py
@ -3,16 +3,16 @@ import pendulum
|
|||||||
|
|
||||||
from atst.database import db
|
from atst.database import db
|
||||||
from atst.queue import celery
|
from atst.queue import celery
|
||||||
from atst.models import EnvironmentRole, JobFailure
|
from atst.models import JobFailure
|
||||||
from atst.domain.csp.cloud.exceptions import GeneralCSPException
|
from atst.domain.csp.cloud.exceptions import GeneralCSPException
|
||||||
from atst.domain.csp.cloud import CloudProviderInterface
|
from atst.domain.csp.cloud import CloudProviderInterface
|
||||||
from atst.domain.applications import Applications
|
from atst.domain.applications import Applications
|
||||||
from atst.domain.environments import Environments
|
from atst.domain.environments import Environments
|
||||||
from atst.domain.portfolios import Portfolios
|
from atst.domain.portfolios import Portfolios
|
||||||
from atst.domain.environment_roles import EnvironmentRoles
|
from atst.domain.application_roles import ApplicationRoles
|
||||||
from atst.models.utils import claim_for_update
|
from atst.models.utils import claim_for_update, claim_many_for_update
|
||||||
from atst.utils.localization import translate
|
from atst.utils.localization import translate
|
||||||
from atst.domain.csp.cloud.models import ApplicationCSPPayload
|
from atst.domain.csp.cloud.models import ApplicationCSPPayload, UserCSPPayload
|
||||||
|
|
||||||
|
|
||||||
class RecordFailure(celery.Task):
|
class RecordFailure(celery.Task):
|
||||||
@ -75,6 +75,34 @@ def do_create_application(csp: CloudProviderInterface, application_id=None):
|
|||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
|
|
||||||
|
def do_create_user(csp: CloudProviderInterface, application_role_ids=None):
|
||||||
|
if not application_role_ids:
|
||||||
|
return
|
||||||
|
|
||||||
|
app_roles = ApplicationRoles.get_many(application_role_ids)
|
||||||
|
|
||||||
|
with claim_many_for_update(app_roles) as app_roles:
|
||||||
|
|
||||||
|
if any([ar.cloud_id for ar in app_roles]):
|
||||||
|
return
|
||||||
|
|
||||||
|
csp_details = app_roles[0].application.portfolio.csp_data
|
||||||
|
user = app_roles[0].user
|
||||||
|
|
||||||
|
payload = UserCSPPayload(
|
||||||
|
tenant_id=csp_details.get("tenant_id"),
|
||||||
|
tenant_host_name=csp_details.get("domain_name"),
|
||||||
|
display_name=user.full_name,
|
||||||
|
email=user.email,
|
||||||
|
)
|
||||||
|
result = csp.create_user(payload)
|
||||||
|
for app_role in app_roles:
|
||||||
|
app_role.cloud_id = result.id
|
||||||
|
db.session.add(app_role)
|
||||||
|
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
|
||||||
def do_create_environment(csp: CloudProviderInterface, environment_id=None):
|
def do_create_environment(csp: CloudProviderInterface, environment_id=None):
|
||||||
environment = Environments.get(environment_id)
|
environment = Environments.get(environment_id)
|
||||||
|
|
||||||
@ -128,21 +156,6 @@ def render_email(template_path, context):
|
|||||||
return app.jinja_env.get_template(template_path).render(context)
|
return app.jinja_env.get_template(template_path).render(context)
|
||||||
|
|
||||||
|
|
||||||
def do_provision_user(csp: CloudProviderInterface, environment_role_id=None):
|
|
||||||
environment_role = EnvironmentRoles.get_by_id(environment_role_id)
|
|
||||||
|
|
||||||
with claim_for_update(environment_role) as environment_role:
|
|
||||||
credentials = environment_role.environment.csp_credentials
|
|
||||||
|
|
||||||
csp_user_id = csp.create_or_update_user(
|
|
||||||
credentials, environment_role, environment_role.role
|
|
||||||
)
|
|
||||||
environment_role.csp_user_id = csp_user_id
|
|
||||||
environment_role.status = EnvironmentRole.Status.COMPLETED
|
|
||||||
db.session.add(environment_role)
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
|
|
||||||
def do_work(fn, task, csp, **kwargs):
|
def do_work(fn, task, csp, **kwargs):
|
||||||
try:
|
try:
|
||||||
fn(csp, **kwargs)
|
fn(csp, **kwargs)
|
||||||
@ -166,6 +179,13 @@ def create_application(self, application_id=None):
|
|||||||
do_work(do_create_application, self, app.csp.cloud, application_id=application_id)
|
do_work(do_create_application, self, app.csp.cloud, application_id=application_id)
|
||||||
|
|
||||||
|
|
||||||
|
@celery.task(bind=True, base=RecordFailure)
|
||||||
|
def create_user(self, application_role_ids=None):
|
||||||
|
do_work(
|
||||||
|
do_create_user, self, app.csp.cloud, application_role_ids=application_role_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@celery.task(bind=True, base=RecordFailure)
|
@celery.task(bind=True, base=RecordFailure)
|
||||||
def create_environment(self, environment_id=None):
|
def create_environment(self, environment_id=None):
|
||||||
do_work(do_create_environment, self, app.csp.cloud, environment_id=environment_id)
|
do_work(do_create_environment, self, app.csp.cloud, environment_id=environment_id)
|
||||||
@ -178,13 +198,6 @@ def create_atat_admin_user(self, environment_id=None):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@celery.task(bind=True)
|
|
||||||
def provision_user(self, environment_role_id=None):
|
|
||||||
do_work(
|
|
||||||
do_provision_user, self, app.csp.cloud, environment_role_id=environment_role_id
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@celery.task(bind=True)
|
@celery.task(bind=True)
|
||||||
def dispatch_provision_portfolio(self):
|
def dispatch_provision_portfolio(self):
|
||||||
"""
|
"""
|
||||||
@ -200,6 +213,12 @@ def dispatch_create_application(self):
|
|||||||
create_application.delay(application_id=application_id)
|
create_application.delay(application_id=application_id)
|
||||||
|
|
||||||
|
|
||||||
|
@celery.task(bind=True)
|
||||||
|
def dispatch_create_user(self):
|
||||||
|
for application_role_ids in ApplicationRoles.get_pending_creation():
|
||||||
|
create_user.delay(application_role_ids=application_role_ids)
|
||||||
|
|
||||||
|
|
||||||
@celery.task(bind=True)
|
@celery.task(bind=True)
|
||||||
def dispatch_create_environment(self):
|
def dispatch_create_environment(self):
|
||||||
for environment_id in Environments.get_environments_pending_creation(
|
for environment_id in Environments.get_environments_pending_creation(
|
||||||
@ -214,11 +233,3 @@ def dispatch_create_atat_admin_user(self):
|
|||||||
pendulum.now()
|
pendulum.now()
|
||||||
):
|
):
|
||||||
create_atat_admin_user.delay(environment_id=environment_id)
|
create_atat_admin_user.delay(environment_id=environment_id)
|
||||||
|
|
||||||
|
|
||||||
@celery.task(bind=True)
|
|
||||||
def dispatch_provision_user(self):
|
|
||||||
for (
|
|
||||||
environment_role_id
|
|
||||||
) in EnvironmentRoles.get_environment_roles_pending_creation():
|
|
||||||
provision_user.delay(environment_role_id=environment_role_id)
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from sqlalchemy import and_, Column, ForeignKey, String, UniqueConstraint, TIMESTAMP
|
from sqlalchemy import and_, Column, ForeignKey, String, UniqueConstraint
|
||||||
from sqlalchemy.orm import relationship, synonym
|
from sqlalchemy.orm import relationship, synonym
|
||||||
|
|
||||||
from atst.models.base import Base
|
from atst.models.base import Base
|
||||||
@ -9,7 +9,11 @@ from atst.models.types import Id
|
|||||||
|
|
||||||
|
|
||||||
class Application(
|
class Application(
|
||||||
Base, mixins.TimestampsMixin, mixins.AuditableMixin, mixins.DeletableMixin
|
Base,
|
||||||
|
mixins.TimestampsMixin,
|
||||||
|
mixins.AuditableMixin,
|
||||||
|
mixins.DeletableMixin,
|
||||||
|
mixins.ClaimableMixin,
|
||||||
):
|
):
|
||||||
__tablename__ = "applications"
|
__tablename__ = "applications"
|
||||||
|
|
||||||
@ -41,7 +45,6 @@ class Application(
|
|||||||
)
|
)
|
||||||
|
|
||||||
cloud_id = Column(String)
|
cloud_id = Column(String)
|
||||||
claimed_until = Column(TIMESTAMP(timezone=True))
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def users(self):
|
def users(self):
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from sqlalchemy import Index, ForeignKey, Column, Enum as SQLAEnum, Table
|
from sqlalchemy import Index, ForeignKey, Column, Enum as SQLAEnum, Table, String
|
||||||
from sqlalchemy.dialects.postgresql import UUID
|
from sqlalchemy.dialects.postgresql import UUID
|
||||||
from sqlalchemy.orm import relationship
|
from sqlalchemy.orm import relationship
|
||||||
from sqlalchemy.event import listen
|
from sqlalchemy.event import listen
|
||||||
@ -33,6 +33,7 @@ class ApplicationRole(
|
|||||||
mixins.AuditableMixin,
|
mixins.AuditableMixin,
|
||||||
mixins.PermissionsMixin,
|
mixins.PermissionsMixin,
|
||||||
mixins.DeletableMixin,
|
mixins.DeletableMixin,
|
||||||
|
mixins.ClaimableMixin,
|
||||||
):
|
):
|
||||||
__tablename__ = "application_roles"
|
__tablename__ = "application_roles"
|
||||||
|
|
||||||
@ -59,6 +60,8 @@ class ApplicationRole(
|
|||||||
primaryjoin="and_(EnvironmentRole.application_role_id == ApplicationRole.id, EnvironmentRole.deleted == False)",
|
primaryjoin="and_(EnvironmentRole.application_role_id == ApplicationRole.id, EnvironmentRole.deleted == False)",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cloud_id = Column(String)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def latest_invitation(self):
|
def latest_invitation(self):
|
||||||
if self.invitations:
|
if self.invitations:
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from sqlalchemy import Column, ForeignKey, String, TIMESTAMP, UniqueConstraint
|
from sqlalchemy import Column, ForeignKey, String, UniqueConstraint
|
||||||
from sqlalchemy.orm import relationship
|
from sqlalchemy.orm import relationship
|
||||||
from sqlalchemy.dialects.postgresql import JSONB
|
from sqlalchemy.dialects.postgresql import JSONB
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
@ -9,7 +9,11 @@ import atst.models.types as types
|
|||||||
|
|
||||||
|
|
||||||
class Environment(
|
class Environment(
|
||||||
Base, mixins.TimestampsMixin, mixins.AuditableMixin, mixins.DeletableMixin
|
Base,
|
||||||
|
mixins.TimestampsMixin,
|
||||||
|
mixins.AuditableMixin,
|
||||||
|
mixins.DeletableMixin,
|
||||||
|
mixins.ClaimableMixin,
|
||||||
):
|
):
|
||||||
__tablename__ = "environments"
|
__tablename__ = "environments"
|
||||||
|
|
||||||
@ -28,8 +32,6 @@ class Environment(
|
|||||||
cloud_id = Column(String)
|
cloud_id = Column(String)
|
||||||
root_user_info = Column(JSONB(none_as_null=True))
|
root_user_info = Column(JSONB(none_as_null=True))
|
||||||
|
|
||||||
claimed_until = Column(TIMESTAMP(timezone=True))
|
|
||||||
|
|
||||||
roles = relationship(
|
roles = relationship(
|
||||||
"EnvironmentRole",
|
"EnvironmentRole",
|
||||||
back_populates="environment",
|
back_populates="environment",
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from sqlalchemy import Index, ForeignKey, Column, String, TIMESTAMP, Enum as SQLAEnum
|
from sqlalchemy import Index, ForeignKey, Column, String, Enum as SQLAEnum
|
||||||
from sqlalchemy.dialects.postgresql import UUID
|
from sqlalchemy.dialects.postgresql import UUID
|
||||||
from sqlalchemy.orm import relationship
|
from sqlalchemy.orm import relationship
|
||||||
|
|
||||||
@ -15,7 +15,11 @@ class CSPRole(Enum):
|
|||||||
|
|
||||||
|
|
||||||
class EnvironmentRole(
|
class EnvironmentRole(
|
||||||
Base, mixins.TimestampsMixin, mixins.AuditableMixin, mixins.DeletableMixin
|
Base,
|
||||||
|
mixins.TimestampsMixin,
|
||||||
|
mixins.AuditableMixin,
|
||||||
|
mixins.DeletableMixin,
|
||||||
|
mixins.ClaimableMixin,
|
||||||
):
|
):
|
||||||
__tablename__ = "environment_roles"
|
__tablename__ = "environment_roles"
|
||||||
|
|
||||||
@ -33,7 +37,6 @@ class EnvironmentRole(
|
|||||||
application_role = relationship("ApplicationRole")
|
application_role = relationship("ApplicationRole")
|
||||||
|
|
||||||
csp_user_id = Column(String())
|
csp_user_id = Column(String())
|
||||||
claimed_until = Column(TIMESTAMP(timezone=True))
|
|
||||||
|
|
||||||
class Status(Enum):
|
class Status(Enum):
|
||||||
PENDING = "pending"
|
PENDING = "pending"
|
||||||
|
@ -4,3 +4,4 @@ from .permissions import PermissionsMixin
|
|||||||
from .deletable import DeletableMixin
|
from .deletable import DeletableMixin
|
||||||
from .invites import InvitesMixin
|
from .invites import InvitesMixin
|
||||||
from .state_machines import FSMMixin
|
from .state_machines import FSMMixin
|
||||||
|
from .claimable import ClaimableMixin
|
||||||
|
5
atst/models/mixins/claimable.py
Normal file
5
atst/models/mixins/claimable.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
from sqlalchemy import Column, TIMESTAMP
|
||||||
|
|
||||||
|
|
||||||
|
class ClaimableMixin(object):
|
||||||
|
claimed_until = Column(TIMESTAMP(timezone=True))
|
@ -1,3 +1,5 @@
|
|||||||
|
from typing import List
|
||||||
|
|
||||||
from sqlalchemy import func, sql, Interval, and_, or_
|
from sqlalchemy import func, sql, Interval, and_, or_
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
|
||||||
@ -28,7 +30,7 @@ def claim_for_update(resource, minutes=30):
|
|||||||
.filter(
|
.filter(
|
||||||
and_(
|
and_(
|
||||||
Model.id == resource.id,
|
Model.id == resource.id,
|
||||||
or_(Model.claimed_until == None, Model.claimed_until <= func.now()),
|
or_(Model.claimed_until.is_(None), Model.claimed_until <= func.now()),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
.update({"claimed_until": claim_until}, synchronize_session="fetch")
|
.update({"claimed_until": claim_until}, synchronize_session="fetch")
|
||||||
@ -48,3 +50,51 @@ def claim_for_update(resource, minutes=30):
|
|||||||
Model.claimed_until != None
|
Model.claimed_until != None
|
||||||
).update({"claimed_until": None}, synchronize_session="fetch")
|
).update({"claimed_until": None}, synchronize_session="fetch")
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def claim_many_for_update(resources: List, minutes=30):
|
||||||
|
"""
|
||||||
|
Claim a mutually exclusive expiring hold on a group of resources.
|
||||||
|
Uses the database as a central source of time in case the server clocks have drifted.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
resources: A list of SQLAlchemy model instances with a `claimed_until` attribute.
|
||||||
|
minutes: The maximum amount of time, in minutes, to hold the claim.
|
||||||
|
"""
|
||||||
|
Model = resources[0].__class__
|
||||||
|
|
||||||
|
claim_until = func.now() + func.cast(
|
||||||
|
sql.functions.concat(minutes, " MINUTES"), Interval
|
||||||
|
)
|
||||||
|
|
||||||
|
ids = tuple(r.id for r in resources)
|
||||||
|
|
||||||
|
# Optimistically query for and update the resources in question. If they're
|
||||||
|
# already claimed, `rows_updated` will be 0 and we can give up.
|
||||||
|
rows_updated = (
|
||||||
|
db.session.query(Model)
|
||||||
|
.filter(
|
||||||
|
and_(
|
||||||
|
Model.id.in_(ids),
|
||||||
|
or_(Model.claimed_until.is_(None), Model.claimed_until <= func.now()),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.update({"claimed_until": claim_until}, synchronize_session="fetch")
|
||||||
|
)
|
||||||
|
if rows_updated < 1:
|
||||||
|
# TODO: Generalize this exception class so it can take multiple resources
|
||||||
|
raise ClaimFailedException(resources[0])
|
||||||
|
|
||||||
|
# Fetch the claimed resources
|
||||||
|
claimed = db.session.query(Model).filter(Model.id.in_(ids)).all()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Give the resource to the caller.
|
||||||
|
yield claimed
|
||||||
|
finally:
|
||||||
|
# Release the claim.
|
||||||
|
db.session.query(Model).filter(Model.id.in_(ids)).filter(
|
||||||
|
Model.claimed_until != None
|
||||||
|
).update({"claimed_until": None}, synchronize_session="fetch")
|
||||||
|
db.session.commit()
|
||||||
|
@ -23,8 +23,8 @@ def update_celery(celery, app):
|
|||||||
"task": "atst.jobs.dispatch_create_atat_admin_user",
|
"task": "atst.jobs.dispatch_create_atat_admin_user",
|
||||||
"schedule": 60,
|
"schedule": 60,
|
||||||
},
|
},
|
||||||
"beat-dispatch_provision_user": {
|
"beat-dispatch_create_user": {
|
||||||
"task": "atst.jobs.dispatch_provision_user",
|
"task": "atst.jobs.dispatch_create_user",
|
||||||
"schedule": 60,
|
"schedule": 60,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -7,6 +7,7 @@ from atst.domain.csp.cloud.models import (
|
|||||||
KeyVaultCredentials,
|
KeyVaultCredentials,
|
||||||
ManagementGroupCSPPayload,
|
ManagementGroupCSPPayload,
|
||||||
ManagementGroupCSPResponse,
|
ManagementGroupCSPResponse,
|
||||||
|
UserCSPPayload,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -97,3 +98,26 @@ def test_KeyVaultCredentials_enforce_root_creds():
|
|||||||
assert KeyVaultCredentials(
|
assert KeyVaultCredentials(
|
||||||
root_tenant_id="an id", root_sp_client_id="C3PO", root_sp_key="beep boop"
|
root_tenant_id="an id", root_sp_client_id="C3PO", root_sp_key="beep boop"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
user_payload = {
|
||||||
|
"tenant_id": "123",
|
||||||
|
"display_name": "Han Solo",
|
||||||
|
"tenant_host_name": "rebelalliance",
|
||||||
|
"email": "han@moseisley.cantina",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_UserCSPPayload_mail_nickname():
|
||||||
|
payload = UserCSPPayload(**user_payload)
|
||||||
|
assert payload.mail_nickname == f"han.solo"
|
||||||
|
|
||||||
|
|
||||||
|
def test_UserCSPPayload_user_principal_name():
|
||||||
|
payload = UserCSPPayload(**user_payload)
|
||||||
|
assert payload.user_principal_name == f"han.solo@rebelalliance.onmicrosoft.com"
|
||||||
|
|
||||||
|
|
||||||
|
def test_UserCSPPayload_password():
|
||||||
|
payload = UserCSPPayload(**user_payload)
|
||||||
|
assert payload.password
|
||||||
|
@ -86,3 +86,79 @@ def test_disable(session):
|
|||||||
session.refresh(environment_role)
|
session.refresh(environment_role)
|
||||||
assert member_role.status == ApplicationRoleStatus.DISABLED
|
assert member_role.status == ApplicationRoleStatus.DISABLED
|
||||||
assert environment_role.deleted
|
assert environment_role.deleted
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_pending_creation():
|
||||||
|
|
||||||
|
# ready Applications belonging to the same Portfolio
|
||||||
|
portfolio_one = PortfolioFactory.create()
|
||||||
|
ready_app = ApplicationFactory.create(cloud_id="123", portfolio=portfolio_one)
|
||||||
|
ready_app2 = ApplicationFactory.create(cloud_id="321", portfolio=portfolio_one)
|
||||||
|
|
||||||
|
# ready Application belonging to a new Portfolio
|
||||||
|
ready_app3 = ApplicationFactory.create(cloud_id="567")
|
||||||
|
unready_app = ApplicationFactory.create()
|
||||||
|
|
||||||
|
# two distinct Users
|
||||||
|
user_one = UserFactory.create()
|
||||||
|
user_two = UserFactory.create()
|
||||||
|
|
||||||
|
# Two ApplicationRoles belonging to the same User and
|
||||||
|
# different Applications. These should sort together because
|
||||||
|
# they are all under the same portfolio (portfolio_one).
|
||||||
|
role_one = ApplicationRoleFactory.create(
|
||||||
|
user=user_one, application=ready_app, status=ApplicationRoleStatus.ACTIVE
|
||||||
|
)
|
||||||
|
role_two = ApplicationRoleFactory.create(
|
||||||
|
user=user_one, application=ready_app2, status=ApplicationRoleStatus.ACTIVE
|
||||||
|
)
|
||||||
|
|
||||||
|
# An ApplicationRole belonging to a different User. This will
|
||||||
|
# be included but sort separately because it belongs to a
|
||||||
|
# different user.
|
||||||
|
role_three = ApplicationRoleFactory.create(
|
||||||
|
user=user_two, application=ready_app, status=ApplicationRoleStatus.ACTIVE
|
||||||
|
)
|
||||||
|
|
||||||
|
# An ApplicationRole belonging to one of the existing users
|
||||||
|
# but under a different portfolio. It will sort separately.
|
||||||
|
role_four = ApplicationRoleFactory.create(
|
||||||
|
user=user_one, application=ready_app3, status=ApplicationRoleStatus.ACTIVE
|
||||||
|
)
|
||||||
|
|
||||||
|
# This ApplicationRole will not be in the results because its
|
||||||
|
# application is not ready (implicitly, its cloud_id is not
|
||||||
|
# set.)
|
||||||
|
ApplicationRoleFactory.create(
|
||||||
|
user=UserFactory.create(),
|
||||||
|
application=unready_app,
|
||||||
|
status=ApplicationRoleStatus.ACTIVE,
|
||||||
|
)
|
||||||
|
|
||||||
|
# This ApplicationRole will not be in the results because it
|
||||||
|
# does not have a user associated.
|
||||||
|
ApplicationRoleFactory.create(
|
||||||
|
user=None, application=ready_app, status=ApplicationRoleStatus.ACTIVE,
|
||||||
|
)
|
||||||
|
|
||||||
|
# This ApplicationRole will not be in the results because its
|
||||||
|
# status is not ACTIVE.
|
||||||
|
ApplicationRoleFactory.create(
|
||||||
|
user=UserFactory.create(),
|
||||||
|
application=unready_app,
|
||||||
|
status=ApplicationRoleStatus.DISABLED,
|
||||||
|
)
|
||||||
|
|
||||||
|
app_ids = ApplicationRoles.get_pending_creation()
|
||||||
|
expected_ids = [[role_one.id, role_two.id], [role_three.id], [role_four.id]]
|
||||||
|
# Sort them to produce the same order.
|
||||||
|
assert sorted(app_ids) == sorted(expected_ids)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_many():
|
||||||
|
ar1 = ApplicationRoleFactory.create()
|
||||||
|
ar2 = ApplicationRoleFactory.create()
|
||||||
|
ApplicationRoleFactory.create()
|
||||||
|
|
||||||
|
result = ApplicationRoles.get_many([ar1.id, ar2.id])
|
||||||
|
assert result == [ar1, ar2]
|
||||||
|
104
tests/models/test_utils.py
Normal file
104
tests/models/test_utils.py
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
from threading import Thread
|
||||||
|
|
||||||
|
from atst.domain.exceptions import ClaimFailedException
|
||||||
|
from atst.models.utils import claim_for_update, claim_many_for_update
|
||||||
|
|
||||||
|
from tests.factories import EnvironmentFactory
|
||||||
|
|
||||||
|
|
||||||
|
def test_claim_for_update(session):
|
||||||
|
environment = EnvironmentFactory.create()
|
||||||
|
|
||||||
|
satisfied_claims = []
|
||||||
|
exceptions = []
|
||||||
|
|
||||||
|
# Two threads race to do work on environment and check out the lock
|
||||||
|
class FirstThread(Thread):
|
||||||
|
def run(self):
|
||||||
|
try:
|
||||||
|
with claim_for_update(environment) as env:
|
||||||
|
assert env.claimed_until
|
||||||
|
satisfied_claims.append("FirstThread")
|
||||||
|
except ClaimFailedException:
|
||||||
|
exceptions.append("FirstThread")
|
||||||
|
|
||||||
|
class SecondThread(Thread):
|
||||||
|
def run(self):
|
||||||
|
try:
|
||||||
|
with claim_for_update(environment) as env:
|
||||||
|
assert env.claimed_until
|
||||||
|
satisfied_claims.append("SecondThread")
|
||||||
|
except ClaimFailedException:
|
||||||
|
exceptions.append("SecondThread")
|
||||||
|
|
||||||
|
t1 = FirstThread()
|
||||||
|
t2 = SecondThread()
|
||||||
|
t1.start()
|
||||||
|
t2.start()
|
||||||
|
t1.join()
|
||||||
|
t2.join()
|
||||||
|
|
||||||
|
session.refresh(environment)
|
||||||
|
|
||||||
|
assert len(satisfied_claims) == 1
|
||||||
|
assert len(exceptions) == 1
|
||||||
|
|
||||||
|
if satisfied_claims == ["FirstThread"]:
|
||||||
|
assert exceptions == ["SecondThread"]
|
||||||
|
else:
|
||||||
|
assert satisfied_claims == ["SecondThread"]
|
||||||
|
assert exceptions == ["FirstThread"]
|
||||||
|
|
||||||
|
# The claim is released
|
||||||
|
assert environment.claimed_until is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_claim_many_for_update(session):
|
||||||
|
environments = [
|
||||||
|
EnvironmentFactory.create(),
|
||||||
|
EnvironmentFactory.create(),
|
||||||
|
]
|
||||||
|
|
||||||
|
satisfied_claims = []
|
||||||
|
exceptions = []
|
||||||
|
|
||||||
|
# Two threads race to do work on environment and check out the lock
|
||||||
|
class FirstThread(Thread):
|
||||||
|
def run(self):
|
||||||
|
try:
|
||||||
|
with claim_many_for_update(environments) as envs:
|
||||||
|
assert all([e.claimed_until for e in envs])
|
||||||
|
satisfied_claims.append("FirstThread")
|
||||||
|
except ClaimFailedException:
|
||||||
|
exceptions.append("FirstThread")
|
||||||
|
|
||||||
|
class SecondThread(Thread):
|
||||||
|
def run(self):
|
||||||
|
try:
|
||||||
|
with claim_many_for_update(environments) as envs:
|
||||||
|
assert all([e.claimed_until for e in envs])
|
||||||
|
satisfied_claims.append("SecondThread")
|
||||||
|
except ClaimFailedException:
|
||||||
|
exceptions.append("SecondThread")
|
||||||
|
|
||||||
|
t1 = FirstThread()
|
||||||
|
t2 = SecondThread()
|
||||||
|
t1.start()
|
||||||
|
t2.start()
|
||||||
|
t1.join()
|
||||||
|
t2.join()
|
||||||
|
|
||||||
|
for env in environments:
|
||||||
|
session.refresh(env)
|
||||||
|
|
||||||
|
assert len(satisfied_claims) == 1
|
||||||
|
assert len(exceptions) == 1
|
||||||
|
|
||||||
|
if satisfied_claims == ["FirstThread"]:
|
||||||
|
assert exceptions == ["SecondThread"]
|
||||||
|
else:
|
||||||
|
assert satisfied_claims == ["SecondThread"]
|
||||||
|
assert exceptions == ["FirstThread"]
|
||||||
|
|
||||||
|
# The claim is released
|
||||||
|
# assert environment.claimed_until is None
|
@ -2,27 +2,25 @@ import pendulum
|
|||||||
import pytest
|
import pytest
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
from threading import Thread
|
|
||||||
|
|
||||||
from atst.domain.csp.cloud import MockCloudProvider
|
from atst.domain.csp.cloud import MockCloudProvider
|
||||||
from atst.domain.portfolios import Portfolios
|
from atst.domain.portfolios import Portfolios
|
||||||
|
from atst.models import ApplicationRoleStatus
|
||||||
|
|
||||||
from atst.jobs import (
|
from atst.jobs import (
|
||||||
RecordFailure,
|
RecordFailure,
|
||||||
dispatch_create_environment,
|
dispatch_create_environment,
|
||||||
dispatch_create_application,
|
dispatch_create_application,
|
||||||
|
dispatch_create_user,
|
||||||
dispatch_create_atat_admin_user,
|
dispatch_create_atat_admin_user,
|
||||||
dispatch_provision_portfolio,
|
dispatch_provision_portfolio,
|
||||||
dispatch_provision_user,
|
|
||||||
create_environment,
|
create_environment,
|
||||||
do_provision_user,
|
do_create_user,
|
||||||
do_provision_portfolio,
|
do_provision_portfolio,
|
||||||
do_create_environment,
|
do_create_environment,
|
||||||
do_create_application,
|
do_create_application,
|
||||||
do_create_atat_admin_user,
|
do_create_atat_admin_user,
|
||||||
)
|
)
|
||||||
from atst.models.utils import claim_for_update
|
|
||||||
from atst.domain.exceptions import ClaimFailedException
|
|
||||||
from tests.factories import (
|
from tests.factories import (
|
||||||
EnvironmentFactory,
|
EnvironmentFactory,
|
||||||
EnvironmentRoleFactory,
|
EnvironmentRoleFactory,
|
||||||
@ -30,6 +28,7 @@ from tests.factories import (
|
|||||||
PortfolioStateMachineFactory,
|
PortfolioStateMachineFactory,
|
||||||
ApplicationFactory,
|
ApplicationFactory,
|
||||||
ApplicationRoleFactory,
|
ApplicationRoleFactory,
|
||||||
|
UserFactory,
|
||||||
)
|
)
|
||||||
from atst.models import CSPRole, EnvironmentRole, ApplicationRoleStatus, JobFailure
|
from atst.models import CSPRole, EnvironmentRole, ApplicationRoleStatus, JobFailure
|
||||||
|
|
||||||
@ -126,6 +125,30 @@ def test_create_application_job_is_idempotent(csp):
|
|||||||
csp.create_application.assert_not_called()
|
csp.create_application.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_user_job(session, csp):
|
||||||
|
portfolio = PortfolioFactory.create(
|
||||||
|
csp_data={
|
||||||
|
"tenant_id": str(uuid4()),
|
||||||
|
"domain_name": "rebelalliance.onmicrosoft.com",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
application = ApplicationFactory.create(portfolio=portfolio, cloud_id="321")
|
||||||
|
user = UserFactory.create(
|
||||||
|
first_name="Han", last_name="Solo", email="han@example.com"
|
||||||
|
)
|
||||||
|
app_role = ApplicationRoleFactory.create(
|
||||||
|
application=application,
|
||||||
|
user=user,
|
||||||
|
status=ApplicationRoleStatus.ACTIVE,
|
||||||
|
cloud_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
do_create_user(csp, [app_role.id])
|
||||||
|
session.refresh(app_role)
|
||||||
|
|
||||||
|
assert app_role.cloud_id
|
||||||
|
|
||||||
|
|
||||||
def test_create_atat_admin_user(csp, session):
|
def test_create_atat_admin_user(csp, session):
|
||||||
environment = EnvironmentFactory.create(cloud_id="something")
|
environment = EnvironmentFactory.create(cloud_id="something")
|
||||||
do_create_atat_admin_user(csp, environment.id)
|
do_create_atat_admin_user(csp, environment.id)
|
||||||
@ -181,6 +204,29 @@ def test_dispatch_create_application(monkeypatch):
|
|||||||
mock.delay.assert_called_once_with(application_id=app.id)
|
mock.delay.assert_called_once_with(application_id=app.id)
|
||||||
|
|
||||||
|
|
||||||
|
def test_dispatch_create_user(monkeypatch):
|
||||||
|
application = ApplicationFactory.create(cloud_id="123")
|
||||||
|
user = UserFactory.create(
|
||||||
|
first_name="Han", last_name="Solo", email="han@example.com"
|
||||||
|
)
|
||||||
|
app_role = ApplicationRoleFactory.create(
|
||||||
|
application=application,
|
||||||
|
user=user,
|
||||||
|
status=ApplicationRoleStatus.ACTIVE,
|
||||||
|
cloud_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock = Mock()
|
||||||
|
monkeypatch.setattr("atst.jobs.create_user", mock)
|
||||||
|
|
||||||
|
# When dispatch_create_user is called
|
||||||
|
dispatch_create_user.run()
|
||||||
|
|
||||||
|
# It should cause the create_user task to be called once
|
||||||
|
# with the application id
|
||||||
|
mock.delay.assert_called_once_with(application_role_ids=[app_role.id])
|
||||||
|
|
||||||
|
|
||||||
def test_dispatch_create_atat_admin_user(session, monkeypatch):
|
def test_dispatch_create_atat_admin_user(session, monkeypatch):
|
||||||
portfolio = PortfolioFactory.create(
|
portfolio = PortfolioFactory.create(
|
||||||
applications=[
|
applications=[
|
||||||
@ -240,128 +286,6 @@ def test_create_environment_no_dupes(session, celery_app, celery_worker):
|
|||||||
assert environment.claimed_until == None
|
assert environment.claimed_until == None
|
||||||
|
|
||||||
|
|
||||||
def test_claim_for_update(session):
|
|
||||||
portfolio = PortfolioFactory.create(
|
|
||||||
applications=[
|
|
||||||
{"environments": [{"cloud_id": uuid4().hex, "root_user_info": {}}]}
|
|
||||||
],
|
|
||||||
task_orders=[
|
|
||||||
{
|
|
||||||
"create_clins": [
|
|
||||||
{
|
|
||||||
"start_date": pendulum.now().subtract(days=1),
|
|
||||||
"end_date": pendulum.now().add(days=1),
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
)
|
|
||||||
environment = portfolio.applications[0].environments[0]
|
|
||||||
|
|
||||||
satisfied_claims = []
|
|
||||||
exceptions = []
|
|
||||||
|
|
||||||
# Two threads race to do work on environment and check out the lock
|
|
||||||
class FirstThread(Thread):
|
|
||||||
def run(self):
|
|
||||||
try:
|
|
||||||
with claim_for_update(environment):
|
|
||||||
satisfied_claims.append("FirstThread")
|
|
||||||
except ClaimFailedException:
|
|
||||||
exceptions.append("FirstThread")
|
|
||||||
|
|
||||||
class SecondThread(Thread):
|
|
||||||
def run(self):
|
|
||||||
try:
|
|
||||||
with claim_for_update(environment):
|
|
||||||
satisfied_claims.append("SecondThread")
|
|
||||||
except ClaimFailedException:
|
|
||||||
exceptions.append("SecondThread")
|
|
||||||
|
|
||||||
t1 = FirstThread()
|
|
||||||
t2 = SecondThread()
|
|
||||||
t1.start()
|
|
||||||
t2.start()
|
|
||||||
t1.join()
|
|
||||||
t2.join()
|
|
||||||
|
|
||||||
session.refresh(environment)
|
|
||||||
|
|
||||||
assert len(satisfied_claims) == 1
|
|
||||||
assert len(exceptions) == 1
|
|
||||||
|
|
||||||
if satisfied_claims == ["FirstThread"]:
|
|
||||||
assert exceptions == ["SecondThread"]
|
|
||||||
else:
|
|
||||||
assert satisfied_claims == ["SecondThread"]
|
|
||||||
assert exceptions == ["FirstThread"]
|
|
||||||
|
|
||||||
# The claim is released
|
|
||||||
assert environment.claimed_until is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_dispatch_provision_user(csp, session, celery_app, celery_worker, monkeypatch):
|
|
||||||
|
|
||||||
# Given that I have four environment roles:
|
|
||||||
# (A) one of which has a completed status
|
|
||||||
# (B) one of which has an environment that has not been provisioned
|
|
||||||
# (C) one of which is pending, has a provisioned environment but an inactive application role
|
|
||||||
# (D) one of which is pending, has a provisioned environment and has an active application role
|
|
||||||
provisioned_environment = EnvironmentFactory.create(
|
|
||||||
cloud_id="cloud_id", root_user_info={}
|
|
||||||
)
|
|
||||||
unprovisioned_environment = EnvironmentFactory.create()
|
|
||||||
_er_a = EnvironmentRoleFactory.create(
|
|
||||||
environment=provisioned_environment, status=EnvironmentRole.Status.COMPLETED
|
|
||||||
)
|
|
||||||
_er_b = EnvironmentRoleFactory.create(
|
|
||||||
environment=unprovisioned_environment, status=EnvironmentRole.Status.PENDING
|
|
||||||
)
|
|
||||||
_er_c = EnvironmentRoleFactory.create(
|
|
||||||
environment=unprovisioned_environment,
|
|
||||||
status=EnvironmentRole.Status.PENDING,
|
|
||||||
application_role=ApplicationRoleFactory(status=ApplicationRoleStatus.PENDING),
|
|
||||||
)
|
|
||||||
er_d = EnvironmentRoleFactory.create(
|
|
||||||
environment=provisioned_environment,
|
|
||||||
status=EnvironmentRole.Status.PENDING,
|
|
||||||
application_role=ApplicationRoleFactory(status=ApplicationRoleStatus.ACTIVE),
|
|
||||||
)
|
|
||||||
|
|
||||||
mock = Mock()
|
|
||||||
monkeypatch.setattr("atst.jobs.provision_user", mock)
|
|
||||||
|
|
||||||
# When I dispatch the user provisioning task
|
|
||||||
dispatch_provision_user.run()
|
|
||||||
|
|
||||||
# I expect it to dispatch only one call, to EnvironmentRole D
|
|
||||||
mock.delay.assert_called_once_with(environment_role_id=er_d.id)
|
|
||||||
|
|
||||||
|
|
||||||
def test_do_provision_user(csp, session):
|
|
||||||
# Given that I have an EnvironmentRole with a provisioned environment
|
|
||||||
credentials = MockCloudProvider(())._auth_credentials
|
|
||||||
provisioned_environment = EnvironmentFactory.create(
|
|
||||||
cloud_id="cloud_id", root_user_info={"credentials": credentials}
|
|
||||||
)
|
|
||||||
environment_role = EnvironmentRoleFactory.create(
|
|
||||||
environment=provisioned_environment,
|
|
||||||
status=EnvironmentRole.Status.PENDING,
|
|
||||||
role="ADMIN",
|
|
||||||
)
|
|
||||||
|
|
||||||
# When I call the user provisoning task
|
|
||||||
do_provision_user(csp=csp, environment_role_id=environment_role.id)
|
|
||||||
|
|
||||||
session.refresh(environment_role)
|
|
||||||
# I expect that the CSP create_or_update_user method will be called
|
|
||||||
csp.create_or_update_user.assert_called_once_with(
|
|
||||||
credentials, environment_role, CSPRole.ADMIN
|
|
||||||
)
|
|
||||||
# I expect that the EnvironmentRole now has a csp_user_id
|
|
||||||
assert environment_role.csp_user_id
|
|
||||||
|
|
||||||
|
|
||||||
def test_dispatch_provision_portfolio(
|
def test_dispatch_provision_portfolio(
|
||||||
csp, session, portfolio, celery_app, celery_worker, monkeypatch
|
csp, session, portfolio, celery_app, celery_worker, monkeypatch
|
||||||
):
|
):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user