diff --git a/atst/domain/auth.py b/atst/domain/auth.py index 9af97dd7..999553dc 100644 --- a/atst/domain/auth.py +++ b/atst/domain/auth.py @@ -21,7 +21,7 @@ def apply_authentication(app): if user: g.current_user = user elif not _unprotected_route(request): - return redirect(url_for("atst.root")) + return redirect(url_for("atst.root", next=request.path)) def get_current_user(): diff --git a/atst/routes/__init__.py b/atst/routes/__init__.py index 4edbdb57..fd5b213a 100644 --- a/atst/routes/__init__.py +++ b/atst/routes/__init__.py @@ -1,4 +1,5 @@ from flask import Blueprint, render_template, g, redirect, session, url_for, request + from flask import current_app as app import pendulum diff --git a/tests/test_auth.py b/tests/test_auth.py index 47d7c6e1..85bc62e0 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -88,12 +88,19 @@ def test_protected_routes_redirect_to_login(client, app): if "GET" in rule.methods: resp = client.get(protected_route) assert resp.status_code == 302 - assert resp.headers["Location"] == "http://localhost/" + assert "http://localhost/" in resp.headers["Location"] if "POST" in rule.methods: resp = client.post(protected_route) assert resp.status_code == 302 - assert resp.headers["Location"] == "http://localhost/" + assert "http://localhost/" in resp.headers["Location"] + + +def test_get_protected_route_encodes_redirect(client): + workspace_index = url_for("workspaces.workspaces") + response = client.get(workspace_index) + redirect = url_for("atst.root", next=workspace_index) + assert redirect in response.headers["Location"] def test_unprotected_routes_set_user_if_logged_in(client, app, user_session):