diff --git a/atst/domain/requests.py b/atst/domain/requests.py index fcdd78bb..85fb688a 100644 --- a/atst/domain/requests.py +++ b/atst/domain/requests.py @@ -31,8 +31,8 @@ class Requests(object): AUTO_APPROVE_THRESHOLD = 1000000 @classmethod - def create(cls, creator_id, body): - request = Request(creator=creator_id, body=body) + def create(cls, creator, body): + request = Request(creator=creator, body=body) request = Requests.set_status(request, RequestStatus.STARTED) db.session.add(request) @@ -41,11 +41,11 @@ class Requests(object): return request @classmethod - def exists(cls, request_id, creator_id): + def exists(cls, request_id, creator): try: return db.session.query( exists().where( - and_(Request.id == request_id, Request.creator == creator_id) + and_(Request.id == request_id, Request.creator == creator) ) ).scalar() except exc.DataError: diff --git a/atst/routes/requests/requests_form.py b/atst/routes/requests/requests_form.py index ad415775..d384abdf 100644 --- a/atst/routes/requests/requests_form.py +++ b/atst/routes/requests/requests_form.py @@ -111,7 +111,7 @@ def requests_submit(request_id=None): def _check_can_view_request(request_id): if Permissions.REVIEW_AND_APPROVE_JEDI_WORKSPACE_REQUEST in g.current_user.atat_permissions: pass - elif Requests.exists(request_id, g.current_user.id): + elif Requests.exists(request_id, g.current_user): pass else: raise UnauthorizedError(g.current_user, "view request {}".format(request_id)) diff --git a/tests/conftest.py b/tests/conftest.py index c3add296..ab912679 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,7 +5,6 @@ import alembic.command from atst.app import make_app, make_config from atst.database import db as _db -from .mocks import MOCK_USER import tests.factories as factories @@ -75,7 +74,6 @@ class DummyForm(dict): class DummyField(object): - def __init__(self, data=None, errors=(), raw_data=None): self.data = data self.errors = list(errors) @@ -93,9 +91,11 @@ def dummy_field(): @pytest.fixture -def user_session(monkeypatch): - - def set_user_session(user=MOCK_USER): - monkeypatch.setattr("atst.domain.auth.get_current_user", lambda *args: user) +def user_session(monkeypatch, session): + def set_user_session(user=None): + monkeypatch.setattr( + "atst.domain.auth.get_current_user", + lambda *args: user or factories.UserFactory.build(), + ) return set_user_session diff --git a/tests/domain/test_requests.py b/tests/domain/test_requests.py index fad23c51..ebdc64a0 100644 --- a/tests/domain/test_requests.py +++ b/tests/domain/test_requests.py @@ -53,6 +53,6 @@ def test_dont_auto_approve_if_no_dollar_value_specified(new_request): def test_exists(session): user_allowed = UserFactory.create() user_denied = UserFactory.create() - request = RequestFactory.create(creator=user_allowed.id) - assert Requests.exists(request.id, user_allowed.id) - assert not Requests.exists(request.id, user_denied.id) + request = RequestFactory.create(creator=user_allowed) + assert Requests.exists(request.id, user_allowed) + assert not Requests.exists(request.id, user_denied) diff --git a/tests/factories.py b/tests/factories.py index 68209045..cf81fa71 100644 --- a/tests/factories.py +++ b/tests/factories.py @@ -20,7 +20,7 @@ class RoleFactory(factory.alchemy.SQLAlchemyModelFactory): permissions = [] - + class UserFactory(factory.alchemy.SQLAlchemyModelFactory): class Meta: model = User @@ -47,7 +47,7 @@ class RequestFactory(factory.alchemy.SQLAlchemyModelFactory): id = factory.Sequence(lambda x: uuid4()) status_events = factory.RelatedFactory( - RequestStatusFactory, "request", new_status=RequestStatus.STARTED + RequestStatusEventFactory, "request", new_status=RequestStatus.STARTED ) creator = factory.SubFactory(UserFactory) body = factory.LazyAttribute(lambda r: RequestFactory.build_request_body(r.creator)) diff --git a/tests/routes/test_request_new.py b/tests/routes/test_request_new.py index e402e622..e31aae79 100644 --- a/tests/routes/test_request_new.py +++ b/tests/routes/test_request_new.py @@ -33,7 +33,7 @@ def test_submit_valid_request_form(monkeypatch, client, user_session): def test_owner_can_view_request(client, user_session): user = UserFactory.create() user_session(user) - request = RequestFactory.create(creator=user.id) + request = RequestFactory.create(creator=user) response = client.get("/requests/new/1/{}".format(request.id), follow_redirects=True)