Merge pull request #1380 from dod-ccpo/azure-user-creation
Azure user creation
This commit is contained in:
commit
8b0c28b09f
@ -0,0 +1,29 @@
|
||||
"""add application_role.cloud_id
|
||||
|
||||
Revision ID: 17da2a475429
|
||||
Revises: 50979d8ef680
|
||||
Create Date: 2020-02-01 10:43:03.073539
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '17da2a475429' # pragma: allowlist secret
|
||||
down_revision = '50979d8ef680' # pragma: allowlist secret
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column('application_roles', sa.Column('cloud_id', sa.String(), nullable=True))
|
||||
op.add_column('application_roles', sa.Column('claimed_until', sa.TIMESTAMP(timezone=True), nullable=True))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column('application_roles', 'cloud_id')
|
||||
op.drop_column('application_roles', 'claimed_until')
|
||||
# ### end Alembic commands ###
|
@ -1,8 +1,12 @@
|
||||
from itertools import groupby
|
||||
from typing import List
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm.exc import NoResultFound
|
||||
|
||||
from atst.database import db
|
||||
from atst.domain.environment_roles import EnvironmentRoles
|
||||
from atst.models import ApplicationRole, ApplicationRoleStatus
|
||||
from atst.models import Application, ApplicationRole, ApplicationRoleStatus, Portfolio
|
||||
from .permission_sets import PermissionSets
|
||||
from .exceptions import NotFoundError
|
||||
|
||||
@ -61,6 +65,15 @@ class ApplicationRoles(object):
|
||||
except NoResultFound:
|
||||
raise NotFoundError("application_role")
|
||||
|
||||
@classmethod
|
||||
def get_many(cls, ids):
|
||||
return (
|
||||
db.session.query(ApplicationRole)
|
||||
.filter(ApplicationRole.id.in_(ids))
|
||||
.filter(ApplicationRole.status != ApplicationRoleStatus.DISABLED)
|
||||
.all()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def update_permission_sets(cls, application_role, new_perm_sets_names):
|
||||
application_role.permission_sets = ApplicationRoles._permission_sets_for_names(
|
||||
@ -92,3 +105,29 @@ class ApplicationRoles(object):
|
||||
|
||||
db.session.add(application_role)
|
||||
db.session.commit()
|
||||
|
||||
@classmethod
|
||||
def get_pending_creation(cls) -> List[List[UUID]]:
|
||||
"""
|
||||
Returns a list of lists of ApplicationRole IDs. The IDs
|
||||
should be grouped by user and portfolio.
|
||||
"""
|
||||
results = (
|
||||
db.session.query(ApplicationRole.id, ApplicationRole.user_id, Portfolio.id)
|
||||
.join(Application, Application.id == ApplicationRole.application_id)
|
||||
.join(Portfolio, Portfolio.id == Application.portfolio_id)
|
||||
.filter(Application.cloud_id.isnot(None))
|
||||
.filter(ApplicationRole.deleted == False)
|
||||
.filter(ApplicationRole.cloud_id.is_(None))
|
||||
.filter(ApplicationRole.user_id.isnot(None))
|
||||
.filter(ApplicationRole.status == ApplicationRoleStatus.ACTIVE)
|
||||
).all()
|
||||
|
||||
groups = []
|
||||
keyfunc = lambda pair: (pair[1], pair[2])
|
||||
sorted_results = sorted(results, key=keyfunc)
|
||||
for _, g in groupby(sorted_results, keyfunc):
|
||||
group = [pair[0] for pair in list(g)]
|
||||
groups.append(group)
|
||||
|
||||
return groups
|
||||
|
@ -6,7 +6,7 @@ from uuid import uuid4
|
||||
from atst.utils import sha256_hex
|
||||
|
||||
from .cloud_provider_interface import CloudProviderInterface
|
||||
from .exceptions import AuthenticationException
|
||||
from .exceptions import AuthenticationException, UserProvisioningException
|
||||
from .models import (
|
||||
SubscriptionCreationCSPPayload,
|
||||
SubscriptionCreationCSPResult,
|
||||
@ -48,6 +48,8 @@ from .models import (
|
||||
TenantPrincipalCSPResult,
|
||||
TenantPrincipalOwnershipCSPPayload,
|
||||
TenantPrincipalOwnershipCSPResult,
|
||||
UserCSPPayload,
|
||||
UserCSPResult,
|
||||
)
|
||||
from .policy import AzurePolicyManager
|
||||
|
||||
@ -193,9 +195,9 @@ class AzureCloudProvider(CloudProviderInterface):
|
||||
creds = self._source_creds(payload.tenant_id)
|
||||
credentials = self._get_credential_obj(
|
||||
{
|
||||
"client_id": creds.root_sp_client_id,
|
||||
"secret_key": creds.root_sp_key,
|
||||
"tenant_id": creds.root_tenant_id,
|
||||
"client_id": creds.tenant_sp_client_id,
|
||||
"secret_key": creds.tenant_sp_key,
|
||||
"tenant_id": creds.tenant_id,
|
||||
},
|
||||
resource=self.sdk.cloud.endpoints.resource_manager,
|
||||
)
|
||||
@ -310,7 +312,9 @@ class AzureCloudProvider(CloudProviderInterface):
|
||||
tenant_admin_password=payload.password,
|
||||
),
|
||||
)
|
||||
return self._ok(TenantCSPResult(**result_dict))
|
||||
return self._ok(
|
||||
TenantCSPResult(domain_name=payload.domain_name, **result_dict)
|
||||
)
|
||||
else:
|
||||
return self._error(result.json())
|
||||
|
||||
@ -850,6 +854,80 @@ class AzureCloudProvider(CloudProviderInterface):
|
||||
|
||||
return service_principal
|
||||
|
||||
def create_user(self, payload: UserCSPPayload) -> UserCSPResult:
|
||||
"""Create a user in an Azure Active Directory instance.
|
||||
Unlike most of the methods on this interface, this requires
|
||||
two API calls: one POST to create the user and one PATCH to
|
||||
set the alternate email address. The email address cannot
|
||||
be set on the first API call. The email address is
|
||||
necessary so that users can do Self-Service Password
|
||||
Recovery.
|
||||
|
||||
Arguments:
|
||||
payload {UserCSPPayload} -- a payload object with the
|
||||
data necessary for both calls
|
||||
|
||||
Returns:
|
||||
UserCSPResult -- a result object containing the AAD ID.
|
||||
"""
|
||||
graph_token = self._get_tenant_principal_token(
|
||||
payload.tenant_id, resource=self.graph_resource
|
||||
)
|
||||
if graph_token is None:
|
||||
raise AuthenticationException(
|
||||
"Could not resolve graph token for tenant admin"
|
||||
)
|
||||
|
||||
result = self._create_active_directory_user(graph_token, payload)
|
||||
self._update_active_directory_user_email(graph_token, result.id, payload)
|
||||
|
||||
return result
|
||||
|
||||
def _create_active_directory_user(self, graph_token, payload: UserCSPPayload):
|
||||
request_body = {
|
||||
"accountEnabled": True,
|
||||
"displayName": payload.display_name,
|
||||
"mailNickname": payload.mail_nickname,
|
||||
"userPrincipalName": payload.user_principal_name,
|
||||
"passwordProfile": {
|
||||
"forceChangePasswordNextSignIn": True,
|
||||
"password": payload.password,
|
||||
},
|
||||
}
|
||||
|
||||
auth_header = {
|
||||
"Authorization": f"Bearer {graph_token}",
|
||||
}
|
||||
|
||||
url = f"{self.graph_resource}v1.0/users"
|
||||
|
||||
response = self.sdk.requests.post(url, headers=auth_header, json=request_body)
|
||||
|
||||
if response.ok:
|
||||
return UserCSPResult(**response.json())
|
||||
else:
|
||||
raise UserProvisioningException(f"Failed to create user: {response.json()}")
|
||||
|
||||
def _update_active_directory_user_email(
|
||||
self, graph_token, user_id, payload: UserCSPPayload
|
||||
):
|
||||
request_body = {"otherMails": [payload.email]}
|
||||
|
||||
auth_header = {
|
||||
"Authorization": f"Bearer {graph_token}",
|
||||
}
|
||||
|
||||
url = f"{self.graph_resource}v1.0/users/{user_id}"
|
||||
|
||||
response = self.sdk.requests.patch(url, headers=auth_header, json=request_body)
|
||||
|
||||
if response.ok:
|
||||
return True
|
||||
else:
|
||||
raise UserProvisioningException(
|
||||
f"Failed update user email: {response.json()}"
|
||||
)
|
||||
|
||||
def _extract_subscription_id(self, subscription_url):
|
||||
sub_id_match = SUBSCRIPTION_ID_REGEX.match(subscription_url)
|
||||
|
||||
@ -871,14 +949,15 @@ class AzureCloudProvider(CloudProviderInterface):
|
||||
creds.root_tenant_id, creds.root_sp_client_id, creds.root_sp_key
|
||||
)
|
||||
|
||||
def _get_sp_token(self, tenant_id, client_id, secret_key):
|
||||
def _get_sp_token(self, tenant_id, client_id, secret_key, resource=None):
|
||||
context = self.sdk.adal.AuthenticationContext(
|
||||
f"{self.sdk.cloud.endpoints.active_directory}/{tenant_id}"
|
||||
)
|
||||
|
||||
resource = resource or self.sdk.cloud.endpoints.resource_manager
|
||||
# TODO: handle failure states here
|
||||
token_response = context.acquire_token_with_client_credentials(
|
||||
self.sdk.cloud.endpoints.resource_manager, client_id, secret_key
|
||||
resource, client_id, secret_key
|
||||
)
|
||||
|
||||
return token_response.get("accessToken", None)
|
||||
@ -939,10 +1018,13 @@ class AzureCloudProvider(CloudProviderInterface):
|
||||
"tenant_id": self.tenant_id,
|
||||
}
|
||||
|
||||
def _get_tenant_principal_token(self, tenant_id):
|
||||
def _get_tenant_principal_token(self, tenant_id, resource=None):
|
||||
creds = self._source_creds(tenant_id)
|
||||
return self._get_sp_token(
|
||||
creds.tenant_id, creds.tenant_sp_client_id, creds.tenant_sp_key
|
||||
creds.tenant_id,
|
||||
creds.tenant_sp_client_id,
|
||||
creds.tenant_sp_key,
|
||||
resource=resource,
|
||||
)
|
||||
|
||||
def _get_elevated_management_token(self, tenant_id):
|
||||
|
@ -88,17 +88,6 @@ class UserProvisioningException(GeneralCSPException):
|
||||
"""Failed to provision a user
|
||||
"""
|
||||
|
||||
def __init__(self, env_identifier, user_identifier, reason):
|
||||
self.env_identifier = env_identifier
|
||||
self.user_identifier = user_identifier
|
||||
self.reason = reason
|
||||
|
||||
@property
|
||||
def message(self):
|
||||
return "Failed to create user {} for environment {}: {}".format(
|
||||
self.user_identifier, self.env_identifier, self.reason
|
||||
)
|
||||
|
||||
|
||||
class UserRemovalException(GeneralCSPException):
|
||||
"""Failed to remove a user
|
||||
|
@ -51,6 +51,8 @@ from .models import (
|
||||
TenantPrincipalCSPResult,
|
||||
TenantPrincipalOwnershipCSPPayload,
|
||||
TenantPrincipalOwnershipCSPResult,
|
||||
UserCSPPayload,
|
||||
UserCSPResult,
|
||||
)
|
||||
|
||||
|
||||
@ -175,6 +177,7 @@ class MockCloudProvider(CloudProviderInterface):
|
||||
"tenant_id": "",
|
||||
"user_id": "",
|
||||
"user_object_id": "",
|
||||
"domain_name": "",
|
||||
"tenant_admin_username": "test",
|
||||
"tenant_admin_password": "test",
|
||||
}
|
||||
@ -474,6 +477,11 @@ class MockCloudProvider(CloudProviderInterface):
|
||||
id=f"{AZURE_MGMNT_PATH}{payload.management_group_name}"
|
||||
)
|
||||
|
||||
def create_user(self, payload: UserCSPPayload):
|
||||
self._maybe_raise(self.UNAUTHORIZED_RATE, GeneralCSPException)
|
||||
|
||||
return UserCSPResult(id=str(uuid4()))
|
||||
|
||||
def get_credentials(self, scope="portfolio", tenant_id=None):
|
||||
return self.root_creds()
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
from secrets import token_urlsafe
|
||||
from typing import Dict, List, Optional
|
||||
import re
|
||||
from uuid import uuid4
|
||||
import re
|
||||
|
||||
from pydantic import BaseModel, validator, root_validator
|
||||
|
||||
@ -39,6 +40,7 @@ class TenantCSPResult(AliasModel):
|
||||
user_id: str
|
||||
tenant_id: str
|
||||
user_object_id: str
|
||||
domain_name: str
|
||||
|
||||
tenant_admin_username: Optional[str]
|
||||
tenant_admin_password: Optional[str]
|
||||
@ -474,3 +476,26 @@ class ProductPurchaseVerificationCSPPayload(BaseCSPPayload):
|
||||
|
||||
class ProductPurchaseVerificationCSPResult(AliasModel):
|
||||
premium_purchase_date: str
|
||||
|
||||
|
||||
class UserCSPPayload(BaseCSPPayload):
|
||||
display_name: str
|
||||
tenant_host_name: str
|
||||
email: str
|
||||
password: Optional[str]
|
||||
|
||||
@property
|
||||
def user_principal_name(self):
|
||||
return f"{self.mail_nickname}@{self.tenant_host_name}.onmicrosoft.com"
|
||||
|
||||
@property
|
||||
def mail_nickname(self):
|
||||
return self.display_name.replace(" ", ".").lower()
|
||||
|
||||
@validator("password", pre=True, always=True)
|
||||
def supply_password_default(cls, password):
|
||||
return password or token_urlsafe(16)
|
||||
|
||||
|
||||
class UserCSPResult(AliasModel):
|
||||
id: str
|
||||
|
79
atst/jobs.py
79
atst/jobs.py
@ -3,16 +3,16 @@ import pendulum
|
||||
|
||||
from atst.database import db
|
||||
from atst.queue import celery
|
||||
from atst.models import EnvironmentRole, JobFailure
|
||||
from atst.models import JobFailure
|
||||
from atst.domain.csp.cloud.exceptions import GeneralCSPException
|
||||
from atst.domain.csp.cloud import CloudProviderInterface
|
||||
from atst.domain.applications import Applications
|
||||
from atst.domain.environments import Environments
|
||||
from atst.domain.portfolios import Portfolios
|
||||
from atst.domain.environment_roles import EnvironmentRoles
|
||||
from atst.models.utils import claim_for_update
|
||||
from atst.domain.application_roles import ApplicationRoles
|
||||
from atst.models.utils import claim_for_update, claim_many_for_update
|
||||
from atst.utils.localization import translate
|
||||
from atst.domain.csp.cloud.models import ApplicationCSPPayload
|
||||
from atst.domain.csp.cloud.models import ApplicationCSPPayload, UserCSPPayload
|
||||
|
||||
|
||||
class RecordFailure(celery.Task):
|
||||
@ -75,6 +75,34 @@ def do_create_application(csp: CloudProviderInterface, application_id=None):
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def do_create_user(csp: CloudProviderInterface, application_role_ids=None):
|
||||
if not application_role_ids:
|
||||
return
|
||||
|
||||
app_roles = ApplicationRoles.get_many(application_role_ids)
|
||||
|
||||
with claim_many_for_update(app_roles) as app_roles:
|
||||
|
||||
if any([ar.cloud_id for ar in app_roles]):
|
||||
return
|
||||
|
||||
csp_details = app_roles[0].application.portfolio.csp_data
|
||||
user = app_roles[0].user
|
||||
|
||||
payload = UserCSPPayload(
|
||||
tenant_id=csp_details.get("tenant_id"),
|
||||
tenant_host_name=csp_details.get("domain_name"),
|
||||
display_name=user.full_name,
|
||||
email=user.email,
|
||||
)
|
||||
result = csp.create_user(payload)
|
||||
for app_role in app_roles:
|
||||
app_role.cloud_id = result.id
|
||||
db.session.add(app_role)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def do_create_environment(csp: CloudProviderInterface, environment_id=None):
|
||||
environment = Environments.get(environment_id)
|
||||
|
||||
@ -128,21 +156,6 @@ def render_email(template_path, context):
|
||||
return app.jinja_env.get_template(template_path).render(context)
|
||||
|
||||
|
||||
def do_provision_user(csp: CloudProviderInterface, environment_role_id=None):
|
||||
environment_role = EnvironmentRoles.get_by_id(environment_role_id)
|
||||
|
||||
with claim_for_update(environment_role) as environment_role:
|
||||
credentials = environment_role.environment.csp_credentials
|
||||
|
||||
csp_user_id = csp.create_or_update_user(
|
||||
credentials, environment_role, environment_role.role
|
||||
)
|
||||
environment_role.csp_user_id = csp_user_id
|
||||
environment_role.status = EnvironmentRole.Status.COMPLETED
|
||||
db.session.add(environment_role)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def do_work(fn, task, csp, **kwargs):
|
||||
try:
|
||||
fn(csp, **kwargs)
|
||||
@ -166,6 +179,13 @@ def create_application(self, application_id=None):
|
||||
do_work(do_create_application, self, app.csp.cloud, application_id=application_id)
|
||||
|
||||
|
||||
@celery.task(bind=True, base=RecordFailure)
|
||||
def create_user(self, application_role_ids=None):
|
||||
do_work(
|
||||
do_create_user, self, app.csp.cloud, application_role_ids=application_role_ids
|
||||
)
|
||||
|
||||
|
||||
@celery.task(bind=True, base=RecordFailure)
|
||||
def create_environment(self, environment_id=None):
|
||||
do_work(do_create_environment, self, app.csp.cloud, environment_id=environment_id)
|
||||
@ -178,13 +198,6 @@ def create_atat_admin_user(self, environment_id=None):
|
||||
)
|
||||
|
||||
|
||||
@celery.task(bind=True)
|
||||
def provision_user(self, environment_role_id=None):
|
||||
do_work(
|
||||
do_provision_user, self, app.csp.cloud, environment_role_id=environment_role_id
|
||||
)
|
||||
|
||||
|
||||
@celery.task(bind=True)
|
||||
def dispatch_provision_portfolio(self):
|
||||
"""
|
||||
@ -200,6 +213,12 @@ def dispatch_create_application(self):
|
||||
create_application.delay(application_id=application_id)
|
||||
|
||||
|
||||
@celery.task(bind=True)
|
||||
def dispatch_create_user(self):
|
||||
for application_role_ids in ApplicationRoles.get_pending_creation():
|
||||
create_user.delay(application_role_ids=application_role_ids)
|
||||
|
||||
|
||||
@celery.task(bind=True)
|
||||
def dispatch_create_environment(self):
|
||||
for environment_id in Environments.get_environments_pending_creation(
|
||||
@ -214,11 +233,3 @@ def dispatch_create_atat_admin_user(self):
|
||||
pendulum.now()
|
||||
):
|
||||
create_atat_admin_user.delay(environment_id=environment_id)
|
||||
|
||||
|
||||
@celery.task(bind=True)
|
||||
def dispatch_provision_user(self):
|
||||
for (
|
||||
environment_role_id
|
||||
) in EnvironmentRoles.get_environment_roles_pending_creation():
|
||||
provision_user.delay(environment_role_id=environment_role_id)
|
||||
|
@ -1,4 +1,4 @@
|
||||
from sqlalchemy import and_, Column, ForeignKey, String, UniqueConstraint, TIMESTAMP
|
||||
from sqlalchemy import and_, Column, ForeignKey, String, UniqueConstraint
|
||||
from sqlalchemy.orm import relationship, synonym
|
||||
|
||||
from atst.models.base import Base
|
||||
@ -9,7 +9,11 @@ from atst.models.types import Id
|
||||
|
||||
|
||||
class Application(
|
||||
Base, mixins.TimestampsMixin, mixins.AuditableMixin, mixins.DeletableMixin
|
||||
Base,
|
||||
mixins.TimestampsMixin,
|
||||
mixins.AuditableMixin,
|
||||
mixins.DeletableMixin,
|
||||
mixins.ClaimableMixin,
|
||||
):
|
||||
__tablename__ = "applications"
|
||||
|
||||
@ -41,7 +45,6 @@ class Application(
|
||||
)
|
||||
|
||||
cloud_id = Column(String)
|
||||
claimed_until = Column(TIMESTAMP(timezone=True))
|
||||
|
||||
@property
|
||||
def users(self):
|
||||
|
@ -1,5 +1,5 @@
|
||||
from enum import Enum
|
||||
from sqlalchemy import Index, ForeignKey, Column, Enum as SQLAEnum, Table
|
||||
from sqlalchemy import Index, ForeignKey, Column, Enum as SQLAEnum, Table, String
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.event import listen
|
||||
@ -33,6 +33,7 @@ class ApplicationRole(
|
||||
mixins.AuditableMixin,
|
||||
mixins.PermissionsMixin,
|
||||
mixins.DeletableMixin,
|
||||
mixins.ClaimableMixin,
|
||||
):
|
||||
__tablename__ = "application_roles"
|
||||
|
||||
@ -59,6 +60,8 @@ class ApplicationRole(
|
||||
primaryjoin="and_(EnvironmentRole.application_role_id == ApplicationRole.id, EnvironmentRole.deleted == False)",
|
||||
)
|
||||
|
||||
cloud_id = Column(String)
|
||||
|
||||
@property
|
||||
def latest_invitation(self):
|
||||
if self.invitations:
|
||||
|
@ -1,4 +1,4 @@
|
||||
from sqlalchemy import Column, ForeignKey, String, TIMESTAMP, UniqueConstraint
|
||||
from sqlalchemy import Column, ForeignKey, String, UniqueConstraint
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from enum import Enum
|
||||
@ -9,7 +9,11 @@ import atst.models.types as types
|
||||
|
||||
|
||||
class Environment(
|
||||
Base, mixins.TimestampsMixin, mixins.AuditableMixin, mixins.DeletableMixin
|
||||
Base,
|
||||
mixins.TimestampsMixin,
|
||||
mixins.AuditableMixin,
|
||||
mixins.DeletableMixin,
|
||||
mixins.ClaimableMixin,
|
||||
):
|
||||
__tablename__ = "environments"
|
||||
|
||||
@ -28,8 +32,6 @@ class Environment(
|
||||
cloud_id = Column(String)
|
||||
root_user_info = Column(JSONB(none_as_null=True))
|
||||
|
||||
claimed_until = Column(TIMESTAMP(timezone=True))
|
||||
|
||||
roles = relationship(
|
||||
"EnvironmentRole",
|
||||
back_populates="environment",
|
||||
|
@ -1,5 +1,5 @@
|
||||
from enum import Enum
|
||||
from sqlalchemy import Index, ForeignKey, Column, String, TIMESTAMP, Enum as SQLAEnum
|
||||
from sqlalchemy import Index, ForeignKey, Column, String, Enum as SQLAEnum
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
@ -15,7 +15,11 @@ class CSPRole(Enum):
|
||||
|
||||
|
||||
class EnvironmentRole(
|
||||
Base, mixins.TimestampsMixin, mixins.AuditableMixin, mixins.DeletableMixin
|
||||
Base,
|
||||
mixins.TimestampsMixin,
|
||||
mixins.AuditableMixin,
|
||||
mixins.DeletableMixin,
|
||||
mixins.ClaimableMixin,
|
||||
):
|
||||
__tablename__ = "environment_roles"
|
||||
|
||||
@ -33,7 +37,6 @@ class EnvironmentRole(
|
||||
application_role = relationship("ApplicationRole")
|
||||
|
||||
csp_user_id = Column(String())
|
||||
claimed_until = Column(TIMESTAMP(timezone=True))
|
||||
|
||||
class Status(Enum):
|
||||
PENDING = "pending"
|
||||
|
@ -4,3 +4,4 @@ from .permissions import PermissionsMixin
|
||||
from .deletable import DeletableMixin
|
||||
from .invites import InvitesMixin
|
||||
from .state_machines import FSMMixin
|
||||
from .claimable import ClaimableMixin
|
||||
|
5
atst/models/mixins/claimable.py
Normal file
5
atst/models/mixins/claimable.py
Normal file
@ -0,0 +1,5 @@
|
||||
from sqlalchemy import Column, TIMESTAMP
|
||||
|
||||
|
||||
class ClaimableMixin(object):
|
||||
claimed_until = Column(TIMESTAMP(timezone=True))
|
@ -1,3 +1,5 @@
|
||||
from typing import List
|
||||
|
||||
from sqlalchemy import func, sql, Interval, and_, or_
|
||||
from contextlib import contextmanager
|
||||
|
||||
@ -28,7 +30,7 @@ def claim_for_update(resource, minutes=30):
|
||||
.filter(
|
||||
and_(
|
||||
Model.id == resource.id,
|
||||
or_(Model.claimed_until == None, Model.claimed_until <= func.now()),
|
||||
or_(Model.claimed_until.is_(None), Model.claimed_until <= func.now()),
|
||||
)
|
||||
)
|
||||
.update({"claimed_until": claim_until}, synchronize_session="fetch")
|
||||
@ -48,3 +50,51 @@ def claim_for_update(resource, minutes=30):
|
||||
Model.claimed_until != None
|
||||
).update({"claimed_until": None}, synchronize_session="fetch")
|
||||
db.session.commit()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def claim_many_for_update(resources: List, minutes=30):
|
||||
"""
|
||||
Claim a mutually exclusive expiring hold on a group of resources.
|
||||
Uses the database as a central source of time in case the server clocks have drifted.
|
||||
|
||||
Args:
|
||||
resources: A list of SQLAlchemy model instances with a `claimed_until` attribute.
|
||||
minutes: The maximum amount of time, in minutes, to hold the claim.
|
||||
"""
|
||||
Model = resources[0].__class__
|
||||
|
||||
claim_until = func.now() + func.cast(
|
||||
sql.functions.concat(minutes, " MINUTES"), Interval
|
||||
)
|
||||
|
||||
ids = tuple(r.id for r in resources)
|
||||
|
||||
# Optimistically query for and update the resources in question. If they're
|
||||
# already claimed, `rows_updated` will be 0 and we can give up.
|
||||
rows_updated = (
|
||||
db.session.query(Model)
|
||||
.filter(
|
||||
and_(
|
||||
Model.id.in_(ids),
|
||||
or_(Model.claimed_until.is_(None), Model.claimed_until <= func.now()),
|
||||
)
|
||||
)
|
||||
.update({"claimed_until": claim_until}, synchronize_session="fetch")
|
||||
)
|
||||
if rows_updated < 1:
|
||||
# TODO: Generalize this exception class so it can take multiple resources
|
||||
raise ClaimFailedException(resources[0])
|
||||
|
||||
# Fetch the claimed resources
|
||||
claimed = db.session.query(Model).filter(Model.id.in_(ids)).all()
|
||||
|
||||
try:
|
||||
# Give the resource to the caller.
|
||||
yield claimed
|
||||
finally:
|
||||
# Release the claim.
|
||||
db.session.query(Model).filter(Model.id.in_(ids)).filter(
|
||||
Model.claimed_until != None
|
||||
).update({"claimed_until": None}, synchronize_session="fetch")
|
||||
db.session.commit()
|
||||
|
@ -23,8 +23,8 @@ def update_celery(celery, app):
|
||||
"task": "atst.jobs.dispatch_create_atat_admin_user",
|
||||
"schedule": 60,
|
||||
},
|
||||
"beat-dispatch_provision_user": {
|
||||
"task": "atst.jobs.dispatch_provision_user",
|
||||
"beat-dispatch_create_user": {
|
||||
"task": "atst.jobs.dispatch_create_user",
|
||||
"schedule": 60,
|
||||
},
|
||||
}
|
||||
|
@ -7,6 +7,7 @@ from atst.domain.csp.cloud.models import (
|
||||
KeyVaultCredentials,
|
||||
ManagementGroupCSPPayload,
|
||||
ManagementGroupCSPResponse,
|
||||
UserCSPPayload,
|
||||
)
|
||||
|
||||
|
||||
@ -97,3 +98,26 @@ def test_KeyVaultCredentials_enforce_root_creds():
|
||||
assert KeyVaultCredentials(
|
||||
root_tenant_id="an id", root_sp_client_id="C3PO", root_sp_key="beep boop"
|
||||
)
|
||||
|
||||
|
||||
user_payload = {
|
||||
"tenant_id": "123",
|
||||
"display_name": "Han Solo",
|
||||
"tenant_host_name": "rebelalliance",
|
||||
"email": "han@moseisley.cantina",
|
||||
}
|
||||
|
||||
|
||||
def test_UserCSPPayload_mail_nickname():
|
||||
payload = UserCSPPayload(**user_payload)
|
||||
assert payload.mail_nickname == f"han.solo"
|
||||
|
||||
|
||||
def test_UserCSPPayload_user_principal_name():
|
||||
payload = UserCSPPayload(**user_payload)
|
||||
assert payload.user_principal_name == f"han.solo@rebelalliance.onmicrosoft.com"
|
||||
|
||||
|
||||
def test_UserCSPPayload_password():
|
||||
payload = UserCSPPayload(**user_payload)
|
||||
assert payload.password
|
||||
|
@ -86,3 +86,79 @@ def test_disable(session):
|
||||
session.refresh(environment_role)
|
||||
assert member_role.status == ApplicationRoleStatus.DISABLED
|
||||
assert environment_role.deleted
|
||||
|
||||
|
||||
def test_get_pending_creation():
|
||||
|
||||
# ready Applications belonging to the same Portfolio
|
||||
portfolio_one = PortfolioFactory.create()
|
||||
ready_app = ApplicationFactory.create(cloud_id="123", portfolio=portfolio_one)
|
||||
ready_app2 = ApplicationFactory.create(cloud_id="321", portfolio=portfolio_one)
|
||||
|
||||
# ready Application belonging to a new Portfolio
|
||||
ready_app3 = ApplicationFactory.create(cloud_id="567")
|
||||
unready_app = ApplicationFactory.create()
|
||||
|
||||
# two distinct Users
|
||||
user_one = UserFactory.create()
|
||||
user_two = UserFactory.create()
|
||||
|
||||
# Two ApplicationRoles belonging to the same User and
|
||||
# different Applications. These should sort together because
|
||||
# they are all under the same portfolio (portfolio_one).
|
||||
role_one = ApplicationRoleFactory.create(
|
||||
user=user_one, application=ready_app, status=ApplicationRoleStatus.ACTIVE
|
||||
)
|
||||
role_two = ApplicationRoleFactory.create(
|
||||
user=user_one, application=ready_app2, status=ApplicationRoleStatus.ACTIVE
|
||||
)
|
||||
|
||||
# An ApplicationRole belonging to a different User. This will
|
||||
# be included but sort separately because it belongs to a
|
||||
# different user.
|
||||
role_three = ApplicationRoleFactory.create(
|
||||
user=user_two, application=ready_app, status=ApplicationRoleStatus.ACTIVE
|
||||
)
|
||||
|
||||
# An ApplicationRole belonging to one of the existing users
|
||||
# but under a different portfolio. It will sort separately.
|
||||
role_four = ApplicationRoleFactory.create(
|
||||
user=user_one, application=ready_app3, status=ApplicationRoleStatus.ACTIVE
|
||||
)
|
||||
|
||||
# This ApplicationRole will not be in the results because its
|
||||
# application is not ready (implicitly, its cloud_id is not
|
||||
# set.)
|
||||
ApplicationRoleFactory.create(
|
||||
user=UserFactory.create(),
|
||||
application=unready_app,
|
||||
status=ApplicationRoleStatus.ACTIVE,
|
||||
)
|
||||
|
||||
# This ApplicationRole will not be in the results because it
|
||||
# does not have a user associated.
|
||||
ApplicationRoleFactory.create(
|
||||
user=None, application=ready_app, status=ApplicationRoleStatus.ACTIVE,
|
||||
)
|
||||
|
||||
# This ApplicationRole will not be in the results because its
|
||||
# status is not ACTIVE.
|
||||
ApplicationRoleFactory.create(
|
||||
user=UserFactory.create(),
|
||||
application=unready_app,
|
||||
status=ApplicationRoleStatus.DISABLED,
|
||||
)
|
||||
|
||||
app_ids = ApplicationRoles.get_pending_creation()
|
||||
expected_ids = [[role_one.id, role_two.id], [role_three.id], [role_four.id]]
|
||||
# Sort them to produce the same order.
|
||||
assert sorted(app_ids) == sorted(expected_ids)
|
||||
|
||||
|
||||
def test_get_many():
|
||||
ar1 = ApplicationRoleFactory.create()
|
||||
ar2 = ApplicationRoleFactory.create()
|
||||
ApplicationRoleFactory.create()
|
||||
|
||||
result = ApplicationRoles.get_many([ar1.id, ar2.id])
|
||||
assert result == [ar1, ar2]
|
||||
|
104
tests/models/test_utils.py
Normal file
104
tests/models/test_utils.py
Normal file
@ -0,0 +1,104 @@
|
||||
from threading import Thread
|
||||
|
||||
from atst.domain.exceptions import ClaimFailedException
|
||||
from atst.models.utils import claim_for_update, claim_many_for_update
|
||||
|
||||
from tests.factories import EnvironmentFactory
|
||||
|
||||
|
||||
def test_claim_for_update(session):
|
||||
environment = EnvironmentFactory.create()
|
||||
|
||||
satisfied_claims = []
|
||||
exceptions = []
|
||||
|
||||
# Two threads race to do work on environment and check out the lock
|
||||
class FirstThread(Thread):
|
||||
def run(self):
|
||||
try:
|
||||
with claim_for_update(environment) as env:
|
||||
assert env.claimed_until
|
||||
satisfied_claims.append("FirstThread")
|
||||
except ClaimFailedException:
|
||||
exceptions.append("FirstThread")
|
||||
|
||||
class SecondThread(Thread):
|
||||
def run(self):
|
||||
try:
|
||||
with claim_for_update(environment) as env:
|
||||
assert env.claimed_until
|
||||
satisfied_claims.append("SecondThread")
|
||||
except ClaimFailedException:
|
||||
exceptions.append("SecondThread")
|
||||
|
||||
t1 = FirstThread()
|
||||
t2 = SecondThread()
|
||||
t1.start()
|
||||
t2.start()
|
||||
t1.join()
|
||||
t2.join()
|
||||
|
||||
session.refresh(environment)
|
||||
|
||||
assert len(satisfied_claims) == 1
|
||||
assert len(exceptions) == 1
|
||||
|
||||
if satisfied_claims == ["FirstThread"]:
|
||||
assert exceptions == ["SecondThread"]
|
||||
else:
|
||||
assert satisfied_claims == ["SecondThread"]
|
||||
assert exceptions == ["FirstThread"]
|
||||
|
||||
# The claim is released
|
||||
assert environment.claimed_until is None
|
||||
|
||||
|
||||
def test_claim_many_for_update(session):
|
||||
environments = [
|
||||
EnvironmentFactory.create(),
|
||||
EnvironmentFactory.create(),
|
||||
]
|
||||
|
||||
satisfied_claims = []
|
||||
exceptions = []
|
||||
|
||||
# Two threads race to do work on environment and check out the lock
|
||||
class FirstThread(Thread):
|
||||
def run(self):
|
||||
try:
|
||||
with claim_many_for_update(environments) as envs:
|
||||
assert all([e.claimed_until for e in envs])
|
||||
satisfied_claims.append("FirstThread")
|
||||
except ClaimFailedException:
|
||||
exceptions.append("FirstThread")
|
||||
|
||||
class SecondThread(Thread):
|
||||
def run(self):
|
||||
try:
|
||||
with claim_many_for_update(environments) as envs:
|
||||
assert all([e.claimed_until for e in envs])
|
||||
satisfied_claims.append("SecondThread")
|
||||
except ClaimFailedException:
|
||||
exceptions.append("SecondThread")
|
||||
|
||||
t1 = FirstThread()
|
||||
t2 = SecondThread()
|
||||
t1.start()
|
||||
t2.start()
|
||||
t1.join()
|
||||
t2.join()
|
||||
|
||||
for env in environments:
|
||||
session.refresh(env)
|
||||
|
||||
assert len(satisfied_claims) == 1
|
||||
assert len(exceptions) == 1
|
||||
|
||||
if satisfied_claims == ["FirstThread"]:
|
||||
assert exceptions == ["SecondThread"]
|
||||
else:
|
||||
assert satisfied_claims == ["SecondThread"]
|
||||
assert exceptions == ["FirstThread"]
|
||||
|
||||
# The claim is released
|
||||
# assert environment.claimed_until is None
|
@ -2,27 +2,25 @@ import pendulum
|
||||
import pytest
|
||||
from uuid import uuid4
|
||||
from unittest.mock import Mock
|
||||
from threading import Thread
|
||||
|
||||
from atst.domain.csp.cloud import MockCloudProvider
|
||||
from atst.domain.portfolios import Portfolios
|
||||
from atst.models import ApplicationRoleStatus
|
||||
|
||||
from atst.jobs import (
|
||||
RecordFailure,
|
||||
dispatch_create_environment,
|
||||
dispatch_create_application,
|
||||
dispatch_create_user,
|
||||
dispatch_create_atat_admin_user,
|
||||
dispatch_provision_portfolio,
|
||||
dispatch_provision_user,
|
||||
create_environment,
|
||||
do_provision_user,
|
||||
do_create_user,
|
||||
do_provision_portfolio,
|
||||
do_create_environment,
|
||||
do_create_application,
|
||||
do_create_atat_admin_user,
|
||||
)
|
||||
from atst.models.utils import claim_for_update
|
||||
from atst.domain.exceptions import ClaimFailedException
|
||||
from tests.factories import (
|
||||
EnvironmentFactory,
|
||||
EnvironmentRoleFactory,
|
||||
@ -30,6 +28,7 @@ from tests.factories import (
|
||||
PortfolioStateMachineFactory,
|
||||
ApplicationFactory,
|
||||
ApplicationRoleFactory,
|
||||
UserFactory,
|
||||
)
|
||||
from atst.models import CSPRole, EnvironmentRole, ApplicationRoleStatus, JobFailure
|
||||
|
||||
@ -126,6 +125,30 @@ def test_create_application_job_is_idempotent(csp):
|
||||
csp.create_application.assert_not_called()
|
||||
|
||||
|
||||
def test_create_user_job(session, csp):
|
||||
portfolio = PortfolioFactory.create(
|
||||
csp_data={
|
||||
"tenant_id": str(uuid4()),
|
||||
"domain_name": "rebelalliance.onmicrosoft.com",
|
||||
}
|
||||
)
|
||||
application = ApplicationFactory.create(portfolio=portfolio, cloud_id="321")
|
||||
user = UserFactory.create(
|
||||
first_name="Han", last_name="Solo", email="han@example.com"
|
||||
)
|
||||
app_role = ApplicationRoleFactory.create(
|
||||
application=application,
|
||||
user=user,
|
||||
status=ApplicationRoleStatus.ACTIVE,
|
||||
cloud_id=None,
|
||||
)
|
||||
|
||||
do_create_user(csp, [app_role.id])
|
||||
session.refresh(app_role)
|
||||
|
||||
assert app_role.cloud_id
|
||||
|
||||
|
||||
def test_create_atat_admin_user(csp, session):
|
||||
environment = EnvironmentFactory.create(cloud_id="something")
|
||||
do_create_atat_admin_user(csp, environment.id)
|
||||
@ -181,6 +204,29 @@ def test_dispatch_create_application(monkeypatch):
|
||||
mock.delay.assert_called_once_with(application_id=app.id)
|
||||
|
||||
|
||||
def test_dispatch_create_user(monkeypatch):
|
||||
application = ApplicationFactory.create(cloud_id="123")
|
||||
user = UserFactory.create(
|
||||
first_name="Han", last_name="Solo", email="han@example.com"
|
||||
)
|
||||
app_role = ApplicationRoleFactory.create(
|
||||
application=application,
|
||||
user=user,
|
||||
status=ApplicationRoleStatus.ACTIVE,
|
||||
cloud_id=None,
|
||||
)
|
||||
|
||||
mock = Mock()
|
||||
monkeypatch.setattr("atst.jobs.create_user", mock)
|
||||
|
||||
# When dispatch_create_user is called
|
||||
dispatch_create_user.run()
|
||||
|
||||
# It should cause the create_user task to be called once
|
||||
# with the application id
|
||||
mock.delay.assert_called_once_with(application_role_ids=[app_role.id])
|
||||
|
||||
|
||||
def test_dispatch_create_atat_admin_user(session, monkeypatch):
|
||||
portfolio = PortfolioFactory.create(
|
||||
applications=[
|
||||
@ -240,128 +286,6 @@ def test_create_environment_no_dupes(session, celery_app, celery_worker):
|
||||
assert environment.claimed_until == None
|
||||
|
||||
|
||||
def test_claim_for_update(session):
|
||||
portfolio = PortfolioFactory.create(
|
||||
applications=[
|
||||
{"environments": [{"cloud_id": uuid4().hex, "root_user_info": {}}]}
|
||||
],
|
||||
task_orders=[
|
||||
{
|
||||
"create_clins": [
|
||||
{
|
||||
"start_date": pendulum.now().subtract(days=1),
|
||||
"end_date": pendulum.now().add(days=1),
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
)
|
||||
environment = portfolio.applications[0].environments[0]
|
||||
|
||||
satisfied_claims = []
|
||||
exceptions = []
|
||||
|
||||
# Two threads race to do work on environment and check out the lock
|
||||
class FirstThread(Thread):
|
||||
def run(self):
|
||||
try:
|
||||
with claim_for_update(environment):
|
||||
satisfied_claims.append("FirstThread")
|
||||
except ClaimFailedException:
|
||||
exceptions.append("FirstThread")
|
||||
|
||||
class SecondThread(Thread):
|
||||
def run(self):
|
||||
try:
|
||||
with claim_for_update(environment):
|
||||
satisfied_claims.append("SecondThread")
|
||||
except ClaimFailedException:
|
||||
exceptions.append("SecondThread")
|
||||
|
||||
t1 = FirstThread()
|
||||
t2 = SecondThread()
|
||||
t1.start()
|
||||
t2.start()
|
||||
t1.join()
|
||||
t2.join()
|
||||
|
||||
session.refresh(environment)
|
||||
|
||||
assert len(satisfied_claims) == 1
|
||||
assert len(exceptions) == 1
|
||||
|
||||
if satisfied_claims == ["FirstThread"]:
|
||||
assert exceptions == ["SecondThread"]
|
||||
else:
|
||||
assert satisfied_claims == ["SecondThread"]
|
||||
assert exceptions == ["FirstThread"]
|
||||
|
||||
# The claim is released
|
||||
assert environment.claimed_until is None
|
||||
|
||||
|
||||
def test_dispatch_provision_user(csp, session, celery_app, celery_worker, monkeypatch):
|
||||
|
||||
# Given that I have four environment roles:
|
||||
# (A) one of which has a completed status
|
||||
# (B) one of which has an environment that has not been provisioned
|
||||
# (C) one of which is pending, has a provisioned environment but an inactive application role
|
||||
# (D) one of which is pending, has a provisioned environment and has an active application role
|
||||
provisioned_environment = EnvironmentFactory.create(
|
||||
cloud_id="cloud_id", root_user_info={}
|
||||
)
|
||||
unprovisioned_environment = EnvironmentFactory.create()
|
||||
_er_a = EnvironmentRoleFactory.create(
|
||||
environment=provisioned_environment, status=EnvironmentRole.Status.COMPLETED
|
||||
)
|
||||
_er_b = EnvironmentRoleFactory.create(
|
||||
environment=unprovisioned_environment, status=EnvironmentRole.Status.PENDING
|
||||
)
|
||||
_er_c = EnvironmentRoleFactory.create(
|
||||
environment=unprovisioned_environment,
|
||||
status=EnvironmentRole.Status.PENDING,
|
||||
application_role=ApplicationRoleFactory(status=ApplicationRoleStatus.PENDING),
|
||||
)
|
||||
er_d = EnvironmentRoleFactory.create(
|
||||
environment=provisioned_environment,
|
||||
status=EnvironmentRole.Status.PENDING,
|
||||
application_role=ApplicationRoleFactory(status=ApplicationRoleStatus.ACTIVE),
|
||||
)
|
||||
|
||||
mock = Mock()
|
||||
monkeypatch.setattr("atst.jobs.provision_user", mock)
|
||||
|
||||
# When I dispatch the user provisioning task
|
||||
dispatch_provision_user.run()
|
||||
|
||||
# I expect it to dispatch only one call, to EnvironmentRole D
|
||||
mock.delay.assert_called_once_with(environment_role_id=er_d.id)
|
||||
|
||||
|
||||
def test_do_provision_user(csp, session):
|
||||
# Given that I have an EnvironmentRole with a provisioned environment
|
||||
credentials = MockCloudProvider(())._auth_credentials
|
||||
provisioned_environment = EnvironmentFactory.create(
|
||||
cloud_id="cloud_id", root_user_info={"credentials": credentials}
|
||||
)
|
||||
environment_role = EnvironmentRoleFactory.create(
|
||||
environment=provisioned_environment,
|
||||
status=EnvironmentRole.Status.PENDING,
|
||||
role="ADMIN",
|
||||
)
|
||||
|
||||
# When I call the user provisoning task
|
||||
do_provision_user(csp=csp, environment_role_id=environment_role.id)
|
||||
|
||||
session.refresh(environment_role)
|
||||
# I expect that the CSP create_or_update_user method will be called
|
||||
csp.create_or_update_user.assert_called_once_with(
|
||||
credentials, environment_role, CSPRole.ADMIN
|
||||
)
|
||||
# I expect that the EnvironmentRole now has a csp_user_id
|
||||
assert environment_role.csp_user_id
|
||||
|
||||
|
||||
def test_dispatch_provision_portfolio(
|
||||
csp, session, portfolio, celery_app, celery_worker, monkeypatch
|
||||
):
|
||||
|
Loading…
x
Reference in New Issue
Block a user