From ad1e1e771bd0fa29e28d730bace7dd623703b9bc Mon Sep 17 00:00:00 2001 From: dandds Date: Mon, 6 Aug 2018 09:18:53 -0400 Subject: [PATCH] extract get_current_user, fix tests --- atst/domain/auth.py | 12 ++++++++++-- tests/mocks.py | 6 ++++++ tests/test_auth.py | 7 +------ tests/test_routes.py | 16 +++++++++------- 4 files changed, 26 insertions(+), 15 deletions(-) diff --git a/atst/domain/auth.py b/atst/domain/auth.py index e14f2aa2..d1a8ec36 100644 --- a/atst/domain/auth.py +++ b/atst/domain/auth.py @@ -8,11 +8,19 @@ def login_required(f): @wraps(f) def decorated_function(*args, **kwargs): - if session.get("user_id"): - g.user = Users.get(session.get("user_id")) + user = get_current_user() + if user: + g.user = user return f(*args, **kwargs) else: return redirect(url_for("atst.root")) return decorated_function + +def get_current_user(): + user_id = session.get("user_id") + if user_id: + return Users.get(user_id) + else: + return False diff --git a/tests/mocks.py b/tests/mocks.py index c8903099..c39e6de6 100644 --- a/tests/mocks.py +++ b/tests/mocks.py @@ -19,6 +19,12 @@ MOCK_REQUEST = RequestFactory.create( }, } ) +DOD_SDN_INFO = { + 'first_name': 'ART', + 'last_name': 'GARFUNKEL', + 'dod_id': '5892460358' + } +DOD_SDN = f"CN={DOD_SDN_INFO['last_name']}.{DOD_SDN_INFO['first_name']}.G.{DOD_SDN_INFO['dod_id']},OU=OTHER,OU=PKI,OU=DoD,O=U.S. Government,C=US" class MockApiClient(ApiClient): diff --git a/tests/test_auth.py b/tests/test_auth.py index fb036869..8d6c2e60 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -1,13 +1,8 @@ from flask import session +from .mocks import DOD_SDN MOCK_USER = {"id": "438567dd-25fa-4d83-a8cc-8aa8366cb24a"} -DOD_SDN_INFO = { - 'first_name': 'ART', - 'last_name': 'GARFUNKEL', - 'dod_id': '5892460358' - } -DOD_SDN = f"CN={DOD_SDN_INFO['last_name']}.{DOD_SDN_INFO['first_name']}.G.{DOD_SDN_INFO['dod_id']},OU=OTHER,OU=PKI,OU=DoD,O=U.S. Government,C=US" def _fetch_user_info(c, t): diff --git a/tests/test_routes.py b/tests/test_routes.py index 71a4f9bc..1c050c84 100644 --- a/tests/test_routes.py +++ b/tests/test_routes.py @@ -1,5 +1,6 @@ -def test_routes(client): - for path in ( +import pytest + +@pytest.mark.parametrize("path", ( "/", "/home", "/workspaces", @@ -9,8 +10,9 @@ def test_routes(client): "/users", "/reports", "/calculator", - ): - response = client.get(path) - if response.status_code == 404: - __import__('ipdb').set_trace() - assert response.status_code == 200 + )) +def test_routes(path, client, monkeypatch): + monkeypatch.setattr("atst.domain.auth.get_current_user", lambda *args: True) + + response = client.get(path) + assert response.status_code == 200