Set current user for public routes as well
This commit is contained in:
parent
0a8c488f1e
commit
a27c1b5712
@ -17,14 +17,11 @@ def apply_authentication(app):
|
|||||||
@app.before_request
|
@app.before_request
|
||||||
# pylint: disable=unused-variable
|
# pylint: disable=unused-variable
|
||||||
def enforce_login():
|
def enforce_login():
|
||||||
|
user = get_current_user()
|
||||||
if not _unprotected_route(request):
|
if user:
|
||||||
user = get_current_user()
|
g.current_user = user
|
||||||
if user:
|
elif not _unprotected_route(request):
|
||||||
g.current_user = user
|
return redirect(url_for("atst.root"))
|
||||||
|
|
||||||
else:
|
|
||||||
return redirect(url_for("atst.root"))
|
|
||||||
|
|
||||||
|
|
||||||
def get_current_user():
|
def get_current_user():
|
||||||
|
@ -73,25 +73,42 @@ def is_unprotected(rule):
|
|||||||
return rule.endpoint in UNPROTECTED_ROUTES
|
return rule.endpoint in UNPROTECTED_ROUTES
|
||||||
|
|
||||||
|
|
||||||
def test_routes_are_protected(client, app):
|
def protected_routes(app):
|
||||||
for rule in app.url_map.iter_rules():
|
for rule in app.url_map.iter_rules():
|
||||||
args = [1] * len(rule.arguments)
|
args = [1] * len(rule.arguments)
|
||||||
mock_args = dict(zip(rule.arguments, args))
|
mock_args = dict(zip(rule.arguments, args))
|
||||||
_n, route = rule.build(mock_args)
|
_n, route = rule.build(mock_args)
|
||||||
if is_unprotected(rule) or "/static" in route:
|
if is_unprotected(rule) or "/static" in route:
|
||||||
continue
|
continue
|
||||||
|
yield rule, route
|
||||||
|
|
||||||
|
|
||||||
|
def test_protected_routes_redirect_to_login(client, app):
|
||||||
|
for rule, protected_route in protected_routes(app):
|
||||||
if "GET" in rule.methods:
|
if "GET" in rule.methods:
|
||||||
resp = client.get(route)
|
resp = client.get(protected_route)
|
||||||
assert resp.status_code == 302
|
assert resp.status_code == 302
|
||||||
assert resp.headers["Location"] == "http://localhost/"
|
assert resp.headers["Location"] == "http://localhost/"
|
||||||
|
|
||||||
if "POST" in rule.methods:
|
if "POST" in rule.methods:
|
||||||
resp = client.post(route)
|
resp = client.post(protected_route)
|
||||||
assert resp.status_code == 302
|
assert resp.status_code == 302
|
||||||
assert resp.headers["Location"] == "http://localhost/"
|
assert resp.headers["Location"] == "http://localhost/"
|
||||||
|
|
||||||
|
|
||||||
|
def test_unprotected_routes_set_user_if_logged_in(client, app, user_session):
|
||||||
|
user = UserFactory.create()
|
||||||
|
|
||||||
|
resp = client.get(url_for("atst.helpdocs"))
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert user.full_name not in resp.data.decode()
|
||||||
|
|
||||||
|
user_session(user)
|
||||||
|
resp = client.get(url_for("atst.helpdocs"))
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert user.full_name in resp.data.decode()
|
||||||
|
|
||||||
|
|
||||||
# this implicitly relies on the test config and test CRL in tests/fixtures/crl
|
# this implicitly relies on the test config and test CRL in tests/fixtures/crl
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user