diff --git a/atst/domain/authz.py b/atst/domain/authz.py index 506970b8..20bd539f 100644 --- a/atst/domain/authz.py +++ b/atst/domain/authz.py @@ -1,4 +1,5 @@ from atst.domain.workspace_users import WorkspaceUsers +from atst.models.permissions import Permissions class Authorization(object): @@ -10,3 +11,15 @@ class Authorization(object): @classmethod def is_in_workspace(cls, user, workspace): return user in workspace.users + + @classmethod + def can_view_request(cls, user, request): + if ( + Permissions.REVIEW_AND_APPROVE_JEDI_WORKSPACE_REQUEST + in user.atat_permissions + ): + return True + elif request.creator == user: + return True + + return False diff --git a/atst/domain/requests.py b/atst/domain/requests.py index 57c69d4c..3169a95e 100644 --- a/atst/domain/requests.py +++ b/atst/domain/requests.py @@ -5,13 +5,14 @@ from sqlalchemy.orm.exc import NoResultFound from sqlalchemy.orm.attributes import flag_modified from werkzeug.datastructures import FileStorage +from atst.database import db +from atst.domain.authz import Authorization +from atst.domain.task_orders import TaskOrders +from atst.domain.workspaces import Workspaces from atst.models.request import Request from atst.models.request_status_event import RequestStatusEvent, RequestStatus -from atst.domain.workspaces import Workspaces -from atst.database import db -from atst.domain.task_orders import TaskOrders -from .exceptions import NotFoundError +from .exceptions import NotFoundError, UnauthorizedError def deep_merge(source, destination: dict): @@ -59,12 +60,15 @@ class Requests(object): return False @classmethod - def get(cls, request_id): + def get(cls, user, request_id): try: request = db.session.query(Request).filter_by(id=request_id).one() except NoResultFound: raise NotFoundError("request") + if not Authorization.can_view_request(user, request): + raise UnauthorizedError(user, "get request") + return request @classmethod diff --git a/atst/routes/requests/financial_verification.py b/atst/routes/requests/financial_verification.py index 17137387..fc22f8f4 100644 --- a/atst/routes/requests/financial_verification.py +++ b/atst/routes/requests/financial_verification.py @@ -1,4 +1,4 @@ -from flask import render_template, redirect, url_for +from flask import g, render_template, redirect, url_for from flask import request as http_request from . import requests_bp @@ -15,7 +15,7 @@ def financial_form(data): @requests_bp.route("/requests/verify/", methods=["GET"]) def financial_verification(request_id=None): - request = Requests.get(request_id) + request = Requests.get(g.current_user, request_id) form = financial_form(request.body.get("financial_verification")) return render_template( "requests/financial_verification.html", @@ -28,7 +28,7 @@ def financial_verification(request_id=None): @requests_bp.route("/requests/verify/", methods=["POST"]) def update_financial_verification(request_id): post_data = http_request.form - existing_request = Requests.get(request_id) + existing_request = Requests.get(g.current_user, request_id) form = financial_form(post_data) rerender_args = dict( request_id=request_id, f=form, extended=http_request.args.get("extended") diff --git a/atst/routes/requests/requests_form.py b/atst/routes/requests/requests_form.py index b4af5c73..08a40e20 100644 --- a/atst/routes/requests/requests_form.py +++ b/atst/routes/requests/requests_form.py @@ -46,7 +46,7 @@ def requests_form_update(screen=1, request_id=None): if request_id: _check_can_view_request(request_id) - request = Requests.get(request_id) if request_id is not None else None + request = Requests.get(g.current_user, request_id) if request_id is not None else None jedi_flow = JEDIRequestFlow( screen, request=request, request_id=request_id, current_user=g.current_user ) @@ -72,7 +72,7 @@ def requests_update(screen=1, request_id=None): screen = int(screen) post_data = http_request.form current_user = g.current_user - existing_request = Requests.get(request_id) if request_id is not None else None + existing_request = Requests.get(g.current_user, request_id) if request_id is not None else None jedi_flow = JEDIRequestFlow( screen, post_data=post_data, @@ -110,7 +110,7 @@ def requests_update(screen=1, request_id=None): @requests_bp.route("/requests/submit/", methods=["POST"]) def requests_submit(request_id=None): - request = Requests.get(request_id) + request = Requests.get(g.current_user, request_id) Requests.submit(request) if request.status == RequestStatus.PENDING_FINANCIAL_VERIFICATION: @@ -122,7 +122,7 @@ def requests_submit(request_id=None): @requests_bp.route("/requests/pending/", methods=["GET"]) def view_pending_request(request_id=None): - request = Requests.get(request_id) + request = Requests.get(g.current_user, request_id) return render_template("requests/view_pending.html", data=request.body) diff --git a/tests/domain/test_requests.py b/tests/domain/test_requests.py index 7d4e4f3e..38a76e93 100644 --- a/tests/domain/test_requests.py +++ b/tests/domain/test_requests.py @@ -21,14 +21,15 @@ def new_request(session): def test_can_get_request(new_request): - request = Requests.get(new_request.id) + request = Requests.get(new_request.creator, new_request.id) assert request.id == new_request.id def test_nonexistent_request_raises(): + a_user = UserFactory.build() with pytest.raises(NotFoundError): - Requests.get(uuid4()) + Requests.get(a_user, uuid4()) def test_new_request_has_started_status(): diff --git a/tests/routes/test_financial_verification.py b/tests/routes/test_financial_verification.py index f8a89691..6163d5b0 100644 --- a/tests/routes/test_financial_verification.py +++ b/tests/routes/test_financial_verification.py @@ -5,7 +5,7 @@ from flask import url_for from atst.eda_client import MockEDAClient from tests.mocks import MOCK_REQUEST, MOCK_USER -from tests.factories import PENumberFactory, RequestFactory +from tests.factories import PENumberFactory, RequestFactory, UserFactory class TestPENumberInForm: @@ -30,12 +30,14 @@ class TestPENumberInForm: monkeypatch.setattr( "atst.forms.financial.FinancialForm.validate", lambda s: True ) + user = UserFactory.create() monkeypatch.setattr( - "atst.domain.auth.get_current_user", lambda *args: MOCK_USER + "atst.domain.auth.get_current_user", lambda *args: user ) + return user - def submit_data(self, client, data, extended=False): - request = RequestFactory.create(body=MOCK_REQUEST.body) + def submit_data(self, client, user, data, extended=False): + request = RequestFactory.create(creator=user, body=MOCK_REQUEST.body) url_kwargs = {"request_id": request.id} if extended: url_kwargs["extended"] = True @@ -47,43 +49,43 @@ class TestPENumberInForm: return response def test_submit_request_form_with_invalid_pe_id(self, monkeypatch, client): - self._set_monkeypatches(monkeypatch) + user = self._set_monkeypatches(monkeypatch) - response = self.submit_data(client, self.required_data) + response = self.submit_data(client, user, self.required_data) assert "We couldn't find that PE number" in response.data.decode() assert response.status_code == 200 def test_submit_request_form_with_unchanged_pe_id(self, monkeypatch, client): - self._set_monkeypatches(monkeypatch) + user = self._set_monkeypatches(monkeypatch) data = dict(self.required_data) data["pe_id"] = MOCK_REQUEST.body["financial_verification"]["pe_id"] - response = self.submit_data(client, data) + response = self.submit_data(client, user, data) assert response.status_code == 302 assert "/workspaces" in response.headers.get("Location") def test_submit_request_form_with_new_valid_pe_id(self, monkeypatch, client): - self._set_monkeypatches(monkeypatch) + user = self._set_monkeypatches(monkeypatch) pe = PENumberFactory.create(number="8675309U", description="sample PE number") data = dict(self.required_data) data["pe_id"] = pe.number - response = self.submit_data(client, data) + response = self.submit_data(client, user, data) assert response.status_code == 302 assert "/workspaces" in response.headers.get("Location") def test_submit_request_form_with_missing_pe_id(self, monkeypatch, client): - self._set_monkeypatches(monkeypatch) + user = self._set_monkeypatches(monkeypatch) data = dict(self.required_data) data["pe_id"] = "" - response = self.submit_data(client, data) + response = self.submit_data(client, user, data) assert "There were some errors" in response.data.decode() assert response.status_code == 200 @@ -91,41 +93,44 @@ class TestPENumberInForm: def test_submit_financial_form_with_invalid_task_order( self, monkeypatch, user_session, client ): - user_session() + user = UserFactory.create() + user_session(user) data = dict(self.required_data) data["pe_id"] = MOCK_REQUEST.body["financial_verification"]["pe_id"] data["task_order_number"] = "1234" - response = self.submit_data(client, data) + response = self.submit_data(client, user, data) assert "enter TO information manually" in response.data.decode() def test_submit_financial_form_with_valid_task_order( self, monkeypatch, user_session, client ): - monkeypatch.setattr("atst.domain.requests.Requests.get", lambda i: MOCK_REQUEST) - user_session() + user = UserFactory.create() + monkeypatch.setattr("atst.domain.requests.Requests.get", lambda *args: MOCK_REQUEST) + user_session(user) data = dict(self.required_data) data["pe_id"] = MOCK_REQUEST.body["financial_verification"]["pe_id"] data["task_order_number"] = MockEDAClient.MOCK_CONTRACT_NUMBER - response = self.submit_data(client, data) + response = self.submit_data(client, user, data) assert "enter TO information manually" not in response.data.decode() def test_submit_extended_financial_form( self, monkeypatch, user_session, client, extended_financial_verification_data ): - request = RequestFactory.create() - monkeypatch.setattr("atst.domain.requests.Requests.get", lambda i: request) + user = UserFactory.create() + request = RequestFactory.create(creator=user) + monkeypatch.setattr("atst.domain.requests.Requests.get", lambda *args: request) monkeypatch.setattr("atst.forms.financial.validate_pe_id", lambda *args: True) user_session() data = {**self.required_data, **extended_financial_verification_data} data["task_order_number"] = "1234567" - response = self.submit_data(client, data, extended=True) + response = self.submit_data(client, user, data, extended=True) assert response.status_code == 302 assert "/projects/new" in response.headers.get("Location") @@ -134,11 +139,12 @@ class TestPENumberInForm: self, monkeypatch, user_session, client, extended_financial_verification_data ): monkeypatch.setattr("atst.forms.financial.validate_pe_id", lambda *args: True) - user_session() + user = UserFactory.create() + user_session(user) data = {**self.required_data, **extended_financial_verification_data} data["task_order_number"] = "1234567" del (data["clin_0001"]) - response = self.submit_data(client, data, extended=True) + response = self.submit_data(client, user, data, extended=True) assert response.status_code == 200 diff --git a/tests/routes/test_request_new.py b/tests/routes/test_request_new.py index 3687900f..1b41eef5 100644 --- a/tests/routes/test_request_new.py +++ b/tests/routes/test_request_new.py @@ -122,7 +122,7 @@ def test_am_poc_causes_poc_to_be_autopopulated(client, user_session): headers={"Content-Type": "application/x-www-form-urlencoded"}, data="am_poc=yes", ) - request = Requests.get(request.id) + request = Requests.get(creator, request.id) assert request.body["primary_poc"]["dodid_poc"] == creator.dod_id @@ -167,7 +167,7 @@ def test_poc_details_can_be_autopopulated_on_new_request(client, user_session): data="am_poc=yes", ) request_id = response.headers["Location"].split("/")[-1] - request = Requests.get(request_id) + request = Requests.get(creator, request_id) assert request.body["primary_poc"]["dodid_poc"] == creator.dod_id @@ -191,7 +191,7 @@ def test_poc_autofill_checks_information_about_you_form_first(client, user_sessi headers={"Content-Type": "application/x-www-form-urlencoded"}, data=urlencode(poc_input), ) - request = Requests.get(request.id) + request = Requests.get(creator, request.id) assert dict_contains( request.body["primary_poc"], { diff --git a/tests/test_integration.py b/tests/test_integration.py index 633f14b1..8b674f9b 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -55,7 +55,7 @@ def test_stepthrough_request_form(user_session, screens, client): # at this point, the real request we made and the mock_request bodies # should be equivalent - assert Requests.get(req_id).body == mock_request.body + assert Requests.get(user, req_id).body == mock_request.body # finish the review and submit step client.post( @@ -63,5 +63,5 @@ def test_stepthrough_request_form(user_session, screens, client): headers={"Content-Type": "application/x-www-form-urlencoded"}, ) - finished_request = Requests.get(req_id) + finished_request = Requests.get(user, req_id) assert finished_request.status == RequestStatus.PENDING_CCPO_APPROVAL