track invitation state by status enum

This commit is contained in:
dandds 2018-10-26 13:31:21 -04:00
parent 6125041a93
commit d5998ed370
7 changed files with 110 additions and 32 deletions

View File

@ -0,0 +1,30 @@
"""add status to invitation
Revision ID: 03654d08f5ff
Revises: 2bec1868a22a
Create Date: 2018-10-26 12:59:16.709080
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '03654d08f5ff'
down_revision = '2bec1868a22a'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('invitations', sa.Column('status', sa.Enum('ACCEPTED', 'REVOKED', 'PENDING', 'REJECTED', name='status', native_enum=False), nullable=True))
op.drop_column('invitations', 'valid')
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('invitations', sa.Column('valid', sa.BOOLEAN(), autoincrement=False, nullable=True))
op.drop_column('invitations', 'status')
# ### end Alembic commands ###

View File

@ -2,18 +2,18 @@ import datetime
from sqlalchemy.orm.exc import NoResultFound
from atst.database import db
from atst.models import Invitation
from atst.models.invitation import Invitation, Status as InvitationStatus
from .exceptions import NotFoundError
class InvitationExpired(Exception):
def __init__(self, invite_id):
self.invite_id = invite_id
class InvitationError(Exception):
def __init__(self, invite):
self.invite = invite
@property
def message(self):
return "{} has expired".format(self.invite_id)
return "{} has a status of {}".format(self.invite.id, self.invite.status.value)
class Invitations(object):
@ -31,7 +31,12 @@ class Invitations(object):
@classmethod
def create(cls, workspace, inviter, user):
invite = Invitation(workspace=workspace, inviter=inviter, user=user, valid=True)
invite = Invitation(
workspace=workspace,
inviter=inviter,
user=user,
status=InvitationStatus.PENDING,
)
db.session.add(invite)
db.session.commit()
@ -40,23 +45,22 @@ class Invitations(object):
@classmethod
def accept(cls, invite_id):
invite = Invitations._get(invite_id)
valid = Invitations.is_valid(invite)
invite.valid = False
if Invitations._is_expired(invite):
invite.status = InvitationStatus.REJECTED
elif invite.is_pending:
invite.status = InvitationStatus.ACCEPTED
db.session.add(invite)
db.session.commit()
if not valid:
raise InvitationExpired(invite_id)
if invite.is_revoked or invite.is_rejected:
raise InvitationError(invite)
return invite
@classmethod
def is_valid(cls, invite):
return invite.valid and not Invitations.is_expired(invite)
@classmethod
def is_expired(cls, invite):
def _is_expired(cls, invite):
time_created = invite.time_created
expiration = datetime.datetime.now(time_created.tzinfo) - datetime.timedelta(
minutes=Invitations.EXPIRATION_LIMIT_MINUTES

View File

@ -1,4 +1,6 @@
from sqlalchemy import Column, ForeignKey, Boolean
from enum import Enum
from sqlalchemy import Column, ForeignKey, Enum as SQLAEnum
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import relationship
@ -6,6 +8,13 @@ from atst.models import Base, types
from atst.models.mixins.timestamps import TimestampsMixin
class Status(Enum):
ACCEPTED = "accepted"
REVOKED = "revoked"
PENDING = "pending"
REJECTED = "rejected"
class Invitation(Base, TimestampsMixin):
__tablename__ = "invitations"
@ -20,9 +29,25 @@ class Invitation(Base, TimestampsMixin):
inviter_id = Column(UUID(as_uuid=True), ForeignKey("users.id"), index=True)
inviter = relationship("User", backref="sent_invites", foreign_keys=[inviter_id])
valid = Column(Boolean, default=True)
status = Column(SQLAEnum(Status, native_enum=False, default=Status.PENDING))
def __repr__(self):
return "<Invitation(user='{}', workspace='{}', id='{}')>".format(
self.user.id, self.workspace.id, self.id
)
@property
def is_accepted(self):
return self.status == Status.ACCEPTED
@property
def is_revoked(self):
return self.status == Status.REVOKED
@property
def is_pending(self):
return self.status == Status.PENDING
@property
def is_rejected(self):
return self.status == Status.REJECTED

View File

@ -3,7 +3,7 @@ from flask_wtf.csrf import CSRFError
import werkzeug.exceptions as werkzeug_exceptions
import atst.domain.exceptions as exceptions
from atst.domain.invitations import InvitationExpired
from atst.domain.invitations import InvitationError
def make_error_pages(app):
@ -42,7 +42,7 @@ def make_error_pages(app):
500,
)
@app.errorhandler(InvitationExpired)
@app.errorhandler(InvitationError)
# pylint: disable=unused-variable
def expired_invitation(e):
log_error(e)

View File

@ -1,7 +1,8 @@
import datetime
import pytest
from atst.domain.invitations import Invitations, InvitationExpired
from atst.domain.invitations import Invitations, InvitationError
from atst.models.invitation import Status
from tests.factories import WorkspaceFactory, UserFactory, InvitationFactory
@ -13,16 +14,16 @@ def test_create_invitation():
assert invite.user == user
assert invite.workspace == workspace
assert invite.inviter == workspace.owner
assert invite.valid
assert invite.status == Status.PENDING
def test_accept_invitation():
workspace = WorkspaceFactory.create()
user = UserFactory.create()
invite = Invitations.create(workspace, workspace.owner, user)
assert invite.valid
assert invite.is_pending
accepted_invite = Invitations.accept(invite.id)
assert not accepted_invite.valid
assert accepted_invite.is_accepted
def test_accept_expired_invitation():
@ -31,19 +32,32 @@ def test_accept_expired_invitation():
increment = Invitations.EXPIRATION_LIMIT_MINUTES + 1
created_at = datetime.datetime.now() - datetime.timedelta(minutes=increment)
invite = InvitationFactory.create(
workspace_id=workspace.id, user_id=user.id, time_created=created_at, valid=True
workspace_id=workspace.id,
user_id=user.id,
time_created=created_at,
status=Status.PENDING,
)
with pytest.raises(InvitationExpired):
with pytest.raises(InvitationError):
Invitations.accept(invite.id)
assert not invite.valid
assert invite.is_rejected
def test_accept_invalid_invite():
def test_accept_rejected_invite():
workspace = WorkspaceFactory.create()
user = UserFactory.create()
invite = InvitationFactory.create(
workspace_id=workspace.id, user_id=user.id, valid=False
workspace_id=workspace.id, user_id=user.id, status=Status.REJECTED
)
with pytest.raises(InvitationExpired):
with pytest.raises(InvitationError):
Invitations.accept(invite.id)
def test_accept_revoked_invite():
workspace = WorkspaceFactory.create()
user = UserFactory.create()
invite = InvitationFactory.create(
workspace_id=workspace.id, user_id=user.id, status=Status.REVOKED
)
with pytest.raises(InvitationError):
Invitations.accept(invite.id)

View File

@ -20,7 +20,7 @@ from atst.models.workspace import Workspace
from atst.domain.roles import Roles
from atst.models.workspace_role import WorkspaceRole
from atst.models.environment_role import EnvironmentRole
from atst.models.invitation import Invitation
from atst.models.invitation import Invitation, Status as InvitationStatus
from atst.domain.workspaces import Workspaces
@ -339,3 +339,5 @@ class EnvironmentRoleFactory(Base):
class InvitationFactory(Base):
class Meta:
model = Invitation
status = InvitationStatus.PENDING

View File

@ -8,6 +8,7 @@ from atst.domain.environments import Environments
from atst.domain.environment_roles import EnvironmentRoles
from atst.domain.invitations import Invitations
from atst.models.workspace_user import WorkspaceUser
from atst.models.invitation import Status as InvitationStatus
from atst.queue import queue
@ -316,7 +317,7 @@ def test_new_member_accepts_valid_invite(client, user_session):
in response.headers["Location"]
)
# the one-time use invite is no longer usable
assert invite.valid == False
assert invite.is_accepted
# the user has access to the workspace
assert len(Workspaces.for_user(user)) == 1
@ -327,7 +328,9 @@ def test_new_member_accept_invalid_invite(client, user_session):
user = UserFactory.create()
member = WorkspaceUsers.add(user, workspace.id, "developer")
invite = InvitationFactory.create(
user_id=member.user.id, workspace_id=workspace.id, valid=False
user_id=member.user.id,
workspace_id=workspace.id,
status=InvitationStatus.REJECTED,
)
user_session(user)
response = client.get(url_for("workspaces.accept_invitation", invite_id=invite.id))