diff --git a/atst/domain/invitations.py b/atst/domain/invitations.py index ee518c4e..00797ddd 100644 --- a/atst/domain/invitations.py +++ b/atst/domain/invitations.py @@ -8,6 +8,27 @@ from atst.domain.workspace_users import WorkspaceUsers from .exceptions import NotFoundError +class WrongUserError(Exception): + def __init__(self, user, invite): + self.user = user + self.invite = invite + + @property + def message(self): + return "User {} with DOD ID {} does not match expected DOD ID {} for invitation {}".format( + self.user.id, self.user.dod_id, self.invite.user.dod_id, self.invite.id + ) + + +class ExpiredError(Exception): + def __init__(self, invite): + self.invite = invite + + @property + def message(self): + return "Invitation {} has expired.".format(self.invite.id) + + class InvitationError(Exception): def __init__(self, invite): self.invite = invite @@ -45,26 +66,36 @@ class Invitations(object): return invite @classmethod - def accept(cls, token): + def accept(cls, user, token): invite = Invitations._get(token) - if invite.is_expired: - invite.status = InvitationStatus.REJECTED - elif invite.is_pending: - invite.status = InvitationStatus.ACCEPTED + if invite.user.dod_id != user.dod_id: + if invite.is_pending: + Invitations._update_status(invite, InvitationStatus.REJECTED) + raise WrongUserError(user, invite) - db.session.add(invite) - db.session.commit() + elif invite.is_expired: + Invitations._update_status(invite, InvitationStatus.REJECTED) + raise ExpiredError(invite) - if invite.is_revoked or invite.is_rejected: + elif invite.is_accepted or invite.is_revoked or invite.is_rejected: raise InvitationError(invite) - WorkspaceUsers.enable(invite.workspace_role) - - return invite + elif invite.is_pending: + Invitations._update_status(invite, InvitationStatus.ACCEPTED) + WorkspaceUsers.enable(invite.workspace_role) + return invite @classmethod def current_expiration_time(cls): return datetime.datetime.now() + datetime.timedelta( minutes=Invitations.EXPIRATION_LIMIT_MINUTES ) + + @classmethod + def _update_status(cls, invite, new_status): + invite.status = new_status + db.session.add(invite) + db.session.commit() + + return invite diff --git a/atst/routes/__init__.py b/atst/routes/__init__.py index 0149ace5..780ff0c3 100644 --- a/atst/routes/__init__.py +++ b/atst/routes/__init__.py @@ -2,6 +2,7 @@ import urllib.parse as url from flask import Blueprint, render_template, g, redirect, session, url_for, request from flask import current_app as app +from jinja2.exceptions import TemplateNotFound import pendulum import os @@ -10,6 +11,7 @@ from atst.domain.users import Users from atst.domain.authnid import AuthenticationContext from atst.domain.audit_log import AuditLog from atst.domain.auth import logout as _logout +from werkzeug.exceptions import NotFound bp = Blueprint("atst", __name__) @@ -77,7 +79,10 @@ def styleguide(): @bp.route("/") def catch_all(path): - return render_template("{}.html".format(path)) + try: + return render_template("{}.html".format(path)) + except TemplateNotFound: + raise NotFound() def _make_authentication_context(): diff --git a/atst/routes/errors.py b/atst/routes/errors.py index 4871f89d..7cf58b46 100644 --- a/atst/routes/errors.py +++ b/atst/routes/errors.py @@ -3,27 +3,35 @@ from flask_wtf.csrf import CSRFError import werkzeug.exceptions as werkzeug_exceptions import atst.domain.exceptions as exceptions -from atst.domain.invitations import InvitationError +from atst.domain.invitations import ( + InvitationError, + ExpiredError as InvitationExpiredError, + WrongUserError as InvitationWrongUserError, +) + + +def log_error(e): + error_message = e.message if hasattr(e, "message") else str(e) + current_app.logger.error(error_message) + + +def handle_error(e, message="Not Found", code=404): + log_error(e) + return render_template("error.html", message=message), code def make_error_pages(app): - def log_error(e): - error_message = e.message if hasattr(e, "message") else str(e) - app.logger.error(error_message) - @app.errorhandler(werkzeug_exceptions.NotFound) @app.errorhandler(exceptions.NotFoundError) @app.errorhandler(exceptions.UnauthorizedError) # pylint: disable=unused-variable def not_found(e): - log_error(e) - return render_template("error.html", message="Not Found"), 404 + return handle_error(e) @app.errorhandler(exceptions.UnauthenticatedError) # pylint: disable=unused-variable def unauthorized(e): - log_error(e) - return render_template("error.html", message="Log in Failed"), 401 + return handle_error(e, message="Log in Failed", code=401) @app.errorhandler(CSRFError) # pylint: disable=unused-variable @@ -43,14 +51,16 @@ def make_error_pages(app): ) @app.errorhandler(InvitationError) + @app.errorhandler(InvitationWrongUserError) # pylint: disable=unused-variable def invalid_invitation(e): - log_error(e) - return ( - render_template( - "error.html", message="The invitation link you clicked is invalid." - ), - 404, + return handle_error(e, message="The link you followed is invalid.", code=404) + + @app.errorhandler(InvitationExpiredError) + # pylint: disable=unused-variable + def invalid_invitation(e): + return handle_error( + e, message="The invitation you followed has expired.", code=404 ) return app diff --git a/atst/routes/workspaces.py b/atst/routes/workspaces.py index 726e1bac..89731326 100644 --- a/atst/routes/workspaces.py +++ b/atst/routes/workspaces.py @@ -361,9 +361,7 @@ def update_member(workspace_id, member_id): @bp.route("/workspaces/invitation/", methods=["GET"]) def accept_invitation(token): - # TODO: check that the current_user DOD ID matches the user associated with - # the invitation - invite = Invitations.accept(token) + invite = Invitations.accept(g.current_user, token) return redirect( url_for("workspaces.show_workspace", workspace_id=invite.workspace.id) diff --git a/templates/error.html b/templates/error.html index c8714c01..f58eac86 100644 --- a/templates/error.html +++ b/templates/error.html @@ -10,6 +10,10 @@

An error occurred.

{% endif %} +{% if g.current_user %} +

Return home.

+{% endif %} + {% endblock %} diff --git a/tests/domain/test_invitations.py b/tests/domain/test_invitations.py index 2ec08a9f..059a8927 100644 --- a/tests/domain/test_invitations.py +++ b/tests/domain/test_invitations.py @@ -2,7 +2,12 @@ import datetime import pytest import re -from atst.domain.invitations import Invitations, InvitationError +from atst.domain.invitations import ( + Invitations, + InvitationError, + WrongUserError, + ExpiredError, +) from atst.models.invitation import Status from tests.factories import ( @@ -31,7 +36,7 @@ def test_accept_invitation(): ws_role = WorkspaceRoleFactory.create(user=user, workspace=workspace) invite = Invitations.create(ws_role, workspace.owner, user) assert invite.is_pending - accepted_invite = Invitations.accept(invite.token) + accepted_invite = Invitations.accept(user, invite.token) assert accepted_invite.is_accepted @@ -42,8 +47,8 @@ def test_accept_expired_invitation(): invite = InvitationFactory.create( user_id=user.id, expiration_time=expiration_time, status=Status.PENDING ) - with pytest.raises(InvitationError): - Invitations.accept(invite.token) + with pytest.raises(ExpiredError): + Invitations.accept(user, invite.token) assert invite.is_rejected @@ -52,11 +57,39 @@ def test_accept_rejected_invite(): user = UserFactory.create() invite = InvitationFactory.create(user_id=user.id, status=Status.REJECTED) with pytest.raises(InvitationError): - Invitations.accept(invite.token) + Invitations.accept(user, invite.token) def test_accept_revoked_invite(): user = UserFactory.create() invite = InvitationFactory.create(user_id=user.id, status=Status.REVOKED) with pytest.raises(InvitationError): - Invitations.accept(invite.token) + Invitations.accept(user, invite.token) + + +def test_wrong_user_accepts_invitation(): + user = UserFactory.create() + wrong_user = UserFactory.create() + invite = InvitationFactory.create(user_id=user.id) + with pytest.raises(WrongUserError): + Invitations.accept(wrong_user, invite.token) + + +def test_user_cannot_accept_invitation_accepted_by_wrong_user(): + user = UserFactory.create() + wrong_user = UserFactory.create() + invite = InvitationFactory.create(user_id=user.id) + with pytest.raises(WrongUserError): + Invitations.accept(wrong_user, invite.token) + with pytest.raises(InvitationError): + Invitations.accept(user, invite.token) + + +def test_accept_invitation_twice(): + workspace = WorkspaceFactory.create() + user = UserFactory.create() + ws_role = WorkspaceRoleFactory.create(user=user, workspace=workspace) + invite = Invitations.create(ws_role, workspace.owner, user) + Invitations.accept(user, invite.token) + with pytest.raises(InvitationError): + Invitations.accept(user, invite.token) diff --git a/tests/routes/test_workspaces.py b/tests/routes/test_workspaces.py index d4f543de..be7097f2 100644 --- a/tests/routes/test_workspaces.py +++ b/tests/routes/test_workspaces.py @@ -1,3 +1,4 @@ +import datetime from flask import url_for from tests.factories import ( @@ -15,6 +16,7 @@ from atst.models.workspace_user import WorkspaceUser from atst.models.workspace_role import Status as WorkspaceRoleStatus from atst.models.invitation import Status as InvitationStatus from atst.queue import queue +from atst.domain.users import Users def test_user_with_permission_has_budget_report_link(client, user_session): @@ -299,7 +301,7 @@ def test_update_member_environment_role_with_no_data(client, user_session): assert EnvironmentRoles.get(user.id, env1_id).role == "developer" -def test_new_member_accepts_valid_invite(client, user_session): +def test_existing_member_accepts_valid_invite(client, user_session): workspace = WorkspaceFactory.create() user = UserFactory.create() ws_role = WorkspaceRoleFactory.create( @@ -325,7 +327,36 @@ def test_new_member_accepts_valid_invite(client, user_session): assert len(Workspaces.for_user(user)) == 1 -def test_new_member_accept_invalid_invite(client, user_session): +def test_new_member_accepts_valid_invite(monkeypatch, client, user_session): + workspace = WorkspaceFactory.create() + user_info = UserFactory.dictionary() + + user_session(workspace.owner) + client.post( + url_for("workspaces.create_member", workspace_id=workspace.id), + data={"workspace_role": "developer", **user_info}, + ) + + user = Users.get_by_dod_id(user_info["dod_id"]) + token = user.invitations[0].token + + monkeypatch.setattr( + "atst.domain.auth.should_redirect_to_user_profile", lambda *args: False + ) + user_session(user) + response = client.get(url_for("workspaces.accept_invitation", token=token)) + + # user is redirected to the workspace view + assert response.status_code == 302 + assert ( + url_for("workspaces.show_workspace", workspace_id=workspace.id) + in response.headers["Location"] + ) + # the user has access to the workspace + assert len(Workspaces.for_user(user)) == 1 + + +def test_member_accepts_invalid_invite(client, user_session): workspace = WorkspaceFactory.create() user = UserFactory.create() ws_role = WorkspaceRoleFactory.create( @@ -355,3 +386,35 @@ def test_user_who_has_not_accepted_workspace_invite_cannot_view(client, user_ses user_session(user) response = client.get("/workspaces/{}/projects".format(workspace.id)) assert response.status_code == 404 + + +def test_user_accepts_invite_with_wrong_dod_id(client, user_session): + workspace = WorkspaceFactory.create() + user = UserFactory.create() + different_user = UserFactory.create() + ws_role = WorkspaceRoleFactory.create( + user=user, workspace=workspace, status=WorkspaceRoleStatus.PENDING + ) + invite = InvitationFactory.create(user_id=user.id, workspace_role_id=ws_role.id) + user_session(different_user) + response = client.get(url_for("workspaces.accept_invitation", token=invite.token)) + + assert response.status_code == 404 + + +def test_user_accepts_expired_invite(client, user_session): + workspace = WorkspaceFactory.create() + user = UserFactory.create() + ws_role = WorkspaceRoleFactory.create( + user=user, workspace=workspace, status=WorkspaceRoleStatus.PENDING + ) + invite = InvitationFactory.create( + user_id=user.id, + workspace_role_id=ws_role.id, + status=InvitationStatus.REJECTED, + expiration_time=datetime.datetime.now() - datetime.timedelta(seconds=1), + ) + user_session(user) + response = client.get(url_for("workspaces.accept_invitation", token=invite.token)) + + assert response.status_code == 404