From 1a5800cbc5493ea58ab7455ceb2d7756420674d4 Mon Sep 17 00:00:00 2001 From: dandds Date: Tue, 7 Aug 2018 15:40:51 -0400 Subject: [PATCH] Requests domain module can determine if user can view request --- atst/domain/requests.py | 8 ++++++++ tests/domain/test_requests.py | 10 +++++++++- tests/factories.py | 9 ++++++--- 3 files changed, 23 insertions(+), 4 deletions(-) diff --git a/atst/domain/requests.py b/atst/domain/requests.py index e7b64e5d..843c7243 100644 --- a/atst/domain/requests.py +++ b/atst/domain/requests.py @@ -143,3 +143,11 @@ class Requests(object): return request.status == "incomplete" and all( section in existing_request_sections for section in all_request_sections ) + + @classmethod + def is_creator(cls, request_id, user_id): + try: + db.session.query(Request).filter_by(id=request_id, creator=user_id).one() + return True + except NoResultFound: + return False diff --git a/tests/domain/test_requests.py b/tests/domain/test_requests.py index 1d696097..609dae90 100644 --- a/tests/domain/test_requests.py +++ b/tests/domain/test_requests.py @@ -5,7 +5,7 @@ from atst.domain.exceptions import NotFoundError from atst.domain.requests import Requests from atst.models.request_status_event import RequestStatus -from tests.factories import RequestFactory +from tests.factories import RequestFactory, UserFactory @pytest.fixture(scope="function") @@ -48,3 +48,11 @@ def test_dont_auto_approve_if_no_dollar_value_specified(new_request): request = Requests.submit(new_request) assert request.status == RequestStatus.PENDING_CCPO_APPROVAL + + +def test_can_check_if_user_created_request(session): + user_allowed = UserFactory.create() + user_denied = UserFactory.create() + request = RequestFactory.create(creator=user_allowed.id) + assert Requests.is_creator(request.id, user_allowed.id) + assert not Requests.is_creator(request.id, user_denied.id) diff --git a/tests/factories.py b/tests/factories.py index dc68eb5c..d731c42f 100644 --- a/tests/factories.py +++ b/tests/factories.py @@ -1,3 +1,5 @@ +import random +import string import factory from uuid import uuid4 @@ -46,7 +48,8 @@ class UserFactory(factory.alchemy.SQLAlchemyModelFactory): model = User id = factory.Sequence(lambda x: uuid4()) - email = "fake.user@mail.com" - first_name = "Fake" - last_name = "User" + email = factory.Faker("email") + first_name = factory.Faker("first_name") + last_name = factory.Faker("last_name") atat_role = factory.SubFactory(RoleFactory) + dod_id = factory.LazyFunction(lambda: "".join(random.choices(string.digits, k=10)))