diff --git a/alembic/versions/03654d08f5ff_add_status_to_invitation.py b/alembic/versions/03654d08f5ff_add_status_to_invitation.py new file mode 100644 index 00000000..0d06b1e5 --- /dev/null +++ b/alembic/versions/03654d08f5ff_add_status_to_invitation.py @@ -0,0 +1,30 @@ +"""add status to invitation + +Revision ID: 03654d08f5ff +Revises: 2bec1868a22a +Create Date: 2018-10-26 12:59:16.709080 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '03654d08f5ff' +down_revision = '2bec1868a22a' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('invitations', sa.Column('status', sa.Enum('ACCEPTED', 'REVOKED', 'PENDING', 'REJECTED', name='status', native_enum=False), nullable=True)) + op.drop_column('invitations', 'valid') + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('invitations', sa.Column('valid', sa.BOOLEAN(), autoincrement=False, nullable=True)) + op.drop_column('invitations', 'status') + # ### end Alembic commands ### diff --git a/atst/domain/invitations.py b/atst/domain/invitations.py index 0f36cecb..c2e2be2e 100644 --- a/atst/domain/invitations.py +++ b/atst/domain/invitations.py @@ -2,18 +2,18 @@ import datetime from sqlalchemy.orm.exc import NoResultFound from atst.database import db -from atst.models import Invitation +from atst.models.invitation import Invitation, Status as InvitationStatus from .exceptions import NotFoundError -class InvitationExpired(Exception): - def __init__(self, invite_id): - self.invite_id = invite_id +class InvitationError(Exception): + def __init__(self, invite): + self.invite = invite @property def message(self): - return "{} has expired".format(self.invite_id) + return "{} has a status of {}".format(self.invite.id, self.invite.status.value) class Invitations(object): @@ -31,7 +31,12 @@ class Invitations(object): @classmethod def create(cls, workspace, inviter, user): - invite = Invitation(workspace=workspace, inviter=inviter, user=user, valid=True) + invite = Invitation( + workspace=workspace, + inviter=inviter, + user=user, + status=InvitationStatus.PENDING, + ) db.session.add(invite) db.session.commit() @@ -40,23 +45,22 @@ class Invitations(object): @classmethod def accept(cls, invite_id): invite = Invitations._get(invite_id) - valid = Invitations.is_valid(invite) - invite.valid = False + if Invitations._is_expired(invite): + invite.status = InvitationStatus.REJECTED + elif invite.is_pending: + invite.status = InvitationStatus.ACCEPTED + db.session.add(invite) db.session.commit() - if not valid: - raise InvitationExpired(invite_id) + if invite.is_revoked or invite.is_rejected: + raise InvitationError(invite) return invite @classmethod - def is_valid(cls, invite): - return invite.valid and not Invitations.is_expired(invite) - - @classmethod - def is_expired(cls, invite): + def _is_expired(cls, invite): time_created = invite.time_created expiration = datetime.datetime.now(time_created.tzinfo) - datetime.timedelta( minutes=Invitations.EXPIRATION_LIMIT_MINUTES diff --git a/atst/models/invitation.py b/atst/models/invitation.py index 95d170ac..0d8bc8d8 100644 --- a/atst/models/invitation.py +++ b/atst/models/invitation.py @@ -1,4 +1,6 @@ -from sqlalchemy import Column, ForeignKey, Boolean +from enum import Enum + +from sqlalchemy import Column, ForeignKey, Enum as SQLAEnum from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.orm import relationship @@ -6,6 +8,13 @@ from atst.models import Base, types from atst.models.mixins.timestamps import TimestampsMixin +class Status(Enum): + ACCEPTED = "accepted" + REVOKED = "revoked" + PENDING = "pending" + REJECTED = "rejected" + + class Invitation(Base, TimestampsMixin): __tablename__ = "invitations" @@ -20,9 +29,25 @@ class Invitation(Base, TimestampsMixin): inviter_id = Column(UUID(as_uuid=True), ForeignKey("users.id"), index=True) inviter = relationship("User", backref="sent_invites", foreign_keys=[inviter_id]) - valid = Column(Boolean, default=True) + status = Column(SQLAEnum(Status, native_enum=False, default=Status.PENDING)) def __repr__(self): return "".format( self.user.id, self.workspace.id, self.id ) + + @property + def is_accepted(self): + return self.status == Status.ACCEPTED + + @property + def is_revoked(self): + return self.status == Status.REVOKED + + @property + def is_pending(self): + return self.status == Status.PENDING + + @property + def is_rejected(self): + return self.status == Status.REJECTED diff --git a/atst/routes/errors.py b/atst/routes/errors.py index b579996d..34830911 100644 --- a/atst/routes/errors.py +++ b/atst/routes/errors.py @@ -3,7 +3,7 @@ from flask_wtf.csrf import CSRFError import werkzeug.exceptions as werkzeug_exceptions import atst.domain.exceptions as exceptions -from atst.domain.invitations import InvitationExpired +from atst.domain.invitations import InvitationError def make_error_pages(app): @@ -42,7 +42,7 @@ def make_error_pages(app): 500, ) - @app.errorhandler(InvitationExpired) + @app.errorhandler(InvitationError) # pylint: disable=unused-variable def expired_invitation(e): log_error(e) diff --git a/tests/domain/test_invitations.py b/tests/domain/test_invitations.py index c0fc8901..864fa79d 100644 --- a/tests/domain/test_invitations.py +++ b/tests/domain/test_invitations.py @@ -1,7 +1,8 @@ import datetime import pytest -from atst.domain.invitations import Invitations, InvitationExpired +from atst.domain.invitations import Invitations, InvitationError +from atst.models.invitation import Status from tests.factories import WorkspaceFactory, UserFactory, InvitationFactory @@ -13,16 +14,16 @@ def test_create_invitation(): assert invite.user == user assert invite.workspace == workspace assert invite.inviter == workspace.owner - assert invite.valid + assert invite.status == Status.PENDING def test_accept_invitation(): workspace = WorkspaceFactory.create() user = UserFactory.create() invite = Invitations.create(workspace, workspace.owner, user) - assert invite.valid + assert invite.is_pending accepted_invite = Invitations.accept(invite.id) - assert not accepted_invite.valid + assert accepted_invite.is_accepted def test_accept_expired_invitation(): @@ -31,19 +32,32 @@ def test_accept_expired_invitation(): increment = Invitations.EXPIRATION_LIMIT_MINUTES + 1 created_at = datetime.datetime.now() - datetime.timedelta(minutes=increment) invite = InvitationFactory.create( - workspace_id=workspace.id, user_id=user.id, time_created=created_at, valid=True + workspace_id=workspace.id, + user_id=user.id, + time_created=created_at, + status=Status.PENDING, ) - with pytest.raises(InvitationExpired): + with pytest.raises(InvitationError): Invitations.accept(invite.id) - assert not invite.valid + assert invite.is_rejected -def test_accept_invalid_invite(): +def test_accept_rejected_invite(): workspace = WorkspaceFactory.create() user = UserFactory.create() invite = InvitationFactory.create( - workspace_id=workspace.id, user_id=user.id, valid=False + workspace_id=workspace.id, user_id=user.id, status=Status.REJECTED ) - with pytest.raises(InvitationExpired): + with pytest.raises(InvitationError): + Invitations.accept(invite.id) + + +def test_accept_revoked_invite(): + workspace = WorkspaceFactory.create() + user = UserFactory.create() + invite = InvitationFactory.create( + workspace_id=workspace.id, user_id=user.id, status=Status.REVOKED + ) + with pytest.raises(InvitationError): Invitations.accept(invite.id) diff --git a/tests/factories.py b/tests/factories.py index 0db8128f..eac728a9 100644 --- a/tests/factories.py +++ b/tests/factories.py @@ -20,7 +20,7 @@ from atst.models.workspace import Workspace from atst.domain.roles import Roles from atst.models.workspace_role import WorkspaceRole from atst.models.environment_role import EnvironmentRole -from atst.models.invitation import Invitation +from atst.models.invitation import Invitation, Status as InvitationStatus from atst.domain.workspaces import Workspaces @@ -339,3 +339,5 @@ class EnvironmentRoleFactory(Base): class InvitationFactory(Base): class Meta: model = Invitation + + status = InvitationStatus.PENDING diff --git a/tests/routes/test_workspaces.py b/tests/routes/test_workspaces.py index f9d460a5..d90203de 100644 --- a/tests/routes/test_workspaces.py +++ b/tests/routes/test_workspaces.py @@ -8,6 +8,7 @@ from atst.domain.environments import Environments from atst.domain.environment_roles import EnvironmentRoles from atst.domain.invitations import Invitations from atst.models.workspace_user import WorkspaceUser +from atst.models.invitation import Status as InvitationStatus from atst.queue import queue @@ -316,7 +317,7 @@ def test_new_member_accepts_valid_invite(client, user_session): in response.headers["Location"] ) # the one-time use invite is no longer usable - assert invite.valid == False + assert invite.is_accepted # the user has access to the workspace assert len(Workspaces.for_user(user)) == 1 @@ -327,7 +328,9 @@ def test_new_member_accept_invalid_invite(client, user_session): user = UserFactory.create() member = WorkspaceUsers.add(user, workspace.id, "developer") invite = InvitationFactory.create( - user_id=member.user.id, workspace_id=workspace.id, valid=False + user_id=member.user.id, + workspace_id=workspace.id, + status=InvitationStatus.REJECTED, ) user_session(user) response = client.get(url_for("workspaces.accept_invitation", invite_id=invite.id))