diff --git a/alembic/versions/b5b17d465166_requests.py b/alembic/versions/b5b17d465166_requests.py index dd701f9c..b8d47c06 100644 --- a/alembic/versions/b5b17d465166_requests.py +++ b/alembic/versions/b5b17d465166_requests.py @@ -36,6 +36,9 @@ def upgrade(): ) # ### end Alembic commands ### + db = op.get_bind() + db.execute("CREATE SEQUENCE request_status_events_sequence_seq OWNED BY request_status_events.sequence;") + def downgrade(): # ### commands auto generated by Alembic - please adjust! ### diff --git a/atst/app.py b/atst/app.py index aa101838..058a3046 100644 --- a/atst/app.py +++ b/atst/app.py @@ -20,6 +20,7 @@ from atst.api_client import ApiClient from atst.sessions import RedisSessions from atst import ui_modules from atst import ui_methods +from atst.database import make_db ENV = os.getenv("TORNADO_ENV", "dev") @@ -55,7 +56,7 @@ def make_app(config, deps, **kwargs): url( r"/requests", Request, - {"page": "requests", "requests_client": deps["requests_client"]}, + {"page": "requests", "db_session": deps["db_session"]}, name="requests", ), url( @@ -160,6 +161,7 @@ def make_deps(config): ) return { + "db_session": make_db(config), "authz_client": ApiClient( config["default"]["AUTHZ_BASE_URL"], api_version="v1", diff --git a/atst/domain/__init__.py b/atst/domain/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/atst/domain/exceptions.py b/atst/domain/exceptions.py new file mode 100644 index 00000000..802997d4 --- /dev/null +++ b/atst/domain/exceptions.py @@ -0,0 +1,16 @@ +class NotFoundError(Exception): + def __init__(self, resource_name): + self.resource_name = resource_name + + @property + def message(self): + return "No {} could be found.".format(self.resource_name) + + +class AlreadyExistsError(Exception): + def __init__(self, resource_name): + self.resource_name = resource_name + + @property + def message(self): + return "{} already exists".format(self.resource_name) diff --git a/atst/domain/requests.py b/atst/domain/requests.py new file mode 100644 index 00000000..09acb0af --- /dev/null +++ b/atst/domain/requests.py @@ -0,0 +1,130 @@ +import tornado.gen +from sqlalchemy import exists, and_ +from sqlalchemy.orm.exc import NoResultFound +from sqlalchemy.orm.attributes import flag_modified + +from atst.models import Request, RequestStatusEvent +from .exceptions import NotFoundError + + +def deep_merge(source, destination: dict): + """ + Merge source dict into destination dict recursively. + """ + + def _deep_merge(a, b): + for key, value in a.items(): + if isinstance(value, dict): + node = b.setdefault(key, {}) + _deep_merge(value, node) + else: + b[key] = value + + return b + + return _deep_merge(source, dict(destination)) + + +class Requests(object): + AUTO_APPROVE_THRESHOLD = 1000000 + + def __init__(self, db_session): + self.db_session = db_session + + def create(self, creator_id, body): + request = Request(creator=creator_id, body=body) + + status_event = RequestStatusEvent(new_status="incomplete") + request.status_events.append(status_event) + + self.db_session.add(request) + self.db_session.commit() + + return request + + def exists(self, request_id, creator_id): + return self.db_session.query( + exists().where( + and_(Request.id == request_id, Request.creator == creator_id) + ) + ).scalar() + + def get(self, request_id): + try: + request = self.db_session.query(Request).filter_by(id=request_id).one() + except NoResultFound: + raise NotFoundError("request") + + return request + + def get_many(self, creator_id=None): + filters = [] + if creator_id: + filters.append(Request.creator == creator_id) + + requests = ( + self.db_session.query(Request) + .filter(*filters) + .order_by(Request.time_created.desc()) + .all() + ) + return requests + + @tornado.gen.coroutine + def submit(self, request): + request.status_events.append(StatusEvent(new_status="submitted")) + + if Requests.should_auto_approve(request): + request.status_events.append(StatusEvent(new_status="approved")) + + self.db_session.add(request) + self.db_session.commit() + + return request + + @tornado.gen.coroutine + def update(self, request_id, request_delta): + try: + # Query for request matching id, acquiring a row-level write lock. + # https://www.postgresql.org/docs/10/static/sql-select.html#SQL-FOR-UPDATE-SHARE + request = ( + self.db_session.query(Request) + .filter_by(id=request_id) + .with_for_update(of=Request) + .one() + ) + except NoResultFound: + return + + request.body = deep_merge(request_delta, request.body) + + if Requests.should_allow_submission(request): + request.status_events.append(StatusEvent(new_status="pending_submission")) + + # Without this, sqlalchemy won't notice the change to request.body, + # since it doesn't track dictionary mutations by default. + flag_modified(request, "body") + + self.db_session.add(request) + self.db_session.commit() + + @classmethod + def should_auto_approve(cls, request): + try: + dollar_value = request.body["details_of_use"]["dollar_value"] + except KeyError: + return False + + return dollar_value < cls.AUTO_APPROVE_THRESHOLD + + @classmethod + def should_allow_submission(cls, request): + all_request_sections = [ + "details_of_use", + "information_about_you", + "primary_poc", + ] + existing_request_sections = request.body.keys() + return request.status == "incomplete" and all( + section in existing_request_sections for section in all_request_sections + ) diff --git a/atst/handlers/request.py b/atst/handlers/request.py index d9bec160..1373280a 100644 --- a/atst/handlers/request.py +++ b/atst/handlers/request.py @@ -2,16 +2,17 @@ import tornado import pendulum from atst.handler import BaseHandler +from atst.domain.requests import Requests def map_request(user, request): - time_created = pendulum.parse(request["time_created"]) + time_created = pendulum.instance(request.time_created) is_new = time_created.add(days=1) > pendulum.now() return { - "order_id": request["id"], + "order_id": request.id, "is_new": is_new, - "status": request["status"], + "status": request.status, "app_count": 1, "date": time_created.format("M/DD/YYYY"), "full_name": "{} {}".format(user["first_name"], user["last_name"]), @@ -19,9 +20,9 @@ def map_request(user, request): class Request(BaseHandler): - def initialize(self, page, requests_client): + def initialize(self, page, db_session): self.page = page - self.requests_client = requests_client + self.requests = Requests(db_session) @tornado.web.authenticated @tornado.gen.coroutine @@ -33,11 +34,10 @@ class Request(BaseHandler): @tornado.gen.coroutine def fetch_requests(self, user): + requests = [] if "review_and_approve_jedi_workspace_request" in user["atat_permissions"]: - response = yield self.requests_client.get("/requests") + requests = self.requests.get_many() else: - response = yield self.requests_client.get( - "/requests?creator_id={}".format(user["id"]) - ) + requests = self.requests.get_many(creator_id=user["id"]) - return response.json["requests"] + return requests diff --git a/atst/models/request.py b/atst/models/request.py index baaa92b3..42bec6ed 100644 --- a/atst/models/request.py +++ b/atst/models/request.py @@ -14,9 +14,9 @@ class Request(Base): creator = Column(UUID(as_uuid=True)) time_created = Column(DateTime(timezone=True), server_default=func.now()) body = Column(JSONB) - status_events = relationship('StatusEvent', + status_events = relationship('RequestStatusEvent', backref='request', - order_by='StatusEvent.sequence') + order_by='RequestStatusEvent.sequence') @property def status(self):