diff --git a/alembic/versions/07e0598199f6_add_applications_claimed_until.py b/alembic/versions/07e0598199f6_add_applications_claimed_until.py new file mode 100644 index 00000000..9c5d3abc --- /dev/null +++ b/alembic/versions/07e0598199f6_add_applications_claimed_until.py @@ -0,0 +1,29 @@ +"""add applications.claimed_until + +Revision ID: 07e0598199f6 +Revises: 26319c44a8d5 +Create Date: 2020-01-25 13:33:17.711548 + +""" +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = '07e0598199f6' # pragma: allowlist secret +down_revision = '26319c44a8d5' # pragma: allowlist secret +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('applications', sa.Column('claimed_until', sa.TIMESTAMP(timezone=True), nullable=True)) + op.add_column('applications', sa.Column('cloud_id', sa.String(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('applications', 'claimed_until') + op.drop_column('applications', 'cloud_id') + # ### end Alembic commands ### diff --git a/alembic/versions/508957112ed6_combine_job_failures.py b/alembic/versions/508957112ed6_combine_job_failures.py new file mode 100644 index 00000000..9d40bb12 --- /dev/null +++ b/alembic/versions/508957112ed6_combine_job_failures.py @@ -0,0 +1,60 @@ +"""combine job failures + +Revision ID: 508957112ed6 +Revises: 07e0598199f6 +Create Date: 2020-01-25 15:03:06.377442 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '508957112ed6' # pragma: allowlist secret +down_revision = '07e0598199f6' # pragma: allowlist secret +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('job_failures', + sa.Column('time_created', sa.TIMESTAMP(timezone=True), server_default=sa.text('now()'), nullable=False), + sa.Column('time_updated', sa.TIMESTAMP(timezone=True), server_default=sa.text('now()'), nullable=False), + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('task_id', sa.String(), nullable=False), + sa.Column('entity', sa.String(), nullable=False), + sa.Column('entity_id', sa.String(), nullable=False), + sa.PrimaryKeyConstraint('id') + ) + op.drop_table('environment_job_failures') + op.drop_table('environment_role_job_failures') + op.drop_table('portfolio_job_failures') + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('portfolio_job_failures', + sa.Column('id', sa.INTEGER(), autoincrement=True, nullable=False), + sa.Column('task_id', sa.VARCHAR(), autoincrement=False, nullable=False), + sa.Column('portfolio_id', postgresql.UUID(), autoincrement=False, nullable=False), + sa.ForeignKeyConstraint(['portfolio_id'], ['portfolios.id'], name='portfolio_job_failures_portfolio_id_fkey'), + sa.PrimaryKeyConstraint('id', name='portfolio_job_failures_pkey') + ) + op.create_table('environment_role_job_failures', + sa.Column('id', sa.INTEGER(), autoincrement=True, nullable=False), + sa.Column('task_id', sa.VARCHAR(), autoincrement=False, nullable=False), + sa.Column('environment_role_id', postgresql.UUID(), autoincrement=False, nullable=False), + sa.ForeignKeyConstraint(['environment_role_id'], ['environment_roles.id'], name='environment_role_job_failures_environment_role_id_fkey'), + sa.PrimaryKeyConstraint('id', name='environment_role_job_failures_pkey') + ) + op.create_table('environment_job_failures', + sa.Column('id', sa.INTEGER(), autoincrement=True, nullable=False), + sa.Column('task_id', sa.VARCHAR(), autoincrement=False, nullable=False), + sa.Column('environment_id', postgresql.UUID(), autoincrement=False, nullable=False), + sa.ForeignKeyConstraint(['environment_id'], ['environments.id'], name='environment_job_failures_environment_id_fkey'), + sa.PrimaryKeyConstraint('id', name='environment_job_failures_pkey') + ) + op.drop_table('job_failures') + # ### end Alembic commands ### diff --git a/atst/domain/applications.py b/atst/domain/applications.py index 3dbb9953..b9df260e 100644 --- a/atst/domain/applications.py +++ b/atst/domain/applications.py @@ -1,5 +1,9 @@ -from . import BaseDomainClass from flask import g +from sqlalchemy import func, or_ +from typing import List +from uuid import UUID + +from . import BaseDomainClass from atst.database import db from atst.domain.application_roles import ApplicationRoles from atst.domain.environments import Environments @@ -10,7 +14,10 @@ from atst.models import ( ApplicationRole, ApplicationRoleStatus, EnvironmentRole, + Portfolio, + PortfolioStateMachine, ) +from atst.models.mixins.state_machines import FSMStates from atst.utils import first_or_none, commit_or_raise_already_exists_error @@ -118,3 +125,21 @@ class Applications(BaseDomainClass): db.session.commit() return invitation + + @classmethod + def get_applications_pending_creation(cls) -> List[UUID]: + results = ( + db.session.query(Application.id) + .join(Portfolio) + .join(PortfolioStateMachine) + .filter(PortfolioStateMachine.state == FSMStates.COMPLETED) + .filter(Application.deleted == False) + .filter(Application.cloud_id.is_(None)) + .filter( + or_( + Application.claimed_until.is_(None), + Application.claimed_until <= func.now(), + ) + ) + ).all() + return [id_ for id_, in results] diff --git a/atst/domain/csp/cloud/azure_cloud_provider.py b/atst/domain/csp/cloud/azure_cloud_provider.py index 5a6a9253..84a9238c 100644 --- a/atst/domain/csp/cloud/azure_cloud_provider.py +++ b/atst/domain/csp/cloud/azure_cloud_provider.py @@ -1,15 +1,14 @@ +import json import re from secrets import token_urlsafe from typing import Dict from uuid import uuid4 -from atst.models.application import Application -from atst.models.environment import Environment -from atst.models.user import User - from .cloud_provider_interface import CloudProviderInterface from .exceptions import AuthenticationException from .models import ( + ApplicationCSPPayload, + ApplicationCSPResult, BillingInstructionCSPPayload, BillingInstructionCSPResult, BillingProfileCreationCSPPayload, @@ -18,6 +17,8 @@ from .models import ( BillingProfileTenantAccessCSPResult, BillingProfileVerificationCSPPayload, BillingProfileVerificationCSPResult, + KeyVaultCredentials, + ManagementGroupCSPResponse, TaskOrderBillingCreationCSPPayload, TaskOrderBillingCreationCSPResult, TaskOrderBillingVerificationCSPPayload, @@ -26,6 +27,7 @@ from .models import ( TenantCSPResult, ) from .policy import AzurePolicyManager +from atst.utils import sha256_hex AZURE_ENVIRONMENT = "AZURE_PUBLIC_CLOUD" # TBD AZURE_SKU_ID = "?" # probably a static sku specific to ATAT/JEDI @@ -47,6 +49,7 @@ class AzureSDKProvider(object): import azure.common.credentials as credentials import azure.identity as identity from azure.keyvault import secrets + from azure.core import exceptions from msrestazure.azure_cloud import AZURE_PUBLIC_CLOUD import adal @@ -85,7 +88,7 @@ class AzureCloudProvider(CloudProviderInterface): def set_secret(self, secret_key, secret_value): credential = self._get_client_secret_credential_obj({}) - secret_client = self.secrets.SecretClient( + secret_client = self.sdk.secrets.SecretClient( vault_url=self.vault_url, credential=credential, ) try: @@ -98,7 +101,7 @@ class AzureCloudProvider(CloudProviderInterface): def get_secret(self, secret_key): credential = self._get_client_secret_credential_obj({}) - secret_client = self.secrets.SecretClient( + secret_client = self.sdk.secrets.SecretClient( vault_url=self.vault_url, credential=credential, ) try: @@ -109,9 +112,7 @@ class AzureCloudProvider(CloudProviderInterface): exc_info=1, ) - def create_environment( - self, auth_credentials: Dict, user: User, environment: Environment - ): + def create_environment(self, auth_credentials: Dict, user, environment): # since this operation would only occur within a tenant, should we source the tenant # via lookup from environment once we've created the portfolio csp data schema # something like this: @@ -128,7 +129,7 @@ class AzureCloudProvider(CloudProviderInterface): credentials, management_group_id, display_name, parent_id, ) - return management_group + return ManagementGroupCSPResponse(**management_group) def create_atat_admin_user( self, auth_credentials: Dict, csp_environment_id: str @@ -167,16 +168,26 @@ class AzureCloudProvider(CloudProviderInterface): "role_name": role_assignment_id, } - def _create_application(self, auth_credentials: Dict, application: Application): - management_group_name = str(uuid4()) # can be anything, not just uuid - display_name = application.name # Does this need to be unique? - credentials = self._get_credential_obj(auth_credentials) - parent_id = "?" # application.portfolio.csp_details.management_group_id - - return self._create_management_group( - credentials, management_group_name, display_name, parent_id, + def create_application(self, payload: ApplicationCSPPayload): + creds = self._source_creds(payload.tenant_id) + credentials = self._get_credential_obj( + { + "client_id": creds.root_sp_client_id, + "secret_key": creds.root_sp_key, + "tenant_id": creds.root_tenant_id, + }, + resource=AZURE_MANAGEMENT_API, ) + response = self._create_management_group( + credentials, + payload.management_group_name, + payload.display_name, + payload.parent_id, + ) + + return ApplicationCSPResult(**response) + def _create_management_group( self, credentials, management_group_id, display_name, parent_id=None, ): @@ -198,6 +209,9 @@ class AzureCloudProvider(CloudProviderInterface): # result is a synchronous wait, might need to do a poll instead to handle first mgmt group create # since we were told it could take 10+ minutes to complete, unless this handles that polling internally + # TODO: what to do is status is not 'Succeeded' on the + # response object? Will it always raise its own error + # instead? return create_request.result() def _create_subscription( @@ -290,6 +304,7 @@ class AzureCloudProvider(CloudProviderInterface): sp_token = self._get_sp_token(payload.creds) if sp_token is None: raise AuthenticationException("Could not resolve token for tenant creation") + payload.password = token_urlsafe(16) create_tenant_body = payload.dict(by_alias=True) @@ -626,3 +641,24 @@ class AzureCloudProvider(CloudProviderInterface): "secret_key": self.secret_key, "tenant_id": self.tenant_id, } + + def _source_creds(self, tenant_id=None) -> KeyVaultCredentials: + if tenant_id: + return self._source_tenant_creds(tenant_id) + else: + return KeyVaultCredentials( + root_tenant_id=self._root_creds.get("tenant_id"), + root_sp_client_id=self._root_creds.get("client_id"), + root_sp_key=self._root_creds.get("secret_key"), + ) + + def update_tenant_creds(self, tenant_id, secret): + hashed = sha256_hex(tenant_id) + self.set_secret(hashed, json.dumps(secret)) + + return secret + + def _source_tenant_creds(self, tenant_id): + hashed = sha256_hex(tenant_id) + raw_creds = self.get_secret(hashed) + return KeyVaultCredentials(**json.loads(raw_creds)) diff --git a/atst/domain/csp/cloud/cloud_provider_interface.py b/atst/domain/csp/cloud/cloud_provider_interface.py index 7f975c07..5f4b9ab5 100644 --- a/atst/domain/csp/cloud/cloud_provider_interface.py +++ b/atst/domain/csp/cloud/cloud_provider_interface.py @@ -1,9 +1,5 @@ from typing import Dict -from atst.models.user import User -from atst.models.environment import Environment -from atst.models.environment_role import EnvironmentRole - class CloudProviderInterface: def set_secret(self, secret_key: str, secret_value: str): @@ -15,9 +11,7 @@ class CloudProviderInterface: def root_creds(self) -> Dict: raise NotImplementedError() - def create_environment( - self, auth_credentials: Dict, user: User, environment: Environment - ) -> str: + def create_environment(self, auth_credentials: Dict, user, environment) -> str: """Create a new environment in the CSP. Arguments: @@ -65,7 +59,7 @@ class CloudProviderInterface: raise NotImplementedError() def create_or_update_user( - self, auth_credentials: Dict, user_info: EnvironmentRole, csp_role_id: str + self, auth_credentials: Dict, user_info, csp_role_id: str ) -> str: """Creates a user or updates an existing user's role. diff --git a/atst/domain/csp/cloud/mock_cloud_provider.py b/atst/domain/csp/cloud/mock_cloud_provider.py index a6c338b5..10d62e15 100644 --- a/atst/domain/csp/cloud/mock_cloud_provider.py +++ b/atst/domain/csp/cloud/mock_cloud_provider.py @@ -17,6 +17,9 @@ from .exceptions import ( UnknownServerException, ) from .models import ( + AZURE_MGMNT_PATH, + ApplicationCSPPayload, + ApplicationCSPResult, BillingInstructionCSPPayload, BillingInstructionCSPResult, BillingProfileCreationCSPPayload, @@ -340,3 +343,16 @@ class MockCloudProvider(CloudProviderInterface): self._delay(1, 5) if self._with_authorization and credentials != self._auth_credentials: raise self.AUTHENTICATION_EXCEPTION + + def create_application(self, payload: ApplicationCSPPayload): + self._maybe_raise(self.UNAUTHORIZED_RATE, GeneralCSPException) + + return ApplicationCSPResult( + id=f"{AZURE_MGMNT_PATH}{payload.management_group_name}" + ) + + def get_credentials(self, scope="portfolio", tenant_id=None): + return self.root_creds() + + def update_tenant_creds(self, tenant_id, secret): + return secret diff --git a/atst/domain/csp/cloud/models.py b/atst/domain/csp/cloud/models.py index c6bf0ede..b4ff9232 100644 --- a/atst/domain/csp/cloud/models.py +++ b/atst/domain/csp/cloud/models.py @@ -1,6 +1,8 @@ from typing import Dict, List, Optional +import re +from uuid import uuid4 -from pydantic import BaseModel, validator +from pydantic import BaseModel, validator, root_validator from atst.utils import snake_to_camel @@ -232,3 +234,110 @@ class BillingInstructionCSPResult(AliasModel): fields = { "reported_clin_name": "name", } + + +AZURE_MGMNT_PATH = "/providers/Microsoft.Management/managementGroups/" + +MANAGEMENT_GROUP_NAME_REGEX = "^[a-zA-Z0-9\-_\(\)\.]+$" + + +class ManagementGroupCSPPayload(AliasModel): + """ + :param: management_group_name: Just pass a UUID for this. + :param: display_name: This can contain any character and + spaces, but should be 90 characters or fewer long. + :param: parent_id: This should be the fully qualified Azure ID, + i.e. /providers/Microsoft.Management/managementGroups/[management group ID] + """ + + tenant_id: str + management_group_name: Optional[str] + display_name: str + parent_id: str + + @validator("management_group_name", pre=True, always=True) + def supply_management_group_name_default(cls, name): + if name: + if re.match(MANAGEMENT_GROUP_NAME_REGEX, name) is None: + raise ValueError( + f"Management group name must match {MANAGEMENT_GROUP_NAME_REGEX}" + ) + + return name[0:90] + else: + return str(uuid4()) + + @validator("display_name", pre=True, always=True) + def enforce_display_name_length(cls, name): + return name[0:90] + + @validator("parent_id", pre=True, always=True) + def enforce_parent_id_pattern(cls, id_): + if AZURE_MGMNT_PATH not in id_: + return f"{AZURE_MGMNT_PATH}{id_}" + else: + return id_ + + +class ManagementGroupCSPResponse(AliasModel): + id: str + + +class ApplicationCSPPayload(ManagementGroupCSPPayload): + pass + + +class ApplicationCSPResult(ManagementGroupCSPResponse): + pass + + +class KeyVaultCredentials(BaseModel): + root_sp_client_id: Optional[str] + root_sp_key: Optional[str] + root_tenant_id: Optional[str] + + tenant_id: Optional[str] + + tenant_admin_username: Optional[str] + tenant_admin_password: Optional[str] + + tenant_sp_client_id: Optional[str] + tenant_sp_key: Optional[str] + + @root_validator(pre=True) + def enforce_admin_creds(cls, values): + tenant_id = values.get("tenant_id") + username = values.get("tenant_admin_username") + password = values.get("tenant_admin_password") + if any([username, password]) and not all([tenant_id, username, password]): + raise ValueError( + "tenant_id, tenant_admin_username, and tenant_admin_password must all be set if any one is" + ) + + return values + + @root_validator(pre=True) + def enforce_sp_creds(cls, values): + tenant_id = values.get("tenant_id") + client_id = values.get("tenant_sp_client_id") + key = values.get("tenant_sp_key") + if any([client_id, key]) and not all([tenant_id, client_id, key]): + raise ValueError( + "tenant_id, tenant_sp_client_id, and tenant_sp_key must all be set if any one is" + ) + + return values + + @root_validator(pre=True) + def enforce_root_creds(cls, values): + sp_creds = [ + values.get("root_tenant_id"), + values.get("root_sp_client_id"), + values.get("root_sp_key"), + ] + if any(sp_creds) and not all(sp_creds): + raise ValueError( + "root_tenant_id, root_sp_client_id, and root_sp_key must all be set if any one is" + ) + + return values diff --git a/atst/jobs.py b/atst/jobs.py index 7172343b..14256336 100644 --- a/atst/jobs.py +++ b/atst/jobs.py @@ -3,47 +3,38 @@ import pendulum from atst.database import db from atst.queue import celery -from atst.models import ( - EnvironmentJobFailure, - EnvironmentRoleJobFailure, - EnvironmentRole, - PortfolioJobFailure, -) +from atst.models import EnvironmentRole, JobFailure from atst.domain.csp.cloud.exceptions import GeneralCSPException from atst.domain.csp.cloud import CloudProviderInterface +from atst.domain.applications import Applications from atst.domain.environments import Environments from atst.domain.portfolios import Portfolios from atst.domain.environment_roles import EnvironmentRoles from atst.models.utils import claim_for_update from atst.utils.localization import translate +from atst.domain.csp.cloud.models import ApplicationCSPPayload -class RecordPortfolioFailure(celery.Task): +class RecordFailure(celery.Task): + _ENTITIES = [ + "portfolio_id", + "application_id", + "environment_id", + "environment_role_id", + ] + + def _derive_entity_info(self, kwargs): + matches = [e for e in self._ENTITIES if e in kwargs.keys()] + if matches: + match = matches[0] + return {"entity": match.replace("_id", ""), "entity_id": kwargs[match]} + else: + return None + def on_failure(self, exc, task_id, args, kwargs, einfo): - if "portfolio_id" in kwargs: - failure = PortfolioJobFailure( - portfolio_id=kwargs["portfolio_id"], task_id=task_id - ) - db.session.add(failure) - db.session.commit() - - -class RecordEnvironmentFailure(celery.Task): - def on_failure(self, exc, task_id, args, kwargs, einfo): - if "environment_id" in kwargs: - failure = EnvironmentJobFailure( - environment_id=kwargs["environment_id"], task_id=task_id - ) - db.session.add(failure) - db.session.commit() - - -class RecordEnvironmentRoleFailure(celery.Task): - def on_failure(self, exc, task_id, args, kwargs, einfo): - if "environment_role_id" in kwargs: - failure = EnvironmentRoleJobFailure( - environment_role_id=kwargs["environment_role_id"], task_id=task_id - ) + info = self._derive_entity_info(kwargs) + if info: + failure = JobFailure(**info, task_id=task_id) db.session.add(failure) db.session.commit() @@ -63,6 +54,27 @@ def send_notification_mail(recipients, subject, body): app.mailer.send(recipients, subject, body) +def do_create_application(csp: CloudProviderInterface, application_id=None): + application = Applications.get(application_id) + + with claim_for_update(application) as application: + + if application.cloud_id: + return + + csp_details = application.portfolio.csp_data + parent_id = csp_details.get("root_management_group_id") + tenant_id = csp_details.get("tenant_id") + payload = ApplicationCSPPayload( + tenant_id=tenant_id, display_name=application.name, parent_id=parent_id + ) + + app_result = csp.create_application(payload) + application.cloud_id = app_result.id + db.session.add(application) + db.session.commit() + + def do_create_environment(csp: CloudProviderInterface, environment_id=None): environment = Environments.get(environment_id) @@ -144,17 +156,22 @@ def do_provision_portfolio(csp: CloudProviderInterface, portfolio_id=None): fsm.trigger_next_transition() -@celery.task(bind=True, base=RecordPortfolioFailure) +@celery.task(bind=True, base=RecordFailure) def provision_portfolio(self, portfolio_id=None): do_work(do_provision_portfolio, self, app.csp.cloud, portfolio_id=portfolio_id) -@celery.task(bind=True, base=RecordEnvironmentFailure) +@celery.task(bind=True, base=RecordFailure) +def create_application(self, application_id=None): + do_work(do_create_application, self, app.csp.cloud, application_id=application_id) + + +@celery.task(bind=True, base=RecordFailure) def create_environment(self, environment_id=None): do_work(do_create_environment, self, app.csp.cloud, environment_id=environment_id) -@celery.task(bind=True, base=RecordEnvironmentFailure) +@celery.task(bind=True, base=RecordFailure) def create_atat_admin_user(self, environment_id=None): do_work( do_create_atat_admin_user, self, app.csp.cloud, environment_id=environment_id @@ -177,6 +194,12 @@ def dispatch_provision_portfolio(self): provision_portfolio.delay(portfolio_id=portfolio_id) +@celery.task(bind=True) +def dispatch_create_application(self): + for application_id in Applications.get_applications_pending_creation(): + create_application.delay(application_id=application_id) + + @celery.task(bind=True) def dispatch_create_environment(self): for environment_id in Environments.get_environments_pending_creation( diff --git a/atst/models/__init__.py b/atst/models/__init__.py index f6c48306..dfb1c19d 100644 --- a/atst/models/__init__.py +++ b/atst/models/__init__.py @@ -7,11 +7,7 @@ from .audit_event import AuditEvent from .clin import CLIN, JEDICLINType from .environment import Environment from .environment_role import EnvironmentRole, CSPRole -from .job_failure import ( - EnvironmentJobFailure, - EnvironmentRoleJobFailure, - PortfolioJobFailure, -) +from .job_failure import JobFailure from .notification_recipient import NotificationRecipient from .permissions import Permissions from .permission_set import PermissionSet diff --git a/atst/models/application.py b/atst/models/application.py index a7bdadba..1af9e39f 100644 --- a/atst/models/application.py +++ b/atst/models/application.py @@ -1,4 +1,4 @@ -from sqlalchemy import and_, Column, ForeignKey, String, UniqueConstraint +from sqlalchemy import and_, Column, ForeignKey, String, UniqueConstraint, TIMESTAMP from sqlalchemy.orm import relationship, synonym from atst.models.base import Base @@ -40,6 +40,9 @@ class Application( ), ) + cloud_id = Column(String) + claimed_until = Column(TIMESTAMP(timezone=True)) + @property def users(self): return set(role.user for role in self.members) diff --git a/atst/models/environment.py b/atst/models/environment.py index 115f3ed7..a0713c63 100644 --- a/atst/models/environment.py +++ b/atst/models/environment.py @@ -30,8 +30,6 @@ class Environment( claimed_until = Column(TIMESTAMP(timezone=True)) - job_failures = relationship("EnvironmentJobFailure") - roles = relationship( "EnvironmentRole", back_populates="environment", diff --git a/atst/models/environment_role.py b/atst/models/environment_role.py index 21f033e0..24aaeb7e 100644 --- a/atst/models/environment_role.py +++ b/atst/models/environment_role.py @@ -32,8 +32,6 @@ class EnvironmentRole( ) application_role = relationship("ApplicationRole") - job_failures = relationship("EnvironmentRoleJobFailure") - csp_user_id = Column(String()) claimed_until = Column(TIMESTAMP(timezone=True)) diff --git a/atst/models/job_failure.py b/atst/models/job_failure.py index 7a7f010a..5f9eee6c 100644 --- a/atst/models/job_failure.py +++ b/atst/models/job_failure.py @@ -1,22 +1,21 @@ -from sqlalchemy import Column, ForeignKey +from celery.result import AsyncResult +from sqlalchemy import Column, String, Integer from atst.models.base import Base import atst.models.mixins as mixins -class EnvironmentJobFailure(Base, mixins.JobFailureMixin): - __tablename__ = "environment_job_failures" +class JobFailure(Base, mixins.TimestampsMixin): + __tablename__ = "job_failures" - environment_id = Column(ForeignKey("environments.id"), nullable=False) + id = Column(Integer(), primary_key=True) + task_id = Column(String(), nullable=False) + entity = Column(String(), nullable=False) + entity_id = Column(String(), nullable=False) + @property + def task(self): + if not hasattr(self, "_task"): + self._task = AsyncResult(self.task_id) -class EnvironmentRoleJobFailure(Base, mixins.JobFailureMixin): - __tablename__ = "environment_role_job_failures" - - environment_role_id = Column(ForeignKey("environment_roles.id"), nullable=False) - - -class PortfolioJobFailure(Base, mixins.JobFailureMixin): - __tablename__ = "portfolio_job_failures" - - portfolio_id = Column(ForeignKey("portfolios.id"), nullable=False) + return self._task diff --git a/atst/models/mixins/__init__.py b/atst/models/mixins/__init__.py index 955171ab..e95b2516 100644 --- a/atst/models/mixins/__init__.py +++ b/atst/models/mixins/__init__.py @@ -3,5 +3,4 @@ from .auditable import AuditableMixin from .permissions import PermissionsMixin from .deletable import DeletableMixin from .invites import InvitesMixin -from .job_failure import JobFailureMixin from .state_machines import FSMMixin diff --git a/atst/models/mixins/job_failure.py b/atst/models/mixins/job_failure.py deleted file mode 100644 index c4f4cfa4..00000000 --- a/atst/models/mixins/job_failure.py +++ /dev/null @@ -1,14 +0,0 @@ -from celery.result import AsyncResult -from sqlalchemy import Column, String, Integer - - -class JobFailureMixin(object): - id = Column(Integer(), primary_key=True) - task_id = Column(String(), nullable=False) - - @property - def task(self): - if not hasattr(self, "_task"): - self._task = AsyncResult(self.task_id) - - return self._task diff --git a/atst/models/portfolio_state_machine.py b/atst/models/portfolio_state_machine.py index cf42710b..be9324b1 100644 --- a/atst/models/portfolio_state_machine.py +++ b/atst/models/portfolio_state_machine.py @@ -175,7 +175,7 @@ class PortfolioStateMachine( tenant_id = new_creds.get("tenant_id") secret = self.csp.get_secret(tenant_id, new_creds) secret.update(new_creds) - self.csp.set_secret(tenant_id, secret) + self.csp.update_tenant_creds(tenant_id, secret) except PydanticValidationError as exc: app.logger.error( f"Failed to cast response to valid result class {self.__repr__()}:", diff --git a/atst/queue.py b/atst/queue.py index 1dce690c..70718150 100644 --- a/atst/queue.py +++ b/atst/queue.py @@ -11,6 +11,10 @@ def update_celery(celery, app): "task": "atst.jobs.dispatch_provision_portfolio", "schedule": 60, }, + "beat-dispatch_create_application": { + "task": "atst.jobs.dispatch_create_application", + "schedule": 60, + }, "beat-dispatch_create_environment": { "task": "atst.jobs.dispatch_create_environment", "schedule": 60, diff --git a/atst/utils/__init__.py b/atst/utils/__init__.py index 09c63dea..79d5362a 100644 --- a/atst/utils/__init__.py +++ b/atst/utils/__init__.py @@ -1,3 +1,4 @@ +import hashlib import re from sqlalchemy.exc import IntegrityError @@ -41,3 +42,8 @@ def commit_or_raise_already_exists_error(message): except IntegrityError: db.session.rollback() raise AlreadyExistsError(message) + + +def sha256_hex(string): + hsh = hashlib.sha256(string.encode()) + return hsh.digest().hex() diff --git a/tests/domain/cloud/test_azure_csp.py b/tests/domain/cloud/test_azure_csp.py index 0648ec1e..39fa2f77 100644 --- a/tests/domain/cloud/test_azure_csp.py +++ b/tests/domain/cloud/test_azure_csp.py @@ -1,11 +1,15 @@ -from unittest.mock import Mock +import pytest +import json from uuid import uuid4 +from unittest.mock import Mock, patch from tests.factories import ApplicationFactory, EnvironmentFactory from tests.mock_azure import AUTH_CREDENTIALS, mock_azure from atst.domain.csp.cloud import AzureCloudProvider from atst.domain.csp.cloud.models import ( + ApplicationCSPPayload, + ApplicationCSPResult, BillingInstructionCSPPayload, BillingInstructionCSPResult, BillingProfileCreationCSPPayload, @@ -65,8 +69,8 @@ def test_create_subscription_succeeds(mock_azure: AzureCloudProvider): def mock_management_group_create(mock_azure, spec_dict): - mock_azure.sdk.managementgroups.ManagementGroupsAPI.return_value.management_groups.create_or_update.return_value.result.return_value = Mock( - **spec_dict + mock_azure.sdk.managementgroups.ManagementGroupsAPI.return_value.management_groups.create_or_update.return_value.result.return_value = ( + spec_dict ) @@ -82,12 +86,30 @@ def test_create_environment_succeeds(mock_azure: AzureCloudProvider): assert result.id == "Test Id" +# mock the get_secret so it returns a JSON string +MOCK_CREDS = { + "tenant_id": str(uuid4()), + "tenant_sp_client_id": str(uuid4()), + "tenant_sp_key": "1234", +} + + +def mock_get_secret(azure, func): + azure.get_secret = func + + return azure + + def test_create_application_succeeds(mock_azure: AzureCloudProvider): application = ApplicationFactory.create() - mock_management_group_create(mock_azure, {"id": "Test Id"}) - result = mock_azure._create_application(AUTH_CREDENTIALS, application) + mock_azure = mock_get_secret(mock_azure, lambda *a, **k: json.dumps(MOCK_CREDS)) + + payload = ApplicationCSPPayload( + tenant_id="1234", display_name=application.name, parent_id=str(uuid4()) + ) + result = mock_azure.create_application(payload) assert result.id == "Test Id" diff --git a/tests/domain/cloud/test_models.py b/tests/domain/cloud/test_models.py new file mode 100644 index 00000000..d9fc963d --- /dev/null +++ b/tests/domain/cloud/test_models.py @@ -0,0 +1,99 @@ +import pytest + +from pydantic import ValidationError + +from atst.domain.csp.cloud.models import ( + AZURE_MGMNT_PATH, + KeyVaultCredentials, + ManagementGroupCSPPayload, + ManagementGroupCSPResponse, +) + + +def test_ManagementGroupCSPPayload_management_group_name(): + # supplies management_group_name when absent + payload = ManagementGroupCSPPayload( + tenant_id="any-old-id", + display_name="Council of Naboo", + parent_id="Galactic_Senate", + ) + assert payload.management_group_name + # validates management_group_name + with pytest.raises(ValidationError): + payload = ManagementGroupCSPPayload( + tenant_id="any-old-id", + management_group_name="council of Naboo 1%^&", + display_name="Council of Naboo", + parent_id="Galactic_Senate", + ) + # shortens management_group_name to fit + name = "council_of_naboo".ljust(95, "1") + + assert len(name) > 90 + payload = ManagementGroupCSPPayload( + tenant_id="any-old-id", + management_group_name=name, + display_name="Council of Naboo", + parent_id="Galactic_Senate", + ) + assert len(payload.management_group_name) == 90 + + +def test_ManagementGroupCSPPayload_display_name(): + # shortens display_name to fit + name = "Council of Naboo".ljust(95, "1") + assert len(name) > 90 + payload = ManagementGroupCSPPayload( + tenant_id="any-old-id", display_name=name, parent_id="Galactic_Senate" + ) + assert len(payload.display_name) == 90 + + +def test_ManagementGroupCSPPayload_parent_id(): + full_path = f"{AZURE_MGMNT_PATH}Galactic_Senate" + # adds full path + payload = ManagementGroupCSPPayload( + tenant_id="any-old-id", + display_name="Council of Naboo", + parent_id="Galactic_Senate", + ) + assert payload.parent_id == full_path + # keeps full path + payload = ManagementGroupCSPPayload( + tenant_id="any-old-id", display_name="Council of Naboo", parent_id=full_path + ) + assert payload.parent_id == full_path + + +def test_ManagementGroupCSPResponse_id(): + full_id = "/path/to/naboo-123" + response = ManagementGroupCSPResponse( + **{"id": "/path/to/naboo-123", "other": "stuff"} + ) + assert response.id == full_id + + +def test_KeyVaultCredentials_enforce_admin_creds(): + with pytest.raises(ValidationError): + KeyVaultCredentials(tenant_id="an id", tenant_admin_username="C3PO") + assert KeyVaultCredentials( + tenant_id="an id", + tenant_admin_username="C3PO", + tenant_admin_password="beep boop", + ) + + +def test_KeyVaultCredentials_enforce_sp_creds(): + with pytest.raises(ValidationError): + KeyVaultCredentials(tenant_id="an id", tenant_sp_client_id="C3PO") + assert KeyVaultCredentials( + tenant_id="an id", tenant_sp_client_id="C3PO", tenant_sp_key="beep boop" + ) + + +def test_KeyVaultCredentials_enforce_root_creds(): + with pytest.raises(ValidationError): + KeyVaultCredentials(root_tenant_id="an id", root_sp_client_id="C3PO") + assert KeyVaultCredentials( + root_tenant_id="an id", root_sp_client_id="C3PO", root_sp_key="beep boop" + ) diff --git a/tests/domain/test_applications.py b/tests/domain/test_applications.py index 9fda3114..02dd3124 100644 --- a/tests/domain/test_applications.py +++ b/tests/domain/test_applications.py @@ -1,3 +1,4 @@ +from datetime import datetime, timedelta import pytest from uuid import uuid4 @@ -196,3 +197,20 @@ def test_update_does_not_duplicate_names_within_portfolio(): with pytest.raises(AlreadyExistsError): Applications.update(dupe_application, {"name": name}) + + +def test_get_applications_pending_creation(): + now = datetime.now() + later = now + timedelta(minutes=30) + + portfolio1 = PortfolioFactory.create(state="COMPLETED") + app_ready = ApplicationFactory.create(portfolio=portfolio1) + + app_done = ApplicationFactory.create(portfolio=portfolio1, cloud_id="123456") + + portfolio2 = PortfolioFactory.create(state="UNSTARTED") + app_not_ready = ApplicationFactory.create(portfolio=portfolio2) + + uuids = Applications.get_applications_pending_creation() + + assert [app_ready.id] == uuids diff --git a/tests/factories.py b/tests/factories.py index d9af7c40..b7a63243 100644 --- a/tests/factories.py +++ b/tests/factories.py @@ -7,6 +7,7 @@ import datetime from atst.forms import data from atst.models import * +from atst.models.mixins.state_machines import FSMStates from atst.domain.invitations import PortfolioInvitations from atst.domain.permission_sets import PermissionSets @@ -121,6 +122,7 @@ class PortfolioFactory(Base): owner = kwargs.pop("owner", UserFactory.create()) members = kwargs.pop("members", []) with_task_orders = kwargs.pop("task_orders", []) + state = kwargs.pop("state", None) portfolio = super()._create(model_class, *args, **kwargs) @@ -161,6 +163,12 @@ class PortfolioFactory(Base): permission_sets=perms_set, ) + if state: + state = getattr(FSMStates, state) + fsm = PortfolioStateMachineFactory.create(state=state, portfolio=portfolio) + # setting it in the factory is not working for some reason + fsm.state = state + portfolio.applications = applications portfolio.task_orders = task_orders return portfolio diff --git a/tests/mock_azure.py b/tests/mock_azure.py index 7fa67667..4f37848e 100644 --- a/tests/mock_azure.py +++ b/tests/mock_azure.py @@ -72,6 +72,12 @@ def mock_secrets(): return Mock(spec=secrets) +def mock_identity(): + import azure.identity as identity + + return Mock(spec=identity) + + class MockAzureSDK(object): def __init__(self): from msrestazure.azure_cloud import AZURE_PUBLIC_CLOUD @@ -88,6 +94,7 @@ class MockAzureSDK(object): self.requests = mock_requests() # may change to a JEDI cloud self.cloud = AZURE_PUBLIC_CLOUD + self.identity = mock_identity() @pytest.fixture(scope="function") diff --git a/tests/test_jobs.py b/tests/test_jobs.py index ff8e4602..2ac5f408 100644 --- a/tests/test_jobs.py +++ b/tests/test_jobs.py @@ -8,9 +8,9 @@ from atst.domain.csp.cloud import MockCloudProvider from atst.domain.portfolios import Portfolios from atst.jobs import ( - RecordEnvironmentFailure, - RecordEnvironmentRoleFailure, + RecordFailure, dispatch_create_environment, + dispatch_create_application, dispatch_create_atat_admin_user, dispatch_provision_portfolio, dispatch_provision_user, @@ -18,6 +18,7 @@ from atst.jobs import ( do_provision_user, do_provision_portfolio, do_create_environment, + do_create_application, do_create_atat_admin_user, ) from atst.models.utils import claim_for_update @@ -27,9 +28,10 @@ from tests.factories import ( EnvironmentRoleFactory, PortfolioFactory, PortfolioStateMachineFactory, + ApplicationFactory, ApplicationRoleFactory, ) -from atst.models import CSPRole, EnvironmentRole, ApplicationRoleStatus +from atst.models import CSPRole, EnvironmentRole, ApplicationRoleStatus, JobFailure @pytest.fixture(autouse=True, scope="function") @@ -43,8 +45,17 @@ def portfolio(): return portfolio -def test_environment_job_failure(celery_app, celery_worker): - @celery_app.task(bind=True, base=RecordEnvironmentFailure) +def _find_failure(session, entity, id_): + return ( + session.query(JobFailure) + .filter(JobFailure.entity == entity) + .filter(JobFailure.entity_id == id_) + .one() + ) + + +def test_environment_job_failure(session, celery_app, celery_worker): + @celery_app.task(bind=True, base=RecordFailure) def _fail_hard(self, environment_id=None): raise ValueError("something bad happened") @@ -56,13 +67,12 @@ def test_environment_job_failure(celery_app, celery_worker): with pytest.raises(ValueError): task.get() - assert environment.job_failures - job_failure = environment.job_failures[0] + job_failure = _find_failure(session, "environment", str(environment.id)) assert job_failure.task == task -def test_environment_role_job_failure(celery_app, celery_worker): - @celery_app.task(bind=True, base=RecordEnvironmentRoleFailure) +def test_environment_role_job_failure(session, celery_app, celery_worker): + @celery_app.task(bind=True, base=RecordFailure) def _fail_hard(self, environment_role_id=None): raise ValueError("something bad happened") @@ -74,8 +84,7 @@ def test_environment_role_job_failure(celery_app, celery_worker): with pytest.raises(ValueError): task.get() - assert role.job_failures - job_failure = role.job_failures[0] + job_failure = _find_failure(session, "environment_role", str(role.id)) assert job_failure.task == task @@ -99,6 +108,24 @@ def test_create_environment_job_is_idempotent(csp, session): csp.create_environment.assert_not_called() +def test_create_application_job(session, csp): + portfolio = PortfolioFactory.create( + csp_data={"tenant_id": str(uuid4()), "root_management_group_id": str(uuid4())} + ) + application = ApplicationFactory.create(portfolio=portfolio, cloud_id=None) + do_create_application(csp, application.id) + session.refresh(application) + + assert application.cloud_id + + +def test_create_application_job_is_idempotent(csp): + application = ApplicationFactory.create(cloud_id=uuid4()) + do_create_application(csp, application.id) + + csp.create_application.assert_not_called() + + def test_create_atat_admin_user(csp, session): environment = EnvironmentFactory.create(cloud_id="something") do_create_atat_admin_user(csp, environment.id) @@ -139,6 +166,21 @@ def test_dispatch_create_environment(session, monkeypatch): mock.delay.assert_called_once_with(environment_id=e1.id) +def test_dispatch_create_application(monkeypatch): + portfolio = PortfolioFactory.create(state="COMPLETED") + app = ApplicationFactory.create(portfolio=portfolio) + + mock = Mock() + monkeypatch.setattr("atst.jobs.create_application", mock) + + # When dispatch_create_application is called + dispatch_create_application.run() + + # It should cause the create_application task to be called once + # with the application id + mock.delay.assert_called_once_with(application_id=app.id) + + def test_dispatch_create_atat_admin_user(session, monkeypatch): portfolio = PortfolioFactory.create( applications=[ diff --git a/tests/utils/test_hash.py b/tests/utils/test_hash.py new file mode 100644 index 00000000..5cfb8489 --- /dev/null +++ b/tests/utils/test_hash.py @@ -0,0 +1,16 @@ +import random +import re +import string + +from atst.utils import sha256_hex + + +def test_sha256_hex(): + sample = "".join( + random.choices(string.ascii_uppercase + string.digits, k=random.randrange(200)) + ) + hashed = sha256_hex(sample) + assert re.match("^[a-zA-Z0-9]+$", hashed) + assert len(hashed) == 64 + hashed_again = sha256_hex(sample) + assert hashed == hashed_again