Set current user for public routes as well

This commit is contained in:
Patrick Smith 2018-09-26 16:50:10 -04:00
parent 0a8c488f1e
commit a27c1b5712
2 changed files with 25 additions and 11 deletions

View File

@ -17,14 +17,11 @@ def apply_authentication(app):
@app.before_request @app.before_request
# pylint: disable=unused-variable # pylint: disable=unused-variable
def enforce_login(): def enforce_login():
user = get_current_user()
if not _unprotected_route(request): if user:
user = get_current_user() g.current_user = user
if user: elif not _unprotected_route(request):
g.current_user = user return redirect(url_for("atst.root"))
else:
return redirect(url_for("atst.root"))
def get_current_user(): def get_current_user():

View File

@ -73,25 +73,42 @@ def is_unprotected(rule):
return rule.endpoint in UNPROTECTED_ROUTES return rule.endpoint in UNPROTECTED_ROUTES
def test_routes_are_protected(client, app): def protected_routes(app):
for rule in app.url_map.iter_rules(): for rule in app.url_map.iter_rules():
args = [1] * len(rule.arguments) args = [1] * len(rule.arguments)
mock_args = dict(zip(rule.arguments, args)) mock_args = dict(zip(rule.arguments, args))
_n, route = rule.build(mock_args) _n, route = rule.build(mock_args)
if is_unprotected(rule) or "/static" in route: if is_unprotected(rule) or "/static" in route:
continue continue
yield rule, route
def test_protected_routes_redirect_to_login(client, app):
for rule, protected_route in protected_routes(app):
if "GET" in rule.methods: if "GET" in rule.methods:
resp = client.get(route) resp = client.get(protected_route)
assert resp.status_code == 302 assert resp.status_code == 302
assert resp.headers["Location"] == "http://localhost/" assert resp.headers["Location"] == "http://localhost/"
if "POST" in rule.methods: if "POST" in rule.methods:
resp = client.post(route) resp = client.post(protected_route)
assert resp.status_code == 302 assert resp.status_code == 302
assert resp.headers["Location"] == "http://localhost/" assert resp.headers["Location"] == "http://localhost/"
def test_unprotected_routes_set_user_if_logged_in(client, app, user_session):
user = UserFactory.create()
resp = client.get(url_for("atst.helpdocs"))
assert resp.status_code == 200
assert user.full_name not in resp.data.decode()
user_session(user)
resp = client.get(url_for("atst.helpdocs"))
assert resp.status_code == 200
assert user.full_name in resp.data.decode()
# this implicitly relies on the test config and test CRL in tests/fixtures/crl # this implicitly relies on the test config and test CRL in tests/fixtures/crl