diff --git a/atst/domain/environments.py b/atst/domain/environments.py index 74976544..491546b2 100644 --- a/atst/domain/environments.py +++ b/atst/domain/environments.py @@ -21,9 +21,13 @@ class Environments(object): @classmethod def create_many(cls, project, names): + environments = [] for name in names: environment = Environment(project=project, name=name) - db.session.add(environment) + environments.append(environment) + + db.session.add_all(environments) + return environments @classmethod def add_member(cls, environment, user, role): @@ -31,8 +35,6 @@ class Environments(object): user=user, environment=environment, role=role ) db.session.add(environment_user) - db.session.commit() - return environment @classmethod diff --git a/atst/domain/projects.py b/atst/domain/projects.py index 663a12f4..e57c9fb1 100644 --- a/atst/domain/projects.py +++ b/atst/domain/projects.py @@ -5,7 +5,6 @@ from atst.domain.exceptions import NotFoundError from atst.models.permissions import Permissions from atst.models.project import Project from atst.models.environment import Environment -from atst.models.environment_role import EnvironmentRole class Projects(object): @@ -16,9 +15,6 @@ class Projects(object): Environments.create_many(project, environment_names) - for environment in project.environments: - Environments.add_member(user, environment, user) - db.session.commit() return project diff --git a/tests/domain/test_projects.py b/tests/domain/test_projects.py index 6a75a792..e473f4df 100644 --- a/tests/domain/test_projects.py +++ b/tests/domain/test_projects.py @@ -1,5 +1,5 @@ from atst.domain.projects import Projects -from tests.factories import RequestFactory +from tests.factories import RequestFactory, UserFactory from atst.domain.workspaces import Workspaces @@ -14,3 +14,16 @@ def test_create_project_with_multiple_environments(): assert project.name == "My Test Project" assert project.description == "Test" assert sorted(e.name for e in project.environments) == ["dev", "prod"] + + +def test_workspace_owner_can_view_environments(): + owner = UserFactory.create() + request = RequestFactory.create(creator=owner) + workspace = Workspaces.create(request) + _project = Projects.create( + owner, workspace, "My Test Project", "Test", ["dev", "prod"] + ) + + project = Projects.get(owner, workspace, _project.id) + + assert len(project.environments) == 2