Merge pull request #1359 from dod-ccpo/app-env-provisioning

Application Provisioning
This commit is contained in:
dandds 2020-01-29 11:43:19 -05:00 committed by GitHub
commit 7812da5eae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 611 additions and 118 deletions

View File

@ -0,0 +1,29 @@
"""add applications.claimed_until
Revision ID: 07e0598199f6
Revises: 26319c44a8d5
Create Date: 2020-01-25 13:33:17.711548
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '07e0598199f6' # pragma: allowlist secret
down_revision = '26319c44a8d5' # pragma: allowlist secret
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('applications', sa.Column('claimed_until', sa.TIMESTAMP(timezone=True), nullable=True))
op.add_column('applications', sa.Column('cloud_id', sa.String(), nullable=True))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('applications', 'claimed_until')
op.drop_column('applications', 'cloud_id')
# ### end Alembic commands ###

View File

@ -0,0 +1,60 @@
"""combine job failures
Revision ID: 508957112ed6
Revises: 07e0598199f6
Create Date: 2020-01-25 15:03:06.377442
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = '508957112ed6' # pragma: allowlist secret
down_revision = '07e0598199f6' # pragma: allowlist secret
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('job_failures',
sa.Column('time_created', sa.TIMESTAMP(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('time_updated', sa.TIMESTAMP(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('task_id', sa.String(), nullable=False),
sa.Column('entity', sa.String(), nullable=False),
sa.Column('entity_id', sa.String(), nullable=False),
sa.PrimaryKeyConstraint('id')
)
op.drop_table('environment_job_failures')
op.drop_table('environment_role_job_failures')
op.drop_table('portfolio_job_failures')
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('portfolio_job_failures',
sa.Column('id', sa.INTEGER(), autoincrement=True, nullable=False),
sa.Column('task_id', sa.VARCHAR(), autoincrement=False, nullable=False),
sa.Column('portfolio_id', postgresql.UUID(), autoincrement=False, nullable=False),
sa.ForeignKeyConstraint(['portfolio_id'], ['portfolios.id'], name='portfolio_job_failures_portfolio_id_fkey'),
sa.PrimaryKeyConstraint('id', name='portfolio_job_failures_pkey')
)
op.create_table('environment_role_job_failures',
sa.Column('id', sa.INTEGER(), autoincrement=True, nullable=False),
sa.Column('task_id', sa.VARCHAR(), autoincrement=False, nullable=False),
sa.Column('environment_role_id', postgresql.UUID(), autoincrement=False, nullable=False),
sa.ForeignKeyConstraint(['environment_role_id'], ['environment_roles.id'], name='environment_role_job_failures_environment_role_id_fkey'),
sa.PrimaryKeyConstraint('id', name='environment_role_job_failures_pkey')
)
op.create_table('environment_job_failures',
sa.Column('id', sa.INTEGER(), autoincrement=True, nullable=False),
sa.Column('task_id', sa.VARCHAR(), autoincrement=False, nullable=False),
sa.Column('environment_id', postgresql.UUID(), autoincrement=False, nullable=False),
sa.ForeignKeyConstraint(['environment_id'], ['environments.id'], name='environment_job_failures_environment_id_fkey'),
sa.PrimaryKeyConstraint('id', name='environment_job_failures_pkey')
)
op.drop_table('job_failures')
# ### end Alembic commands ###

View File

@ -1,5 +1,9 @@
from . import BaseDomainClass
from flask import g from flask import g
from sqlalchemy import func, or_
from typing import List
from uuid import UUID
from . import BaseDomainClass
from atst.database import db from atst.database import db
from atst.domain.application_roles import ApplicationRoles from atst.domain.application_roles import ApplicationRoles
from atst.domain.environments import Environments from atst.domain.environments import Environments
@ -10,7 +14,10 @@ from atst.models import (
ApplicationRole, ApplicationRole,
ApplicationRoleStatus, ApplicationRoleStatus,
EnvironmentRole, EnvironmentRole,
Portfolio,
PortfolioStateMachine,
) )
from atst.models.mixins.state_machines import FSMStates
from atst.utils import first_or_none, commit_or_raise_already_exists_error from atst.utils import first_or_none, commit_or_raise_already_exists_error
@ -118,3 +125,21 @@ class Applications(BaseDomainClass):
db.session.commit() db.session.commit()
return invitation return invitation
@classmethod
def get_applications_pending_creation(cls) -> List[UUID]:
results = (
db.session.query(Application.id)
.join(Portfolio)
.join(PortfolioStateMachine)
.filter(PortfolioStateMachine.state == FSMStates.COMPLETED)
.filter(Application.deleted == False)
.filter(Application.cloud_id.is_(None))
.filter(
or_(
Application.claimed_until.is_(None),
Application.claimed_until <= func.now(),
)
)
).all()
return [id_ for id_, in results]

View File

@ -1,15 +1,14 @@
import json
import re import re
from secrets import token_urlsafe from secrets import token_urlsafe
from typing import Dict from typing import Dict
from uuid import uuid4 from uuid import uuid4
from atst.models.application import Application
from atst.models.environment import Environment
from atst.models.user import User
from .cloud_provider_interface import CloudProviderInterface from .cloud_provider_interface import CloudProviderInterface
from .exceptions import AuthenticationException from .exceptions import AuthenticationException
from .models import ( from .models import (
ApplicationCSPPayload,
ApplicationCSPResult,
BillingInstructionCSPPayload, BillingInstructionCSPPayload,
BillingInstructionCSPResult, BillingInstructionCSPResult,
BillingProfileCreationCSPPayload, BillingProfileCreationCSPPayload,
@ -18,6 +17,8 @@ from .models import (
BillingProfileTenantAccessCSPResult, BillingProfileTenantAccessCSPResult,
BillingProfileVerificationCSPPayload, BillingProfileVerificationCSPPayload,
BillingProfileVerificationCSPResult, BillingProfileVerificationCSPResult,
KeyVaultCredentials,
ManagementGroupCSPResponse,
TaskOrderBillingCreationCSPPayload, TaskOrderBillingCreationCSPPayload,
TaskOrderBillingCreationCSPResult, TaskOrderBillingCreationCSPResult,
TaskOrderBillingVerificationCSPPayload, TaskOrderBillingVerificationCSPPayload,
@ -26,6 +27,7 @@ from .models import (
TenantCSPResult, TenantCSPResult,
) )
from .policy import AzurePolicyManager from .policy import AzurePolicyManager
from atst.utils import sha256_hex
AZURE_ENVIRONMENT = "AZURE_PUBLIC_CLOUD" # TBD AZURE_ENVIRONMENT = "AZURE_PUBLIC_CLOUD" # TBD
AZURE_SKU_ID = "?" # probably a static sku specific to ATAT/JEDI AZURE_SKU_ID = "?" # probably a static sku specific to ATAT/JEDI
@ -47,6 +49,7 @@ class AzureSDKProvider(object):
import azure.common.credentials as credentials import azure.common.credentials as credentials
import azure.identity as identity import azure.identity as identity
from azure.keyvault import secrets from azure.keyvault import secrets
from azure.core import exceptions
from msrestazure.azure_cloud import AZURE_PUBLIC_CLOUD from msrestazure.azure_cloud import AZURE_PUBLIC_CLOUD
import adal import adal
@ -85,7 +88,7 @@ class AzureCloudProvider(CloudProviderInterface):
def set_secret(self, secret_key, secret_value): def set_secret(self, secret_key, secret_value):
credential = self._get_client_secret_credential_obj({}) credential = self._get_client_secret_credential_obj({})
secret_client = self.secrets.SecretClient( secret_client = self.sdk.secrets.SecretClient(
vault_url=self.vault_url, credential=credential, vault_url=self.vault_url, credential=credential,
) )
try: try:
@ -98,7 +101,7 @@ class AzureCloudProvider(CloudProviderInterface):
def get_secret(self, secret_key): def get_secret(self, secret_key):
credential = self._get_client_secret_credential_obj({}) credential = self._get_client_secret_credential_obj({})
secret_client = self.secrets.SecretClient( secret_client = self.sdk.secrets.SecretClient(
vault_url=self.vault_url, credential=credential, vault_url=self.vault_url, credential=credential,
) )
try: try:
@ -109,9 +112,7 @@ class AzureCloudProvider(CloudProviderInterface):
exc_info=1, exc_info=1,
) )
def create_environment( def create_environment(self, auth_credentials: Dict, user, environment):
self, auth_credentials: Dict, user: User, environment: Environment
):
# since this operation would only occur within a tenant, should we source the tenant # since this operation would only occur within a tenant, should we source the tenant
# via lookup from environment once we've created the portfolio csp data schema # via lookup from environment once we've created the portfolio csp data schema
# something like this: # something like this:
@ -128,7 +129,7 @@ class AzureCloudProvider(CloudProviderInterface):
credentials, management_group_id, display_name, parent_id, credentials, management_group_id, display_name, parent_id,
) )
return management_group return ManagementGroupCSPResponse(**management_group)
def create_atat_admin_user( def create_atat_admin_user(
self, auth_credentials: Dict, csp_environment_id: str self, auth_credentials: Dict, csp_environment_id: str
@ -167,16 +168,26 @@ class AzureCloudProvider(CloudProviderInterface):
"role_name": role_assignment_id, "role_name": role_assignment_id,
} }
def _create_application(self, auth_credentials: Dict, application: Application): def create_application(self, payload: ApplicationCSPPayload):
management_group_name = str(uuid4()) # can be anything, not just uuid creds = self._source_creds(payload.tenant_id)
display_name = application.name # Does this need to be unique? credentials = self._get_credential_obj(
credentials = self._get_credential_obj(auth_credentials) {
parent_id = "?" # application.portfolio.csp_details.management_group_id "client_id": creds.root_sp_client_id,
"secret_key": creds.root_sp_key,
return self._create_management_group( "tenant_id": creds.root_tenant_id,
credentials, management_group_name, display_name, parent_id, },
resource=AZURE_MANAGEMENT_API,
) )
response = self._create_management_group(
credentials,
payload.management_group_name,
payload.display_name,
payload.parent_id,
)
return ApplicationCSPResult(**response)
def _create_management_group( def _create_management_group(
self, credentials, management_group_id, display_name, parent_id=None, self, credentials, management_group_id, display_name, parent_id=None,
): ):
@ -198,6 +209,9 @@ class AzureCloudProvider(CloudProviderInterface):
# result is a synchronous wait, might need to do a poll instead to handle first mgmt group create # result is a synchronous wait, might need to do a poll instead to handle first mgmt group create
# since we were told it could take 10+ minutes to complete, unless this handles that polling internally # since we were told it could take 10+ minutes to complete, unless this handles that polling internally
# TODO: what to do is status is not 'Succeeded' on the
# response object? Will it always raise its own error
# instead?
return create_request.result() return create_request.result()
def _create_subscription( def _create_subscription(
@ -290,6 +304,7 @@ class AzureCloudProvider(CloudProviderInterface):
sp_token = self._get_sp_token(payload.creds) sp_token = self._get_sp_token(payload.creds)
if sp_token is None: if sp_token is None:
raise AuthenticationException("Could not resolve token for tenant creation") raise AuthenticationException("Could not resolve token for tenant creation")
payload.password = token_urlsafe(16) payload.password = token_urlsafe(16)
create_tenant_body = payload.dict(by_alias=True) create_tenant_body = payload.dict(by_alias=True)
@ -626,3 +641,24 @@ class AzureCloudProvider(CloudProviderInterface):
"secret_key": self.secret_key, "secret_key": self.secret_key,
"tenant_id": self.tenant_id, "tenant_id": self.tenant_id,
} }
def _source_creds(self, tenant_id=None) -> KeyVaultCredentials:
if tenant_id:
return self._source_tenant_creds(tenant_id)
else:
return KeyVaultCredentials(
root_tenant_id=self._root_creds.get("tenant_id"),
root_sp_client_id=self._root_creds.get("client_id"),
root_sp_key=self._root_creds.get("secret_key"),
)
def update_tenant_creds(self, tenant_id, secret):
hashed = sha256_hex(tenant_id)
self.set_secret(hashed, json.dumps(secret))
return secret
def _source_tenant_creds(self, tenant_id):
hashed = sha256_hex(tenant_id)
raw_creds = self.get_secret(hashed)
return KeyVaultCredentials(**json.loads(raw_creds))

View File

@ -1,9 +1,5 @@
from typing import Dict from typing import Dict
from atst.models.user import User
from atst.models.environment import Environment
from atst.models.environment_role import EnvironmentRole
class CloudProviderInterface: class CloudProviderInterface:
def set_secret(self, secret_key: str, secret_value: str): def set_secret(self, secret_key: str, secret_value: str):
@ -15,9 +11,7 @@ class CloudProviderInterface:
def root_creds(self) -> Dict: def root_creds(self) -> Dict:
raise NotImplementedError() raise NotImplementedError()
def create_environment( def create_environment(self, auth_credentials: Dict, user, environment) -> str:
self, auth_credentials: Dict, user: User, environment: Environment
) -> str:
"""Create a new environment in the CSP. """Create a new environment in the CSP.
Arguments: Arguments:
@ -65,7 +59,7 @@ class CloudProviderInterface:
raise NotImplementedError() raise NotImplementedError()
def create_or_update_user( def create_or_update_user(
self, auth_credentials: Dict, user_info: EnvironmentRole, csp_role_id: str self, auth_credentials: Dict, user_info, csp_role_id: str
) -> str: ) -> str:
"""Creates a user or updates an existing user's role. """Creates a user or updates an existing user's role.

View File

@ -17,6 +17,9 @@ from .exceptions import (
UnknownServerException, UnknownServerException,
) )
from .models import ( from .models import (
AZURE_MGMNT_PATH,
ApplicationCSPPayload,
ApplicationCSPResult,
BillingInstructionCSPPayload, BillingInstructionCSPPayload,
BillingInstructionCSPResult, BillingInstructionCSPResult,
BillingProfileCreationCSPPayload, BillingProfileCreationCSPPayload,
@ -340,3 +343,16 @@ class MockCloudProvider(CloudProviderInterface):
self._delay(1, 5) self._delay(1, 5)
if self._with_authorization and credentials != self._auth_credentials: if self._with_authorization and credentials != self._auth_credentials:
raise self.AUTHENTICATION_EXCEPTION raise self.AUTHENTICATION_EXCEPTION
def create_application(self, payload: ApplicationCSPPayload):
self._maybe_raise(self.UNAUTHORIZED_RATE, GeneralCSPException)
return ApplicationCSPResult(
id=f"{AZURE_MGMNT_PATH}{payload.management_group_name}"
)
def get_credentials(self, scope="portfolio", tenant_id=None):
return self.root_creds()
def update_tenant_creds(self, tenant_id, secret):
return secret

View File

@ -1,6 +1,8 @@
from typing import Dict, List, Optional from typing import Dict, List, Optional
import re
from uuid import uuid4
from pydantic import BaseModel, validator from pydantic import BaseModel, validator, root_validator
from atst.utils import snake_to_camel from atst.utils import snake_to_camel
@ -232,3 +234,110 @@ class BillingInstructionCSPResult(AliasModel):
fields = { fields = {
"reported_clin_name": "name", "reported_clin_name": "name",
} }
AZURE_MGMNT_PATH = "/providers/Microsoft.Management/managementGroups/"
MANAGEMENT_GROUP_NAME_REGEX = "^[a-zA-Z0-9\-_\(\)\.]+$"
class ManagementGroupCSPPayload(AliasModel):
"""
:param: management_group_name: Just pass a UUID for this.
:param: display_name: This can contain any character and
spaces, but should be 90 characters or fewer long.
:param: parent_id: This should be the fully qualified Azure ID,
i.e. /providers/Microsoft.Management/managementGroups/[management group ID]
"""
tenant_id: str
management_group_name: Optional[str]
display_name: str
parent_id: str
@validator("management_group_name", pre=True, always=True)
def supply_management_group_name_default(cls, name):
if name:
if re.match(MANAGEMENT_GROUP_NAME_REGEX, name) is None:
raise ValueError(
f"Management group name must match {MANAGEMENT_GROUP_NAME_REGEX}"
)
return name[0:90]
else:
return str(uuid4())
@validator("display_name", pre=True, always=True)
def enforce_display_name_length(cls, name):
return name[0:90]
@validator("parent_id", pre=True, always=True)
def enforce_parent_id_pattern(cls, id_):
if AZURE_MGMNT_PATH not in id_:
return f"{AZURE_MGMNT_PATH}{id_}"
else:
return id_
class ManagementGroupCSPResponse(AliasModel):
id: str
class ApplicationCSPPayload(ManagementGroupCSPPayload):
pass
class ApplicationCSPResult(ManagementGroupCSPResponse):
pass
class KeyVaultCredentials(BaseModel):
root_sp_client_id: Optional[str]
root_sp_key: Optional[str]
root_tenant_id: Optional[str]
tenant_id: Optional[str]
tenant_admin_username: Optional[str]
tenant_admin_password: Optional[str]
tenant_sp_client_id: Optional[str]
tenant_sp_key: Optional[str]
@root_validator(pre=True)
def enforce_admin_creds(cls, values):
tenant_id = values.get("tenant_id")
username = values.get("tenant_admin_username")
password = values.get("tenant_admin_password")
if any([username, password]) and not all([tenant_id, username, password]):
raise ValueError(
"tenant_id, tenant_admin_username, and tenant_admin_password must all be set if any one is"
)
return values
@root_validator(pre=True)
def enforce_sp_creds(cls, values):
tenant_id = values.get("tenant_id")
client_id = values.get("tenant_sp_client_id")
key = values.get("tenant_sp_key")
if any([client_id, key]) and not all([tenant_id, client_id, key]):
raise ValueError(
"tenant_id, tenant_sp_client_id, and tenant_sp_key must all be set if any one is"
)
return values
@root_validator(pre=True)
def enforce_root_creds(cls, values):
sp_creds = [
values.get("root_tenant_id"),
values.get("root_sp_client_id"),
values.get("root_sp_key"),
]
if any(sp_creds) and not all(sp_creds):
raise ValueError(
"root_tenant_id, root_sp_client_id, and root_sp_key must all be set if any one is"
)
return values

View File

@ -3,47 +3,38 @@ import pendulum
from atst.database import db from atst.database import db
from atst.queue import celery from atst.queue import celery
from atst.models import ( from atst.models import EnvironmentRole, JobFailure
EnvironmentJobFailure,
EnvironmentRoleJobFailure,
EnvironmentRole,
PortfolioJobFailure,
)
from atst.domain.csp.cloud.exceptions import GeneralCSPException from atst.domain.csp.cloud.exceptions import GeneralCSPException
from atst.domain.csp.cloud import CloudProviderInterface from atst.domain.csp.cloud import CloudProviderInterface
from atst.domain.applications import Applications
from atst.domain.environments import Environments from atst.domain.environments import Environments
from atst.domain.portfolios import Portfolios from atst.domain.portfolios import Portfolios
from atst.domain.environment_roles import EnvironmentRoles from atst.domain.environment_roles import EnvironmentRoles
from atst.models.utils import claim_for_update from atst.models.utils import claim_for_update
from atst.utils.localization import translate from atst.utils.localization import translate
from atst.domain.csp.cloud.models import ApplicationCSPPayload
class RecordPortfolioFailure(celery.Task): class RecordFailure(celery.Task):
_ENTITIES = [
"portfolio_id",
"application_id",
"environment_id",
"environment_role_id",
]
def _derive_entity_info(self, kwargs):
matches = [e for e in self._ENTITIES if e in kwargs.keys()]
if matches:
match = matches[0]
return {"entity": match.replace("_id", ""), "entity_id": kwargs[match]}
else:
return None
def on_failure(self, exc, task_id, args, kwargs, einfo): def on_failure(self, exc, task_id, args, kwargs, einfo):
if "portfolio_id" in kwargs: info = self._derive_entity_info(kwargs)
failure = PortfolioJobFailure( if info:
portfolio_id=kwargs["portfolio_id"], task_id=task_id failure = JobFailure(**info, task_id=task_id)
)
db.session.add(failure)
db.session.commit()
class RecordEnvironmentFailure(celery.Task):
def on_failure(self, exc, task_id, args, kwargs, einfo):
if "environment_id" in kwargs:
failure = EnvironmentJobFailure(
environment_id=kwargs["environment_id"], task_id=task_id
)
db.session.add(failure)
db.session.commit()
class RecordEnvironmentRoleFailure(celery.Task):
def on_failure(self, exc, task_id, args, kwargs, einfo):
if "environment_role_id" in kwargs:
failure = EnvironmentRoleJobFailure(
environment_role_id=kwargs["environment_role_id"], task_id=task_id
)
db.session.add(failure) db.session.add(failure)
db.session.commit() db.session.commit()
@ -63,6 +54,27 @@ def send_notification_mail(recipients, subject, body):
app.mailer.send(recipients, subject, body) app.mailer.send(recipients, subject, body)
def do_create_application(csp: CloudProviderInterface, application_id=None):
application = Applications.get(application_id)
with claim_for_update(application) as application:
if application.cloud_id:
return
csp_details = application.portfolio.csp_data
parent_id = csp_details.get("root_management_group_id")
tenant_id = csp_details.get("tenant_id")
payload = ApplicationCSPPayload(
tenant_id=tenant_id, display_name=application.name, parent_id=parent_id
)
app_result = csp.create_application(payload)
application.cloud_id = app_result.id
db.session.add(application)
db.session.commit()
def do_create_environment(csp: CloudProviderInterface, environment_id=None): def do_create_environment(csp: CloudProviderInterface, environment_id=None):
environment = Environments.get(environment_id) environment = Environments.get(environment_id)
@ -144,17 +156,22 @@ def do_provision_portfolio(csp: CloudProviderInterface, portfolio_id=None):
fsm.trigger_next_transition() fsm.trigger_next_transition()
@celery.task(bind=True, base=RecordPortfolioFailure) @celery.task(bind=True, base=RecordFailure)
def provision_portfolio(self, portfolio_id=None): def provision_portfolio(self, portfolio_id=None):
do_work(do_provision_portfolio, self, app.csp.cloud, portfolio_id=portfolio_id) do_work(do_provision_portfolio, self, app.csp.cloud, portfolio_id=portfolio_id)
@celery.task(bind=True, base=RecordEnvironmentFailure) @celery.task(bind=True, base=RecordFailure)
def create_application(self, application_id=None):
do_work(do_create_application, self, app.csp.cloud, application_id=application_id)
@celery.task(bind=True, base=RecordFailure)
def create_environment(self, environment_id=None): def create_environment(self, environment_id=None):
do_work(do_create_environment, self, app.csp.cloud, environment_id=environment_id) do_work(do_create_environment, self, app.csp.cloud, environment_id=environment_id)
@celery.task(bind=True, base=RecordEnvironmentFailure) @celery.task(bind=True, base=RecordFailure)
def create_atat_admin_user(self, environment_id=None): def create_atat_admin_user(self, environment_id=None):
do_work( do_work(
do_create_atat_admin_user, self, app.csp.cloud, environment_id=environment_id do_create_atat_admin_user, self, app.csp.cloud, environment_id=environment_id
@ -177,6 +194,12 @@ def dispatch_provision_portfolio(self):
provision_portfolio.delay(portfolio_id=portfolio_id) provision_portfolio.delay(portfolio_id=portfolio_id)
@celery.task(bind=True)
def dispatch_create_application(self):
for application_id in Applications.get_applications_pending_creation():
create_application.delay(application_id=application_id)
@celery.task(bind=True) @celery.task(bind=True)
def dispatch_create_environment(self): def dispatch_create_environment(self):
for environment_id in Environments.get_environments_pending_creation( for environment_id in Environments.get_environments_pending_creation(

View File

@ -7,11 +7,7 @@ from .audit_event import AuditEvent
from .clin import CLIN, JEDICLINType from .clin import CLIN, JEDICLINType
from .environment import Environment from .environment import Environment
from .environment_role import EnvironmentRole, CSPRole from .environment_role import EnvironmentRole, CSPRole
from .job_failure import ( from .job_failure import JobFailure
EnvironmentJobFailure,
EnvironmentRoleJobFailure,
PortfolioJobFailure,
)
from .notification_recipient import NotificationRecipient from .notification_recipient import NotificationRecipient
from .permissions import Permissions from .permissions import Permissions
from .permission_set import PermissionSet from .permission_set import PermissionSet

View File

@ -1,4 +1,4 @@
from sqlalchemy import and_, Column, ForeignKey, String, UniqueConstraint from sqlalchemy import and_, Column, ForeignKey, String, UniqueConstraint, TIMESTAMP
from sqlalchemy.orm import relationship, synonym from sqlalchemy.orm import relationship, synonym
from atst.models.base import Base from atst.models.base import Base
@ -40,6 +40,9 @@ class Application(
), ),
) )
cloud_id = Column(String)
claimed_until = Column(TIMESTAMP(timezone=True))
@property @property
def users(self): def users(self):
return set(role.user for role in self.members) return set(role.user for role in self.members)

View File

@ -30,8 +30,6 @@ class Environment(
claimed_until = Column(TIMESTAMP(timezone=True)) claimed_until = Column(TIMESTAMP(timezone=True))
job_failures = relationship("EnvironmentJobFailure")
roles = relationship( roles = relationship(
"EnvironmentRole", "EnvironmentRole",
back_populates="environment", back_populates="environment",

View File

@ -32,8 +32,6 @@ class EnvironmentRole(
) )
application_role = relationship("ApplicationRole") application_role = relationship("ApplicationRole")
job_failures = relationship("EnvironmentRoleJobFailure")
csp_user_id = Column(String()) csp_user_id = Column(String())
claimed_until = Column(TIMESTAMP(timezone=True)) claimed_until = Column(TIMESTAMP(timezone=True))

View File

@ -1,22 +1,21 @@
from sqlalchemy import Column, ForeignKey from celery.result import AsyncResult
from sqlalchemy import Column, String, Integer
from atst.models.base import Base from atst.models.base import Base
import atst.models.mixins as mixins import atst.models.mixins as mixins
class EnvironmentJobFailure(Base, mixins.JobFailureMixin): class JobFailure(Base, mixins.TimestampsMixin):
__tablename__ = "environment_job_failures" __tablename__ = "job_failures"
environment_id = Column(ForeignKey("environments.id"), nullable=False) id = Column(Integer(), primary_key=True)
task_id = Column(String(), nullable=False)
entity = Column(String(), nullable=False)
entity_id = Column(String(), nullable=False)
@property
def task(self):
if not hasattr(self, "_task"):
self._task = AsyncResult(self.task_id)
class EnvironmentRoleJobFailure(Base, mixins.JobFailureMixin): return self._task
__tablename__ = "environment_role_job_failures"
environment_role_id = Column(ForeignKey("environment_roles.id"), nullable=False)
class PortfolioJobFailure(Base, mixins.JobFailureMixin):
__tablename__ = "portfolio_job_failures"
portfolio_id = Column(ForeignKey("portfolios.id"), nullable=False)

View File

@ -3,5 +3,4 @@ from .auditable import AuditableMixin
from .permissions import PermissionsMixin from .permissions import PermissionsMixin
from .deletable import DeletableMixin from .deletable import DeletableMixin
from .invites import InvitesMixin from .invites import InvitesMixin
from .job_failure import JobFailureMixin
from .state_machines import FSMMixin from .state_machines import FSMMixin

View File

@ -1,14 +0,0 @@
from celery.result import AsyncResult
from sqlalchemy import Column, String, Integer
class JobFailureMixin(object):
id = Column(Integer(), primary_key=True)
task_id = Column(String(), nullable=False)
@property
def task(self):
if not hasattr(self, "_task"):
self._task = AsyncResult(self.task_id)
return self._task

View File

@ -175,7 +175,7 @@ class PortfolioStateMachine(
tenant_id = new_creds.get("tenant_id") tenant_id = new_creds.get("tenant_id")
secret = self.csp.get_secret(tenant_id, new_creds) secret = self.csp.get_secret(tenant_id, new_creds)
secret.update(new_creds) secret.update(new_creds)
self.csp.set_secret(tenant_id, secret) self.csp.update_tenant_creds(tenant_id, secret)
except PydanticValidationError as exc: except PydanticValidationError as exc:
app.logger.error( app.logger.error(
f"Failed to cast response to valid result class {self.__repr__()}:", f"Failed to cast response to valid result class {self.__repr__()}:",

View File

@ -11,6 +11,10 @@ def update_celery(celery, app):
"task": "atst.jobs.dispatch_provision_portfolio", "task": "atst.jobs.dispatch_provision_portfolio",
"schedule": 60, "schedule": 60,
}, },
"beat-dispatch_create_application": {
"task": "atst.jobs.dispatch_create_application",
"schedule": 60,
},
"beat-dispatch_create_environment": { "beat-dispatch_create_environment": {
"task": "atst.jobs.dispatch_create_environment", "task": "atst.jobs.dispatch_create_environment",
"schedule": 60, "schedule": 60,

View File

@ -1,3 +1,4 @@
import hashlib
import re import re
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
@ -41,3 +42,8 @@ def commit_or_raise_already_exists_error(message):
except IntegrityError: except IntegrityError:
db.session.rollback() db.session.rollback()
raise AlreadyExistsError(message) raise AlreadyExistsError(message)
def sha256_hex(string):
hsh = hashlib.sha256(string.encode())
return hsh.digest().hex()

View File

@ -1,11 +1,15 @@
from unittest.mock import Mock import pytest
import json
from uuid import uuid4 from uuid import uuid4
from unittest.mock import Mock, patch
from tests.factories import ApplicationFactory, EnvironmentFactory from tests.factories import ApplicationFactory, EnvironmentFactory
from tests.mock_azure import AUTH_CREDENTIALS, mock_azure from tests.mock_azure import AUTH_CREDENTIALS, mock_azure
from atst.domain.csp.cloud import AzureCloudProvider from atst.domain.csp.cloud import AzureCloudProvider
from atst.domain.csp.cloud.models import ( from atst.domain.csp.cloud.models import (
ApplicationCSPPayload,
ApplicationCSPResult,
BillingInstructionCSPPayload, BillingInstructionCSPPayload,
BillingInstructionCSPResult, BillingInstructionCSPResult,
BillingProfileCreationCSPPayload, BillingProfileCreationCSPPayload,
@ -65,8 +69,8 @@ def test_create_subscription_succeeds(mock_azure: AzureCloudProvider):
def mock_management_group_create(mock_azure, spec_dict): def mock_management_group_create(mock_azure, spec_dict):
mock_azure.sdk.managementgroups.ManagementGroupsAPI.return_value.management_groups.create_or_update.return_value.result.return_value = Mock( mock_azure.sdk.managementgroups.ManagementGroupsAPI.return_value.management_groups.create_or_update.return_value.result.return_value = (
**spec_dict spec_dict
) )
@ -82,12 +86,30 @@ def test_create_environment_succeeds(mock_azure: AzureCloudProvider):
assert result.id == "Test Id" assert result.id == "Test Id"
# mock the get_secret so it returns a JSON string
MOCK_CREDS = {
"tenant_id": str(uuid4()),
"tenant_sp_client_id": str(uuid4()),
"tenant_sp_key": "1234",
}
def mock_get_secret(azure, func):
azure.get_secret = func
return azure
def test_create_application_succeeds(mock_azure: AzureCloudProvider): def test_create_application_succeeds(mock_azure: AzureCloudProvider):
application = ApplicationFactory.create() application = ApplicationFactory.create()
mock_management_group_create(mock_azure, {"id": "Test Id"}) mock_management_group_create(mock_azure, {"id": "Test Id"})
result = mock_azure._create_application(AUTH_CREDENTIALS, application) mock_azure = mock_get_secret(mock_azure, lambda *a, **k: json.dumps(MOCK_CREDS))
payload = ApplicationCSPPayload(
tenant_id="1234", display_name=application.name, parent_id=str(uuid4())
)
result = mock_azure.create_application(payload)
assert result.id == "Test Id" assert result.id == "Test Id"

View File

@ -0,0 +1,99 @@
import pytest
from pydantic import ValidationError
from atst.domain.csp.cloud.models import (
AZURE_MGMNT_PATH,
KeyVaultCredentials,
ManagementGroupCSPPayload,
ManagementGroupCSPResponse,
)
def test_ManagementGroupCSPPayload_management_group_name():
# supplies management_group_name when absent
payload = ManagementGroupCSPPayload(
tenant_id="any-old-id",
display_name="Council of Naboo",
parent_id="Galactic_Senate",
)
assert payload.management_group_name
# validates management_group_name
with pytest.raises(ValidationError):
payload = ManagementGroupCSPPayload(
tenant_id="any-old-id",
management_group_name="council of Naboo 1%^&",
display_name="Council of Naboo",
parent_id="Galactic_Senate",
)
# shortens management_group_name to fit
name = "council_of_naboo".ljust(95, "1")
assert len(name) > 90
payload = ManagementGroupCSPPayload(
tenant_id="any-old-id",
management_group_name=name,
display_name="Council of Naboo",
parent_id="Galactic_Senate",
)
assert len(payload.management_group_name) == 90
def test_ManagementGroupCSPPayload_display_name():
# shortens display_name to fit
name = "Council of Naboo".ljust(95, "1")
assert len(name) > 90
payload = ManagementGroupCSPPayload(
tenant_id="any-old-id", display_name=name, parent_id="Galactic_Senate"
)
assert len(payload.display_name) == 90
def test_ManagementGroupCSPPayload_parent_id():
full_path = f"{AZURE_MGMNT_PATH}Galactic_Senate"
# adds full path
payload = ManagementGroupCSPPayload(
tenant_id="any-old-id",
display_name="Council of Naboo",
parent_id="Galactic_Senate",
)
assert payload.parent_id == full_path
# keeps full path
payload = ManagementGroupCSPPayload(
tenant_id="any-old-id", display_name="Council of Naboo", parent_id=full_path
)
assert payload.parent_id == full_path
def test_ManagementGroupCSPResponse_id():
full_id = "/path/to/naboo-123"
response = ManagementGroupCSPResponse(
**{"id": "/path/to/naboo-123", "other": "stuff"}
)
assert response.id == full_id
def test_KeyVaultCredentials_enforce_admin_creds():
with pytest.raises(ValidationError):
KeyVaultCredentials(tenant_id="an id", tenant_admin_username="C3PO")
assert KeyVaultCredentials(
tenant_id="an id",
tenant_admin_username="C3PO",
tenant_admin_password="beep boop",
)
def test_KeyVaultCredentials_enforce_sp_creds():
with pytest.raises(ValidationError):
KeyVaultCredentials(tenant_id="an id", tenant_sp_client_id="C3PO")
assert KeyVaultCredentials(
tenant_id="an id", tenant_sp_client_id="C3PO", tenant_sp_key="beep boop"
)
def test_KeyVaultCredentials_enforce_root_creds():
with pytest.raises(ValidationError):
KeyVaultCredentials(root_tenant_id="an id", root_sp_client_id="C3PO")
assert KeyVaultCredentials(
root_tenant_id="an id", root_sp_client_id="C3PO", root_sp_key="beep boop"
)

View File

@ -1,3 +1,4 @@
from datetime import datetime, timedelta
import pytest import pytest
from uuid import uuid4 from uuid import uuid4
@ -196,3 +197,20 @@ def test_update_does_not_duplicate_names_within_portfolio():
with pytest.raises(AlreadyExistsError): with pytest.raises(AlreadyExistsError):
Applications.update(dupe_application, {"name": name}) Applications.update(dupe_application, {"name": name})
def test_get_applications_pending_creation():
now = datetime.now()
later = now + timedelta(minutes=30)
portfolio1 = PortfolioFactory.create(state="COMPLETED")
app_ready = ApplicationFactory.create(portfolio=portfolio1)
app_done = ApplicationFactory.create(portfolio=portfolio1, cloud_id="123456")
portfolio2 = PortfolioFactory.create(state="UNSTARTED")
app_not_ready = ApplicationFactory.create(portfolio=portfolio2)
uuids = Applications.get_applications_pending_creation()
assert [app_ready.id] == uuids

View File

@ -7,6 +7,7 @@ import datetime
from atst.forms import data from atst.forms import data
from atst.models import * from atst.models import *
from atst.models.mixins.state_machines import FSMStates
from atst.domain.invitations import PortfolioInvitations from atst.domain.invitations import PortfolioInvitations
from atst.domain.permission_sets import PermissionSets from atst.domain.permission_sets import PermissionSets
@ -121,6 +122,7 @@ class PortfolioFactory(Base):
owner = kwargs.pop("owner", UserFactory.create()) owner = kwargs.pop("owner", UserFactory.create())
members = kwargs.pop("members", []) members = kwargs.pop("members", [])
with_task_orders = kwargs.pop("task_orders", []) with_task_orders = kwargs.pop("task_orders", [])
state = kwargs.pop("state", None)
portfolio = super()._create(model_class, *args, **kwargs) portfolio = super()._create(model_class, *args, **kwargs)
@ -161,6 +163,12 @@ class PortfolioFactory(Base):
permission_sets=perms_set, permission_sets=perms_set,
) )
if state:
state = getattr(FSMStates, state)
fsm = PortfolioStateMachineFactory.create(state=state, portfolio=portfolio)
# setting it in the factory is not working for some reason
fsm.state = state
portfolio.applications = applications portfolio.applications = applications
portfolio.task_orders = task_orders portfolio.task_orders = task_orders
return portfolio return portfolio

View File

@ -72,6 +72,12 @@ def mock_secrets():
return Mock(spec=secrets) return Mock(spec=secrets)
def mock_identity():
import azure.identity as identity
return Mock(spec=identity)
class MockAzureSDK(object): class MockAzureSDK(object):
def __init__(self): def __init__(self):
from msrestazure.azure_cloud import AZURE_PUBLIC_CLOUD from msrestazure.azure_cloud import AZURE_PUBLIC_CLOUD
@ -88,6 +94,7 @@ class MockAzureSDK(object):
self.requests = mock_requests() self.requests = mock_requests()
# may change to a JEDI cloud # may change to a JEDI cloud
self.cloud = AZURE_PUBLIC_CLOUD self.cloud = AZURE_PUBLIC_CLOUD
self.identity = mock_identity()
@pytest.fixture(scope="function") @pytest.fixture(scope="function")

View File

@ -8,9 +8,9 @@ from atst.domain.csp.cloud import MockCloudProvider
from atst.domain.portfolios import Portfolios from atst.domain.portfolios import Portfolios
from atst.jobs import ( from atst.jobs import (
RecordEnvironmentFailure, RecordFailure,
RecordEnvironmentRoleFailure,
dispatch_create_environment, dispatch_create_environment,
dispatch_create_application,
dispatch_create_atat_admin_user, dispatch_create_atat_admin_user,
dispatch_provision_portfolio, dispatch_provision_portfolio,
dispatch_provision_user, dispatch_provision_user,
@ -18,6 +18,7 @@ from atst.jobs import (
do_provision_user, do_provision_user,
do_provision_portfolio, do_provision_portfolio,
do_create_environment, do_create_environment,
do_create_application,
do_create_atat_admin_user, do_create_atat_admin_user,
) )
from atst.models.utils import claim_for_update from atst.models.utils import claim_for_update
@ -27,9 +28,10 @@ from tests.factories import (
EnvironmentRoleFactory, EnvironmentRoleFactory,
PortfolioFactory, PortfolioFactory,
PortfolioStateMachineFactory, PortfolioStateMachineFactory,
ApplicationFactory,
ApplicationRoleFactory, ApplicationRoleFactory,
) )
from atst.models import CSPRole, EnvironmentRole, ApplicationRoleStatus from atst.models import CSPRole, EnvironmentRole, ApplicationRoleStatus, JobFailure
@pytest.fixture(autouse=True, scope="function") @pytest.fixture(autouse=True, scope="function")
@ -43,8 +45,17 @@ def portfolio():
return portfolio return portfolio
def test_environment_job_failure(celery_app, celery_worker): def _find_failure(session, entity, id_):
@celery_app.task(bind=True, base=RecordEnvironmentFailure) return (
session.query(JobFailure)
.filter(JobFailure.entity == entity)
.filter(JobFailure.entity_id == id_)
.one()
)
def test_environment_job_failure(session, celery_app, celery_worker):
@celery_app.task(bind=True, base=RecordFailure)
def _fail_hard(self, environment_id=None): def _fail_hard(self, environment_id=None):
raise ValueError("something bad happened") raise ValueError("something bad happened")
@ -56,13 +67,12 @@ def test_environment_job_failure(celery_app, celery_worker):
with pytest.raises(ValueError): with pytest.raises(ValueError):
task.get() task.get()
assert environment.job_failures job_failure = _find_failure(session, "environment", str(environment.id))
job_failure = environment.job_failures[0]
assert job_failure.task == task assert job_failure.task == task
def test_environment_role_job_failure(celery_app, celery_worker): def test_environment_role_job_failure(session, celery_app, celery_worker):
@celery_app.task(bind=True, base=RecordEnvironmentRoleFailure) @celery_app.task(bind=True, base=RecordFailure)
def _fail_hard(self, environment_role_id=None): def _fail_hard(self, environment_role_id=None):
raise ValueError("something bad happened") raise ValueError("something bad happened")
@ -74,8 +84,7 @@ def test_environment_role_job_failure(celery_app, celery_worker):
with pytest.raises(ValueError): with pytest.raises(ValueError):
task.get() task.get()
assert role.job_failures job_failure = _find_failure(session, "environment_role", str(role.id))
job_failure = role.job_failures[0]
assert job_failure.task == task assert job_failure.task == task
@ -99,6 +108,24 @@ def test_create_environment_job_is_idempotent(csp, session):
csp.create_environment.assert_not_called() csp.create_environment.assert_not_called()
def test_create_application_job(session, csp):
portfolio = PortfolioFactory.create(
csp_data={"tenant_id": str(uuid4()), "root_management_group_id": str(uuid4())}
)
application = ApplicationFactory.create(portfolio=portfolio, cloud_id=None)
do_create_application(csp, application.id)
session.refresh(application)
assert application.cloud_id
def test_create_application_job_is_idempotent(csp):
application = ApplicationFactory.create(cloud_id=uuid4())
do_create_application(csp, application.id)
csp.create_application.assert_not_called()
def test_create_atat_admin_user(csp, session): def test_create_atat_admin_user(csp, session):
environment = EnvironmentFactory.create(cloud_id="something") environment = EnvironmentFactory.create(cloud_id="something")
do_create_atat_admin_user(csp, environment.id) do_create_atat_admin_user(csp, environment.id)
@ -139,6 +166,21 @@ def test_dispatch_create_environment(session, monkeypatch):
mock.delay.assert_called_once_with(environment_id=e1.id) mock.delay.assert_called_once_with(environment_id=e1.id)
def test_dispatch_create_application(monkeypatch):
portfolio = PortfolioFactory.create(state="COMPLETED")
app = ApplicationFactory.create(portfolio=portfolio)
mock = Mock()
monkeypatch.setattr("atst.jobs.create_application", mock)
# When dispatch_create_application is called
dispatch_create_application.run()
# It should cause the create_application task to be called once
# with the application id
mock.delay.assert_called_once_with(application_id=app.id)
def test_dispatch_create_atat_admin_user(session, monkeypatch): def test_dispatch_create_atat_admin_user(session, monkeypatch):
portfolio = PortfolioFactory.create( portfolio = PortfolioFactory.create(
applications=[ applications=[

16
tests/utils/test_hash.py Normal file
View File

@ -0,0 +1,16 @@
import random
import re
import string
from atst.utils import sha256_hex
def test_sha256_hex():
sample = "".join(
random.choices(string.ascii_uppercase + string.digits, k=random.randrange(200))
)
hashed = sha256_hex(sample)
assert re.match("^[a-zA-Z0-9]+$", hashed)
assert len(hashed) == 64
hashed_again = sha256_hex(sample)
assert hashed == hashed_again