diff --git a/atst/domain/common/__init__.py b/atst/domain/common/__init__.py new file mode 100644 index 00000000..f829496f --- /dev/null +++ b/atst/domain/common/__init__.py @@ -0,0 +1 @@ +from .query import Query diff --git a/atst/domain/common/query.py b/atst/domain/common/query.py new file mode 100644 index 00000000..af4e069c --- /dev/null +++ b/atst/domain/common/query.py @@ -0,0 +1,38 @@ +from sqlalchemy.exc import DataError +from sqlalchemy.orm.exc import NoResultFound + +from atst.domain.exceptions import NotFoundError +from atst.database import db + + +class Query(object): + + model = None + + @property + def resource_name(cls): + return cls.model.__class__.lower() + + @classmethod + def create(cls, **kwargs): + # pylint: disable=E1102 + return cls.model(**kwargs) + + @classmethod + def get(cls, id_): + try: + resource = db.session.query(cls.model).filter_by(id=id_).one() + return resource + except (NoResultFound, DataError): + raise NotFoundError(cls.resource_name) + + @classmethod + def get_all(cls): + return db.session.query(cls.model).all() + + @classmethod + def add_and_commit(cls, resource): + db.session.add(resource) + db.session.commit() + return resource + diff --git a/atst/domain/requests/__init__.py b/atst/domain/requests/__init__.py new file mode 100644 index 00000000..88072d60 --- /dev/null +++ b/atst/domain/requests/__init__.py @@ -0,0 +1 @@ +from .requests import Requests, create_revision_from_request_body diff --git a/atst/domain/requests/query.py b/atst/domain/requests/query.py new file mode 100644 index 00000000..10e6e434 --- /dev/null +++ b/atst/domain/requests/query.py @@ -0,0 +1,71 @@ +from sqlalchemy import exists, and_, exc, text + +from atst.database import db +from atst.domain.common import Query +from atst.models.request import Request + + +class RequestsQuery(Query): + model = Request + + @classmethod + def exists(cls, request_id, creator): + try: + return db.session.query( + exists().where( + and_(Request.id == request_id, Request.creator == creator) + ) + ).scalar() + + except exc.DataError: + return False + + @classmethod + def get_many(cls, creator=None): + filters = [] + if creator: + filters.append(Request.creator == creator) + + requests = ( + db.session.query(Request) + .filter(*filters) + .order_by(Request.time_created.desc()) + .all() + ) + return requests + + @classmethod + def get_with_lock(cls, request_id): + try: + # Query for request matching id, acquiring a row-level write lock. + # https://www.postgresql.org/docs/10/static/sql-select.html#SQL-FOR-UPDATE-SHARE + return ( + db.session.query(Request) + .filter_by(id=request_id) + .with_for_update(of=Request) + .one() + ) + + except NoResultFound: + raise NotFoundError("requests") + + @classmethod + def status_count(cls, status, creator=None): + bindings = {"status": status.name} + raw = """ +SELECT count(requests_with_status.id) +FROM ( + SELECT DISTINCT ON (rse.request_id) r.*, rse.new_status as status + FROM request_status_events rse JOIN requests r ON r.id = rse.request_id + ORDER BY rse.request_id, rse.sequence DESC +) as requests_with_status +WHERE requests_with_status.status = :status + """ + + if creator: + raw += " AND requests_with_status.user_id = :user_id" + bindings["user_id"] = creator.id + + results = db.session.execute(text(raw), bindings).fetchone() + (count,) = results + return count diff --git a/atst/domain/requests.py b/atst/domain/requests/requests.py similarity index 71% rename from atst/domain/requests.py rename to atst/domain/requests/requests.py index c2026b3d..71245f23 100644 --- a/atst/domain/requests.py +++ b/atst/domain/requests/requests.py @@ -1,11 +1,6 @@ -from enum import Enum -from sqlalchemy import exists, and_, exc -from sqlalchemy.sql import text -from sqlalchemy.orm.exc import NoResultFound from werkzeug.datastructures import FileStorage import dateutil -from atst.database import db from atst.domain.authz import Authorization from atst.domain.task_orders import TaskOrders from atst.domain.workspaces import Workspaces @@ -16,7 +11,9 @@ from atst.models.request_review import RequestReview from atst.models.request_internal_comment import RequestInternalComment from atst.utils import deep_merge -from .exceptions import NotFoundError, UnauthorizedError +from atst.domain.exceptions import UnauthorizedError + +from .query import RequestsQuery def create_revision_from_request_body(body): @@ -38,38 +35,19 @@ class Requests(object): @classmethod def create(cls, creator, body): revision = create_revision_from_request_body(body) - request = Request(creator=creator, revisions=[revision]) + request = RequestsQuery.create(creator=creator, revisions=[revision]) request = Requests.set_status(request, RequestStatus.STARTED) - - db.session.add(request) - db.session.commit() + request = RequestsQuery.add_and_commit(request) return request @classmethod def exists(cls, request_id, creator): - try: - return db.session.query( - exists().where( - and_(Request.id == request_id, Request.creator == creator) - ) - ).scalar() - - except exc.DataError: - return False - - @classmethod - def _get(cls, user, request_id): - try: - request = db.session.query(Request).filter_by(id=request_id).one() - except (NoResultFound, exc.DataError): - raise NotFoundError("request") - - return request + return RequestsQuery.exists(request_id, creator) @classmethod def get(cls, user, request_id): - request = Requests._get(user, request_id) + request = RequestsQuery.get(request_id) if not Authorization.can_view_request(user, request): raise UnauthorizedError(user, "get request") @@ -78,7 +56,7 @@ class Requests(object): @classmethod def get_for_approval(cls, user, request_id): - request = Requests._get(user, request_id) + request = RequestsQuery.get(request_id) Authorization.check_can_approve_request(user) @@ -86,17 +64,7 @@ class Requests(object): @classmethod def get_many(cls, creator=None): - filters = [] - if creator: - filters.append(Request.creator == creator) - - requests = ( - db.session.query(Request) - .filter(*filters) - .order_by(Request.time_created.desc()) - .all() - ) - return requests + return RequestsQuery.get_many(creator) @classmethod def submit(cls, request): @@ -109,47 +77,28 @@ class Requests(object): new_status = RequestStatus.PENDING_CCPO_ACCEPTANCE request = Requests.set_status(request, new_status) - - db.session.add(request) - db.session.commit() + request = RequestsQuery.add_and_commit(request) return request @classmethod def update(cls, request_id, request_delta): - request = Requests._get_with_lock(request_id) + request = RequestsQuery.get_with_lock(request_id) new_body = deep_merge(request_delta, request.body) revision = create_revision_from_request_body(new_body) request.revisions.append(revision) - db.session.add(request) - db.session.commit() + request = RequestsQuery.add_and_commit(request) return request - @classmethod - def _get_with_lock(cls, request_id): - try: - # Query for request matching id, acquiring a row-level write lock. - # https://www.postgresql.org/docs/10/static/sql-select.html#SQL-FOR-UPDATE-SHARE - return ( - db.session.query(Request) - .filter_by(id=request_id) - .with_for_update(of=Request) - .one() - ) - - except NoResultFound: - raise NotFoundError() - @classmethod def approve_and_create_workspace(cls, request): approved_request = Requests.set_status(request, RequestStatus.APPROVED) workspace = Workspaces.create(approved_request) - db.session.add(approved_request) - db.session.commit() + RequestsQuery.add_and_commit(approved_request) return workspace @@ -205,26 +154,7 @@ class Requests(object): @classmethod def status_count(cls, status, creator=None): - if isinstance(status, Enum): - status = status.name - bindings = {"status": status} - raw = """ -SELECT count(requests_with_status.id) -FROM ( - SELECT DISTINCT ON (rse.request_id) r.*, rse.new_status as status - FROM request_status_events rse JOIN requests r ON r.id = rse.request_id - ORDER BY rse.request_id, rse.sequence DESC -) as requests_with_status -WHERE requests_with_status.status = :status - """ - - if creator: - raw += " AND requests_with_status.user_id = :user_id" - bindings["user_id"] = creator.id - - results = db.session.execute(text(raw), bindings).fetchone() - (count,) = results - return count + return RequestsQuery.status_count(status, creator) @classmethod def in_progress_count(cls): @@ -251,7 +181,7 @@ WHERE requests_with_status.status = :status @classmethod def update_financial_verification(cls, request_id, financial_data): - request = Requests._get_with_lock(request_id) + request = RequestsQuery.get_with_lock(request_id) request_data = financial_data.copy() task_order_data = { @@ -283,20 +213,14 @@ WHERE requests_with_status.status = :status @classmethod def submit_financial_verification(cls, request): - Requests.set_status(request, RequestStatus.PENDING_CCPO_APPROVAL) - - db.session.add(request) - db.session.commit() - + request = Requests.set_status(request, RequestStatus.PENDING_CCPO_APPROVAL) + request = RequestsQuery.add_and_commit(request) return request @classmethod def _add_review(cls, user, request, review_data): request.latest_status.review = RequestReview(reviewer=user, **review_data) - - db.session.add(request) - db.session.commit() - + request = RequestsQuery.add_and_commit(request) return request @classmethod @@ -320,9 +244,6 @@ WHERE requests_with_status.status = :status @classmethod def update_internal_comments(cls, user, request, comment_text): Authorization.check_can_approve_request(user) - request.internal_comments = RequestInternalComment(text=comment_text, user=user) - db.session.add(request) - db.session.commit() - + request = RequestsQuery.add_and_commit(request) return request diff --git a/atst/domain/workspaces/query.py b/atst/domain/workspaces/query.py index 69dfe992..0fb94d97 100644 --- a/atst/domain/workspaces/query.py +++ b/atst/domain/workspaces/query.py @@ -1,43 +1,12 @@ from sqlalchemy.orm.exc import NoResultFound from atst.database import db +from atst.domain.common import Query from atst.domain.exceptions import NotFoundError from atst.models.workspace import Workspace from atst.models.workspace_role import WorkspaceRole -class Query(object): - - model = None - - @property - def resource_name(cls): - return cls.model.__class__.lower() - - @classmethod - def create(cls, **kwargs): - # pylint: disable=E1102 - return cls.model(**kwargs) - - @classmethod - def get(cls, id_): - try: - resource = db.session.query(cls.model).filter_by(id=id_).one() - return resource - except NoResultFound: - raise NotFoundError(cls.resource_name) - - @classmethod - def get_all(cls): - return db.session.query(cls.model).all() - - @classmethod - def add_and_commit(cls, resource): - db.session.add(resource) - db.session.commit() - return resource - - class WorkspaceQuery(Query): model = Workspace diff --git a/tests/models/test_requests.py b/tests/models/test_requests.py index f1409481..7f5a3650 100644 --- a/tests/models/test_requests.py +++ b/tests/models/test_requests.py @@ -4,7 +4,8 @@ from tests.factories import ( RequestStatusEventFactory, RequestReviewFactory, ) -from atst.domain.requests import Requests, RequestStatus +from atst.domain.requests import Requests +from atst.models.request_status_event import RequestStatus def test_pending_financial_requires_mo_action():