From 87baa1f873136318c62d384b9fcd6b8426006ab2 Mon Sep 17 00:00:00 2001 From: dandds Date: Wed, 31 Oct 2018 12:06:42 -0400 Subject: [PATCH] more fine-grained errors for invalid invitations --- atst/domain/invitations.py | 41 ++++++++++++++++++++++---------- atst/routes/errors.py | 39 ++++++++++++++++++------------ tests/domain/test_invitations.py | 18 ++++++++++++-- tests/routes/test_workspaces.py | 28 ++++++++++++++++------ 4 files changed, 90 insertions(+), 36 deletions(-) diff --git a/atst/domain/invitations.py b/atst/domain/invitations.py index 65df2390..a3946399 100644 --- a/atst/domain/invitations.py +++ b/atst/domain/invitations.py @@ -15,7 +15,19 @@ class WrongUserError(Exception): @property def message(self): - return "User {} with DOD ID {} does not match expected DOD ID {} for invitation {}".format(self.user.id, self.user.dod_id, self.invite.user.dod_id, self.invite.id) + return "User {} with DOD ID {} does not match expected DOD ID {} for invitation {}".format( + self.user.id, self.user.dod_id, self.invite.user.dod_id, self.invite.id + ) + + +class ExpiredError(Exception): + def __init__(self, invite): + self.invite = invite + + @property + def message(self): + return "Invitation {} has expired.".format(self.invite.id) + class InvitationError(Exception): def __init__(self, invite): @@ -60,23 +72,28 @@ class Invitations(object): if invite.user.dod_id != user.dod_id: raise WrongUserError(user, invite) - if invite.is_expired: - invite.status = InvitationStatus.REJECTED - elif invite.is_pending: - invite.status = InvitationStatus.ACCEPTED + elif invite.is_expired: + Invitations._update_status(invite, InvitationStatus.REJECTED) + raise ExpiredError(invite) - db.session.add(invite) - db.session.commit() - - if invite.is_revoked or invite.is_rejected: + elif invite.is_accepted or invite.is_revoked or invite.is_rejected: raise InvitationError(invite) - WorkspaceUsers.enable(invite.workspace_role) - - return invite + elif invite.is_pending: + Invitations._update_status(invite, InvitationStatus.ACCEPTED) + WorkspaceUsers.enable(invite.workspace_role) + return invite @classmethod def current_expiration_time(cls): return datetime.datetime.now() + datetime.timedelta( minutes=Invitations.EXPIRATION_LIMIT_MINUTES ) + + @classmethod + def _update_status(cls, invite, new_status): + invite.status = new_status + db.session.add(invite) + db.session.commit() + + return invite diff --git a/atst/routes/errors.py b/atst/routes/errors.py index 49b5fbc9..7cf58b46 100644 --- a/atst/routes/errors.py +++ b/atst/routes/errors.py @@ -3,27 +3,35 @@ from flask_wtf.csrf import CSRFError import werkzeug.exceptions as werkzeug_exceptions import atst.domain.exceptions as exceptions -from atst.domain.invitations import InvitationError, WrongUserError as InvitationWrongUserError +from atst.domain.invitations import ( + InvitationError, + ExpiredError as InvitationExpiredError, + WrongUserError as InvitationWrongUserError, +) + + +def log_error(e): + error_message = e.message if hasattr(e, "message") else str(e) + current_app.logger.error(error_message) + + +def handle_error(e, message="Not Found", code=404): + log_error(e) + return render_template("error.html", message=message), code def make_error_pages(app): - def log_error(e): - error_message = e.message if hasattr(e, "message") else str(e) - app.logger.error(error_message) - @app.errorhandler(werkzeug_exceptions.NotFound) @app.errorhandler(exceptions.NotFoundError) @app.errorhandler(exceptions.UnauthorizedError) # pylint: disable=unused-variable def not_found(e): - log_error(e) - return render_template("error.html", message="Not Found"), 404 + return handle_error(e) @app.errorhandler(exceptions.UnauthenticatedError) # pylint: disable=unused-variable def unauthorized(e): - log_error(e) - return render_template("error.html", message="Log in Failed"), 401 + return handle_error(e, message="Log in Failed", code=401) @app.errorhandler(CSRFError) # pylint: disable=unused-variable @@ -46,12 +54,13 @@ def make_error_pages(app): @app.errorhandler(InvitationWrongUserError) # pylint: disable=unused-variable def invalid_invitation(e): - log_error(e) - return ( - render_template( - "error.html", message="The link you followed is invalid." - ), - 404, + return handle_error(e, message="The link you followed is invalid.", code=404) + + @app.errorhandler(InvitationExpiredError) + # pylint: disable=unused-variable + def invalid_invitation(e): + return handle_error( + e, message="The invitation you followed has expired.", code=404 ) return app diff --git a/tests/domain/test_invitations.py b/tests/domain/test_invitations.py index 2f677357..5a2198d9 100644 --- a/tests/domain/test_invitations.py +++ b/tests/domain/test_invitations.py @@ -2,7 +2,12 @@ import datetime import pytest import re -from atst.domain.invitations import Invitations, InvitationError, WrongUserError +from atst.domain.invitations import ( + Invitations, + InvitationError, + WrongUserError, + ExpiredError, +) from atst.models.invitation import Status from tests.factories import ( @@ -42,7 +47,7 @@ def test_accept_expired_invitation(): invite = InvitationFactory.create( user_id=user.id, expiration_time=expiration_time, status=Status.PENDING ) - with pytest.raises(InvitationError): + with pytest.raises(ExpiredError): Invitations.accept(user, invite.token) assert invite.is_rejected @@ -69,3 +74,12 @@ def test_wrong_user_accepts_invitation(): with pytest.raises(WrongUserError): Invitations.accept(wrong_user, invite.token) + +def test_accept_invitation_twice(): + workspace = WorkspaceFactory.create() + user = UserFactory.create() + ws_role = WorkspaceRoleFactory.create(user=user, workspace=workspace) + invite = Invitations.create(ws_role, workspace.owner, user) + Invitations.accept(user, invite.token) + with pytest.raises(InvitationError): + Invitations.accept(user, invite.token) diff --git a/tests/routes/test_workspaces.py b/tests/routes/test_workspaces.py index af1402e3..29bc4f87 100644 --- a/tests/routes/test_workspaces.py +++ b/tests/routes/test_workspaces.py @@ -1,3 +1,4 @@ +import datetime from flask import url_for from tests.factories import ( @@ -333,10 +334,7 @@ def test_new_member_accepts_valid_invite(client, user_session): user_session(workspace.owner) client.post( url_for("workspaces.create_member", workspace_id=workspace.id), - data={ - "workspace_role": "developer", - **user_info, - } + data={"workspace_role": "developer", **user_info}, ) user = Users.get_by_dod_id(user_info["dod_id"]) @@ -394,10 +392,26 @@ def test_user_accepts_invite_with_wrong_dod_id(client, user_session): ws_role = WorkspaceRoleFactory.create( user=user, workspace=workspace, status=WorkspaceRoleStatus.PENDING ) - invite = InvitationFactory.create( - user_id=user.id, workspace_role_id=ws_role.id - ) + invite = InvitationFactory.create(user_id=user.id, workspace_role_id=ws_role.id) user_session(different_user) response = client.get(url_for("workspaces.accept_invitation", token=invite.token)) assert response.status_code == 404 + + +def test_user_accepts_expired_invite(client, user_session): + workspace = WorkspaceFactory.create() + user = UserFactory.create() + ws_role = WorkspaceRoleFactory.create( + user=user, workspace=workspace, status=WorkspaceRoleStatus.PENDING + ) + invite = InvitationFactory.create( + user_id=user.id, + workspace_role_id=ws_role.id, + status=InvitationStatus.REJECTED, + expiration_time=datetime.datetime.now() - datetime.timedelta(seconds=1), + ) + user_session(user) + response = client.get(url_for("workspaces.accept_invitation", token=invite.token)) + + assert response.status_code == 404