diff --git a/atst/domain/workspaces.py b/atst/domain/workspaces.py index 3997ab9f..cd161267 100644 --- a/atst/domain/workspaces.py +++ b/atst/domain/workspaces.py @@ -1,7 +1,7 @@ from sqlalchemy.orm.exc import NoResultFound from atst.database import db -from atst.domain.exceptions import NotFoundError +from atst.domain.exceptions import NotFoundError, UnauthorizedError from atst.models.workspace import Workspace from atst.models.workspace_role import WorkspaceRole from atst.domain.roles import Roles @@ -27,12 +27,15 @@ class Workspaces(object): return workspace @classmethod - def get(cls, workspace_id): + def get(cls, user, workspace_id): try: workspace = db.session.query(Workspace).filter_by(id=workspace_id).one() except NoResultFound: raise NotFoundError("workspace") + if user not in workspace.users: + raise UnauthorizedError(user, "get workspace") + return workspace @classmethod diff --git a/atst/models/workspace.py b/atst/models/workspace.py index 151889a6..1f82a594 100644 --- a/atst/models/workspace.py +++ b/atst/models/workspace.py @@ -19,3 +19,7 @@ class Workspace(Base, TimestampsMixin): @property def owner(self): return self.request.creator + + @property + def users(self): + return set(role.user for role in self.roles) diff --git a/tests/domain/test_workspaces.py b/tests/domain/test_workspaces.py index 29dc231e..abdd2cca 100644 --- a/tests/domain/test_workspaces.py +++ b/tests/domain/test_workspaces.py @@ -1,9 +1,8 @@ import pytest from uuid import uuid4 -from atst.domain.exceptions import NotFoundError +from atst.domain.exceptions import NotFoundError, UnauthorizedError from atst.domain.workspaces import Workspaces -from atst.domain.workspace_users import WorkspaceUsers from tests.factories import WorkspaceFactory, RequestFactory, UserFactory @@ -21,15 +20,9 @@ def test_default_workspace_name_is_request_id(): assert workspace.name == str(request.id) -def test_can_get_workspace(): - workspace = WorkspaceFactory.create() - found = Workspaces.get(workspace.id) - assert workspace == found - - -def test_nonexistent_workspace_raises(): +def test_get_nonexistent_workspace_raises(): with pytest.raises(NotFoundError): - Workspaces.get(uuid4()) + Workspaces.get(UserFactory.build(), uuid4()) def test_can_get_workspace_by_request(): @@ -42,8 +35,7 @@ def test_creating_workspace_adds_owner(): user = UserFactory.create() request = RequestFactory.create(creator=user) workspace = Workspaces.create(request) - workspace_user = WorkspaceUsers.get(workspace.id, user.id) - assert workspace_user.workspace_role + assert workspace.roles[0].user == user def test_workspace_has_timestamps(): @@ -52,7 +44,13 @@ def test_workspace_has_timestamps(): assert workspace.time_created == workspace.time_updated -def test_workspace_has_roles(): - request = RequestFactory.create() - workspace = Workspaces.create(request) - assert workspace.roles[0].user == request.creator +def test_workspaces_get_ensures_user_is_in_workspace(): + owner = UserFactory.create() + outside_user = UserFactory.create() + workspace = Workspaces.create(RequestFactory.create(creator=owner)) + + workspace_ = Workspaces.get(owner, workspace.id) + assert workspace_ == workspace + + with pytest.raises(UnauthorizedError): + Workspaces.get(outside_user, workspace.id)