track invitation state by status enum

This commit is contained in:
dandds 2018-10-26 13:31:21 -04:00
parent 6125041a93
commit d5998ed370
7 changed files with 110 additions and 32 deletions

View File

@ -0,0 +1,30 @@
"""add status to invitation
Revision ID: 03654d08f5ff
Revises: 2bec1868a22a
Create Date: 2018-10-26 12:59:16.709080
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '03654d08f5ff'
down_revision = '2bec1868a22a'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('invitations', sa.Column('status', sa.Enum('ACCEPTED', 'REVOKED', 'PENDING', 'REJECTED', name='status', native_enum=False), nullable=True))
op.drop_column('invitations', 'valid')
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('invitations', sa.Column('valid', sa.BOOLEAN(), autoincrement=False, nullable=True))
op.drop_column('invitations', 'status')
# ### end Alembic commands ###

View File

@ -2,18 +2,18 @@ import datetime
from sqlalchemy.orm.exc import NoResultFound from sqlalchemy.orm.exc import NoResultFound
from atst.database import db from atst.database import db
from atst.models import Invitation from atst.models.invitation import Invitation, Status as InvitationStatus
from .exceptions import NotFoundError from .exceptions import NotFoundError
class InvitationExpired(Exception): class InvitationError(Exception):
def __init__(self, invite_id): def __init__(self, invite):
self.invite_id = invite_id self.invite = invite
@property @property
def message(self): def message(self):
return "{} has expired".format(self.invite_id) return "{} has a status of {}".format(self.invite.id, self.invite.status.value)
class Invitations(object): class Invitations(object):
@ -31,7 +31,12 @@ class Invitations(object):
@classmethod @classmethod
def create(cls, workspace, inviter, user): def create(cls, workspace, inviter, user):
invite = Invitation(workspace=workspace, inviter=inviter, user=user, valid=True) invite = Invitation(
workspace=workspace,
inviter=inviter,
user=user,
status=InvitationStatus.PENDING,
)
db.session.add(invite) db.session.add(invite)
db.session.commit() db.session.commit()
@ -40,23 +45,22 @@ class Invitations(object):
@classmethod @classmethod
def accept(cls, invite_id): def accept(cls, invite_id):
invite = Invitations._get(invite_id) invite = Invitations._get(invite_id)
valid = Invitations.is_valid(invite)
invite.valid = False if Invitations._is_expired(invite):
invite.status = InvitationStatus.REJECTED
elif invite.is_pending:
invite.status = InvitationStatus.ACCEPTED
db.session.add(invite) db.session.add(invite)
db.session.commit() db.session.commit()
if not valid: if invite.is_revoked or invite.is_rejected:
raise InvitationExpired(invite_id) raise InvitationError(invite)
return invite return invite
@classmethod @classmethod
def is_valid(cls, invite): def _is_expired(cls, invite):
return invite.valid and not Invitations.is_expired(invite)
@classmethod
def is_expired(cls, invite):
time_created = invite.time_created time_created = invite.time_created
expiration = datetime.datetime.now(time_created.tzinfo) - datetime.timedelta( expiration = datetime.datetime.now(time_created.tzinfo) - datetime.timedelta(
minutes=Invitations.EXPIRATION_LIMIT_MINUTES minutes=Invitations.EXPIRATION_LIMIT_MINUTES

View File

@ -1,4 +1,6 @@
from sqlalchemy import Column, ForeignKey, Boolean from enum import Enum
from sqlalchemy import Column, ForeignKey, Enum as SQLAEnum
from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
@ -6,6 +8,13 @@ from atst.models import Base, types
from atst.models.mixins.timestamps import TimestampsMixin from atst.models.mixins.timestamps import TimestampsMixin
class Status(Enum):
ACCEPTED = "accepted"
REVOKED = "revoked"
PENDING = "pending"
REJECTED = "rejected"
class Invitation(Base, TimestampsMixin): class Invitation(Base, TimestampsMixin):
__tablename__ = "invitations" __tablename__ = "invitations"
@ -20,9 +29,25 @@ class Invitation(Base, TimestampsMixin):
inviter_id = Column(UUID(as_uuid=True), ForeignKey("users.id"), index=True) inviter_id = Column(UUID(as_uuid=True), ForeignKey("users.id"), index=True)
inviter = relationship("User", backref="sent_invites", foreign_keys=[inviter_id]) inviter = relationship("User", backref="sent_invites", foreign_keys=[inviter_id])
valid = Column(Boolean, default=True) status = Column(SQLAEnum(Status, native_enum=False, default=Status.PENDING))
def __repr__(self): def __repr__(self):
return "<Invitation(user='{}', workspace='{}', id='{}')>".format( return "<Invitation(user='{}', workspace='{}', id='{}')>".format(
self.user.id, self.workspace.id, self.id self.user.id, self.workspace.id, self.id
) )
@property
def is_accepted(self):
return self.status == Status.ACCEPTED
@property
def is_revoked(self):
return self.status == Status.REVOKED
@property
def is_pending(self):
return self.status == Status.PENDING
@property
def is_rejected(self):
return self.status == Status.REJECTED

View File

@ -3,7 +3,7 @@ from flask_wtf.csrf import CSRFError
import werkzeug.exceptions as werkzeug_exceptions import werkzeug.exceptions as werkzeug_exceptions
import atst.domain.exceptions as exceptions import atst.domain.exceptions as exceptions
from atst.domain.invitations import InvitationExpired from atst.domain.invitations import InvitationError
def make_error_pages(app): def make_error_pages(app):
@ -42,7 +42,7 @@ def make_error_pages(app):
500, 500,
) )
@app.errorhandler(InvitationExpired) @app.errorhandler(InvitationError)
# pylint: disable=unused-variable # pylint: disable=unused-variable
def expired_invitation(e): def expired_invitation(e):
log_error(e) log_error(e)

View File

@ -1,7 +1,8 @@
import datetime import datetime
import pytest import pytest
from atst.domain.invitations import Invitations, InvitationExpired from atst.domain.invitations import Invitations, InvitationError
from atst.models.invitation import Status
from tests.factories import WorkspaceFactory, UserFactory, InvitationFactory from tests.factories import WorkspaceFactory, UserFactory, InvitationFactory
@ -13,16 +14,16 @@ def test_create_invitation():
assert invite.user == user assert invite.user == user
assert invite.workspace == workspace assert invite.workspace == workspace
assert invite.inviter == workspace.owner assert invite.inviter == workspace.owner
assert invite.valid assert invite.status == Status.PENDING
def test_accept_invitation(): def test_accept_invitation():
workspace = WorkspaceFactory.create() workspace = WorkspaceFactory.create()
user = UserFactory.create() user = UserFactory.create()
invite = Invitations.create(workspace, workspace.owner, user) invite = Invitations.create(workspace, workspace.owner, user)
assert invite.valid assert invite.is_pending
accepted_invite = Invitations.accept(invite.id) accepted_invite = Invitations.accept(invite.id)
assert not accepted_invite.valid assert accepted_invite.is_accepted
def test_accept_expired_invitation(): def test_accept_expired_invitation():
@ -31,19 +32,32 @@ def test_accept_expired_invitation():
increment = Invitations.EXPIRATION_LIMIT_MINUTES + 1 increment = Invitations.EXPIRATION_LIMIT_MINUTES + 1
created_at = datetime.datetime.now() - datetime.timedelta(minutes=increment) created_at = datetime.datetime.now() - datetime.timedelta(minutes=increment)
invite = InvitationFactory.create( invite = InvitationFactory.create(
workspace_id=workspace.id, user_id=user.id, time_created=created_at, valid=True workspace_id=workspace.id,
user_id=user.id,
time_created=created_at,
status=Status.PENDING,
) )
with pytest.raises(InvitationExpired): with pytest.raises(InvitationError):
Invitations.accept(invite.id) Invitations.accept(invite.id)
assert not invite.valid assert invite.is_rejected
def test_accept_invalid_invite(): def test_accept_rejected_invite():
workspace = WorkspaceFactory.create() workspace = WorkspaceFactory.create()
user = UserFactory.create() user = UserFactory.create()
invite = InvitationFactory.create( invite = InvitationFactory.create(
workspace_id=workspace.id, user_id=user.id, valid=False workspace_id=workspace.id, user_id=user.id, status=Status.REJECTED
) )
with pytest.raises(InvitationExpired): with pytest.raises(InvitationError):
Invitations.accept(invite.id)
def test_accept_revoked_invite():
workspace = WorkspaceFactory.create()
user = UserFactory.create()
invite = InvitationFactory.create(
workspace_id=workspace.id, user_id=user.id, status=Status.REVOKED
)
with pytest.raises(InvitationError):
Invitations.accept(invite.id) Invitations.accept(invite.id)

View File

@ -20,7 +20,7 @@ from atst.models.workspace import Workspace
from atst.domain.roles import Roles from atst.domain.roles import Roles
from atst.models.workspace_role import WorkspaceRole from atst.models.workspace_role import WorkspaceRole
from atst.models.environment_role import EnvironmentRole from atst.models.environment_role import EnvironmentRole
from atst.models.invitation import Invitation from atst.models.invitation import Invitation, Status as InvitationStatus
from atst.domain.workspaces import Workspaces from atst.domain.workspaces import Workspaces
@ -339,3 +339,5 @@ class EnvironmentRoleFactory(Base):
class InvitationFactory(Base): class InvitationFactory(Base):
class Meta: class Meta:
model = Invitation model = Invitation
status = InvitationStatus.PENDING

View File

@ -8,6 +8,7 @@ from atst.domain.environments import Environments
from atst.domain.environment_roles import EnvironmentRoles from atst.domain.environment_roles import EnvironmentRoles
from atst.domain.invitations import Invitations from atst.domain.invitations import Invitations
from atst.models.workspace_user import WorkspaceUser from atst.models.workspace_user import WorkspaceUser
from atst.models.invitation import Status as InvitationStatus
from atst.queue import queue from atst.queue import queue
@ -316,7 +317,7 @@ def test_new_member_accepts_valid_invite(client, user_session):
in response.headers["Location"] in response.headers["Location"]
) )
# the one-time use invite is no longer usable # the one-time use invite is no longer usable
assert invite.valid == False assert invite.is_accepted
# the user has access to the workspace # the user has access to the workspace
assert len(Workspaces.for_user(user)) == 1 assert len(Workspaces.for_user(user)) == 1
@ -327,7 +328,9 @@ def test_new_member_accept_invalid_invite(client, user_session):
user = UserFactory.create() user = UserFactory.create()
member = WorkspaceUsers.add(user, workspace.id, "developer") member = WorkspaceUsers.add(user, workspace.id, "developer")
invite = InvitationFactory.create( invite = InvitationFactory.create(
user_id=member.user.id, workspace_id=workspace.id, valid=False user_id=member.user.id,
workspace_id=workspace.id,
status=InvitationStatus.REJECTED,
) )
user_session(user) user_session(user)
response = client.get(url_for("workspaces.accept_invitation", invite_id=invite.id)) response = client.get(url_for("workspaces.accept_invitation", invite_id=invite.id))