diff --git a/atst/domain/auth.py b/atst/domain/auth.py index 74c6d2d8..9af97dd7 100644 --- a/atst/domain/auth.py +++ b/atst/domain/auth.py @@ -17,14 +17,11 @@ def apply_authentication(app): @app.before_request # pylint: disable=unused-variable def enforce_login(): - - if not _unprotected_route(request): - user = get_current_user() - if user: - g.current_user = user - - else: - return redirect(url_for("atst.root")) + user = get_current_user() + if user: + g.current_user = user + elif not _unprotected_route(request): + return redirect(url_for("atst.root")) def get_current_user(): diff --git a/tests/test_auth.py b/tests/test_auth.py index 6ecd1ee1..47d7c6e1 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -73,25 +73,42 @@ def is_unprotected(rule): return rule.endpoint in UNPROTECTED_ROUTES -def test_routes_are_protected(client, app): +def protected_routes(app): for rule in app.url_map.iter_rules(): args = [1] * len(rule.arguments) mock_args = dict(zip(rule.arguments, args)) _n, route = rule.build(mock_args) if is_unprotected(rule) or "/static" in route: continue + yield rule, route + +def test_protected_routes_redirect_to_login(client, app): + for rule, protected_route in protected_routes(app): if "GET" in rule.methods: - resp = client.get(route) + resp = client.get(protected_route) assert resp.status_code == 302 assert resp.headers["Location"] == "http://localhost/" if "POST" in rule.methods: - resp = client.post(route) + resp = client.post(protected_route) assert resp.status_code == 302 assert resp.headers["Location"] == "http://localhost/" +def test_unprotected_routes_set_user_if_logged_in(client, app, user_session): + user = UserFactory.create() + + resp = client.get(url_for("atst.helpdocs")) + assert resp.status_code == 200 + assert user.full_name not in resp.data.decode() + + user_session(user) + resp = client.get(url_for("atst.helpdocs")) + assert resp.status_code == 200 + assert user.full_name in resp.data.decode() + + # this implicitly relies on the test config and test CRL in tests/fixtures/crl