diff --git a/atst/domain/workspace_users.py b/atst/domain/workspace_users.py index 7e217c38..a67c0cbb 100644 --- a/atst/domain/workspace_users.py +++ b/atst/domain/workspace_users.py @@ -21,7 +21,8 @@ class WorkspaceUsers(object): try: workspace_role = ( - WorkspaceRole.query.join(User) + db.session.query(WorkspaceRole) + .join(User) .filter(User.id == user_id, WorkspaceRole.workspace_id == workspace_id) .one() ) diff --git a/atst/domain/workspaces.py b/atst/domain/workspaces.py index 7e675ec1..4ba9c591 100644 --- a/atst/domain/workspaces.py +++ b/atst/domain/workspaces.py @@ -3,6 +3,8 @@ from sqlalchemy.orm.exc import NoResultFound from atst.database import db from atst.domain.exceptions import NotFoundError from atst.models.workspace import Workspace +from atst.models.workspace_role import WorkspaceRole +from atst.domain.roles import Roles class Workspaces(object): @@ -13,7 +15,10 @@ class Workspaces(object): @classmethod def create(cls, request, name=None): name = name or request.id - return Workspace(request=request, name=name) + workspace = Workspace(request=request, name=name) + role = Roles.get("owner") + wr = WorkspaceRole(user_id=request.creator.id, role=role, workspace_id=workspace.id) + return workspace @classmethod def get(cls, workspace_id): diff --git a/atst/models/workspace.py b/atst/models/workspace.py index 56c62188..03370c34 100644 --- a/atst/models/workspace.py +++ b/atst/models/workspace.py @@ -14,3 +14,7 @@ class Workspace(Base): request = relationship("Request") name = Column(String, unique=True) + + @property + def owner(self): + return self.request.creator diff --git a/tests/domain/test_workspaces.py b/tests/domain/test_workspaces.py index 3e44848b..fc52a719 100644 --- a/tests/domain/test_workspaces.py +++ b/tests/domain/test_workspaces.py @@ -3,8 +3,9 @@ from uuid import uuid4 from atst.domain.exceptions import NotFoundError from atst.domain.workspaces import Workspaces +from atst.domain.workspace_users import WorkspaceUsers -from tests.factories import WorkspaceFactory, RequestFactory, TaskOrderFactory +from tests.factories import WorkspaceFactory, RequestFactory, UserFactory def test_can_create_workspace(): @@ -32,3 +33,12 @@ def test_can_get_workspace_by_request(): workspace = WorkspaceFactory.create() found = Workspaces.get_by_request(workspace.request) assert workspace == found + + +def test_creating_workspace_adds_owner(): + user = UserFactory.create() + request = RequestFactory.create(creator=user) + workspace = Workspaces.create(request) + workspace_user = WorkspaceUsers.get(workspace.id, user.id) + assert workspace_user.workspace_role +