diff --git a/atst/domain/invitations.py b/atst/domain/invitations.py index ee518c4e..65df2390 100644 --- a/atst/domain/invitations.py +++ b/atst/domain/invitations.py @@ -8,6 +8,15 @@ from atst.domain.workspace_users import WorkspaceUsers from .exceptions import NotFoundError +class WrongUserError(Exception): + def __init__(self, user, invite): + self.user = user + self.invite = invite + + @property + def message(self): + return "User {} with DOD ID {} does not match expected DOD ID {} for invitation {}".format(self.user.id, self.user.dod_id, self.invite.user.dod_id, self.invite.id) + class InvitationError(Exception): def __init__(self, invite): self.invite = invite @@ -45,9 +54,12 @@ class Invitations(object): return invite @classmethod - def accept(cls, token): + def accept(cls, user, token): invite = Invitations._get(token) + if invite.user.dod_id != user.dod_id: + raise WrongUserError(user, invite) + if invite.is_expired: invite.status = InvitationStatus.REJECTED elif invite.is_pending: diff --git a/atst/routes/errors.py b/atst/routes/errors.py index 4871f89d..49b5fbc9 100644 --- a/atst/routes/errors.py +++ b/atst/routes/errors.py @@ -3,7 +3,7 @@ from flask_wtf.csrf import CSRFError import werkzeug.exceptions as werkzeug_exceptions import atst.domain.exceptions as exceptions -from atst.domain.invitations import InvitationError +from atst.domain.invitations import InvitationError, WrongUserError as InvitationWrongUserError def make_error_pages(app): @@ -43,12 +43,13 @@ def make_error_pages(app): ) @app.errorhandler(InvitationError) + @app.errorhandler(InvitationWrongUserError) # pylint: disable=unused-variable def invalid_invitation(e): log_error(e) return ( render_template( - "error.html", message="The invitation link you clicked is invalid." + "error.html", message="The link you followed is invalid." ), 404, ) diff --git a/atst/routes/workspaces.py b/atst/routes/workspaces.py index 726e1bac..35807e00 100644 --- a/atst/routes/workspaces.py +++ b/atst/routes/workspaces.py @@ -363,7 +363,7 @@ def update_member(workspace_id, member_id): def accept_invitation(token): # TODO: check that the current_user DOD ID matches the user associated with # the invitation - invite = Invitations.accept(token) + invite = Invitations.accept(g.current_user, token) return redirect( url_for("workspaces.show_workspace", workspace_id=invite.workspace.id) diff --git a/tests/domain/test_invitations.py b/tests/domain/test_invitations.py index 2ec08a9f..2f677357 100644 --- a/tests/domain/test_invitations.py +++ b/tests/domain/test_invitations.py @@ -2,7 +2,7 @@ import datetime import pytest import re -from atst.domain.invitations import Invitations, InvitationError +from atst.domain.invitations import Invitations, InvitationError, WrongUserError from atst.models.invitation import Status from tests.factories import ( @@ -31,7 +31,7 @@ def test_accept_invitation(): ws_role = WorkspaceRoleFactory.create(user=user, workspace=workspace) invite = Invitations.create(ws_role, workspace.owner, user) assert invite.is_pending - accepted_invite = Invitations.accept(invite.token) + accepted_invite = Invitations.accept(user, invite.token) assert accepted_invite.is_accepted @@ -43,7 +43,7 @@ def test_accept_expired_invitation(): user_id=user.id, expiration_time=expiration_time, status=Status.PENDING ) with pytest.raises(InvitationError): - Invitations.accept(invite.token) + Invitations.accept(user, invite.token) assert invite.is_rejected @@ -52,11 +52,20 @@ def test_accept_rejected_invite(): user = UserFactory.create() invite = InvitationFactory.create(user_id=user.id, status=Status.REJECTED) with pytest.raises(InvitationError): - Invitations.accept(invite.token) + Invitations.accept(user, invite.token) def test_accept_revoked_invite(): user = UserFactory.create() invite = InvitationFactory.create(user_id=user.id, status=Status.REVOKED) with pytest.raises(InvitationError): - Invitations.accept(invite.token) + Invitations.accept(user, invite.token) + + +def test_wrong_user_accepts_invitation(): + user = UserFactory.create() + wrong_user = UserFactory.create() + invite = InvitationFactory.create(user_id=user.id) + with pytest.raises(WrongUserError): + Invitations.accept(wrong_user, invite.token) + diff --git a/tests/routes/test_workspaces.py b/tests/routes/test_workspaces.py index d4f543de..da0d8a1a 100644 --- a/tests/routes/test_workspaces.py +++ b/tests/routes/test_workspaces.py @@ -355,3 +355,19 @@ def test_user_who_has_not_accepted_workspace_invite_cannot_view(client, user_ses user_session(user) response = client.get("/workspaces/{}/projects".format(workspace.id)) assert response.status_code == 404 + + +def test_user_accepts_invite_with_wrong_dod_id(client, user_session): + workspace = WorkspaceFactory.create() + user = UserFactory.create() + different_user = UserFactory.create() + ws_role = WorkspaceRoleFactory.create( + user=user, workspace=workspace, status=WorkspaceRoleStatus.PENDING + ) + invite = InvitationFactory.create( + user_id=user.id, workspace_role_id=ws_role.id + ) + user_session(different_user) + response = client.get(url_for("workspaces.accept_invitation", token=invite.token)) + + assert response.status_code == 404