track invitation state by status enum
This commit is contained in:
@@ -2,18 +2,18 @@ import datetime
|
||||
from sqlalchemy.orm.exc import NoResultFound
|
||||
|
||||
from atst.database import db
|
||||
from atst.models import Invitation
|
||||
from atst.models.invitation import Invitation, Status as InvitationStatus
|
||||
|
||||
from .exceptions import NotFoundError
|
||||
|
||||
|
||||
class InvitationExpired(Exception):
|
||||
def __init__(self, invite_id):
|
||||
self.invite_id = invite_id
|
||||
class InvitationError(Exception):
|
||||
def __init__(self, invite):
|
||||
self.invite = invite
|
||||
|
||||
@property
|
||||
def message(self):
|
||||
return "{} has expired".format(self.invite_id)
|
||||
return "{} has a status of {}".format(self.invite.id, self.invite.status.value)
|
||||
|
||||
|
||||
class Invitations(object):
|
||||
@@ -31,7 +31,12 @@ class Invitations(object):
|
||||
|
||||
@classmethod
|
||||
def create(cls, workspace, inviter, user):
|
||||
invite = Invitation(workspace=workspace, inviter=inviter, user=user, valid=True)
|
||||
invite = Invitation(
|
||||
workspace=workspace,
|
||||
inviter=inviter,
|
||||
user=user,
|
||||
status=InvitationStatus.PENDING,
|
||||
)
|
||||
db.session.add(invite)
|
||||
db.session.commit()
|
||||
|
||||
@@ -40,23 +45,22 @@ class Invitations(object):
|
||||
@classmethod
|
||||
def accept(cls, invite_id):
|
||||
invite = Invitations._get(invite_id)
|
||||
valid = Invitations.is_valid(invite)
|
||||
|
||||
invite.valid = False
|
||||
if Invitations._is_expired(invite):
|
||||
invite.status = InvitationStatus.REJECTED
|
||||
elif invite.is_pending:
|
||||
invite.status = InvitationStatus.ACCEPTED
|
||||
|
||||
db.session.add(invite)
|
||||
db.session.commit()
|
||||
|
||||
if not valid:
|
||||
raise InvitationExpired(invite_id)
|
||||
if invite.is_revoked or invite.is_rejected:
|
||||
raise InvitationError(invite)
|
||||
|
||||
return invite
|
||||
|
||||
@classmethod
|
||||
def is_valid(cls, invite):
|
||||
return invite.valid and not Invitations.is_expired(invite)
|
||||
|
||||
@classmethod
|
||||
def is_expired(cls, invite):
|
||||
def _is_expired(cls, invite):
|
||||
time_created = invite.time_created
|
||||
expiration = datetime.datetime.now(time_created.tzinfo) - datetime.timedelta(
|
||||
minutes=Invitations.EXPIRATION_LIMIT_MINUTES
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from sqlalchemy import Column, ForeignKey, Boolean
|
||||
from enum import Enum
|
||||
|
||||
from sqlalchemy import Column, ForeignKey, Enum as SQLAEnum
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
@@ -6,6 +8,13 @@ from atst.models import Base, types
|
||||
from atst.models.mixins.timestamps import TimestampsMixin
|
||||
|
||||
|
||||
class Status(Enum):
|
||||
ACCEPTED = "accepted"
|
||||
REVOKED = "revoked"
|
||||
PENDING = "pending"
|
||||
REJECTED = "rejected"
|
||||
|
||||
|
||||
class Invitation(Base, TimestampsMixin):
|
||||
__tablename__ = "invitations"
|
||||
|
||||
@@ -20,9 +29,25 @@ class Invitation(Base, TimestampsMixin):
|
||||
inviter_id = Column(UUID(as_uuid=True), ForeignKey("users.id"), index=True)
|
||||
inviter = relationship("User", backref="sent_invites", foreign_keys=[inviter_id])
|
||||
|
||||
valid = Column(Boolean, default=True)
|
||||
status = Column(SQLAEnum(Status, native_enum=False, default=Status.PENDING))
|
||||
|
||||
def __repr__(self):
|
||||
return "<Invitation(user='{}', workspace='{}', id='{}')>".format(
|
||||
self.user.id, self.workspace.id, self.id
|
||||
)
|
||||
|
||||
@property
|
||||
def is_accepted(self):
|
||||
return self.status == Status.ACCEPTED
|
||||
|
||||
@property
|
||||
def is_revoked(self):
|
||||
return self.status == Status.REVOKED
|
||||
|
||||
@property
|
||||
def is_pending(self):
|
||||
return self.status == Status.PENDING
|
||||
|
||||
@property
|
||||
def is_rejected(self):
|
||||
return self.status == Status.REJECTED
|
||||
|
||||
@@ -3,7 +3,7 @@ from flask_wtf.csrf import CSRFError
|
||||
import werkzeug.exceptions as werkzeug_exceptions
|
||||
|
||||
import atst.domain.exceptions as exceptions
|
||||
from atst.domain.invitations import InvitationExpired
|
||||
from atst.domain.invitations import InvitationError
|
||||
|
||||
|
||||
def make_error_pages(app):
|
||||
@@ -42,7 +42,7 @@ def make_error_pages(app):
|
||||
500,
|
||||
)
|
||||
|
||||
@app.errorhandler(InvitationExpired)
|
||||
@app.errorhandler(InvitationError)
|
||||
# pylint: disable=unused-variable
|
||||
def expired_invitation(e):
|
||||
log_error(e)
|
||||
|
||||
Reference in New Issue
Block a user