diff --git a/atst/domain/auth.py b/atst/domain/auth.py index 9af97dd7..999553dc 100644 --- a/atst/domain/auth.py +++ b/atst/domain/auth.py @@ -21,7 +21,7 @@ def apply_authentication(app): if user: g.current_user = user elif not _unprotected_route(request): - return redirect(url_for("atst.root")) + return redirect(url_for("atst.root", next=request.path)) def get_current_user(): diff --git a/atst/routes/__init__.py b/atst/routes/__init__.py index 4edbdb57..d13bdb36 100644 --- a/atst/routes/__init__.py +++ b/atst/routes/__init__.py @@ -1,4 +1,6 @@ +import urllib.parse as url from flask import Blueprint, render_template, g, redirect, session, url_for, request + from flask import current_app as app import pendulum @@ -15,7 +17,16 @@ bp = Blueprint("atst", __name__) @bp.route("/") def root(): - return render_template("login.html") + redirect_url = app.config.get("CAC_URL") + if request.args.get("next"): + redirect_url = url.urljoin( + redirect_url, + "?{}".format(url.urlencode({"next": request.args.get("next")})), + ) + + return render_template( + "login.html", redirect=bool(request.args.get("next")), redirect_url=redirect_url + ) @bp.route("/help") @@ -70,6 +81,13 @@ def _make_authentication_context(): ) +def redirect_after_login_url(): + if request.args.get("next"): + return request.args.get("next") + else: + return url_for("atst.home") + + @bp.route("/login-redirect") def login_redirect(): auth_context = _make_authentication_context() @@ -77,13 +95,13 @@ def login_redirect(): user = auth_context.get_user() session["user_id"] = user.id - return redirect(url_for(".home")) + return redirect(redirect_after_login_url()) @bp.route("/logout") def logout(): _logout() - return redirect(url_for(".home")) + return redirect(url_for(".root")) @bp.route("/activity-history") diff --git a/atst/routes/dev.py b/atst/routes/dev.py index 0b53daf0..e9b107e8 100644 --- a/atst/routes/dev.py +++ b/atst/routes/dev.py @@ -1,5 +1,6 @@ -from flask import Blueprint, request, session, redirect, url_for +from flask import Blueprint, request, session, redirect +from . import redirect_after_login_url from atst.domain.users import Users bp = Blueprint("dev", __name__) @@ -63,4 +64,4 @@ def login_dev(): ) session["user_id"] = user.id - return redirect(url_for("atst.home")) + return redirect(redirect_after_login_url()) diff --git a/templates/login.html b/templates/login.html index 7beaee66..0ff48b72 100644 --- a/templates/login.html +++ b/templates/login.html @@ -15,13 +15,20 @@ - Sign in with CAC + Sign in with CAC {% if g.dev %} - DEV Login + DEV Login {% endif %} + {% if redirect %} + {{ Alert('Log in Required.', + message='After you log in, you will be redirected to your destination page.', + level='warning' + ) }} + {% endif %} + {{ Alert('Certificate Selection', message='When you are prompted to select a certificate, please select E-mail Certificate from the provided choices.', actions=[ diff --git a/tests/test_auth.py b/tests/test_auth.py index 47d7c6e1..87e90f77 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -17,9 +17,9 @@ def _fetch_user_info(c, t): return MOCK_USER -def _login(client, verify="SUCCESS", sdn=DOD_SDN, cert=""): +def _login(client, verify="SUCCESS", sdn=DOD_SDN, cert="", **url_query_args): return client.get( - url_for("atst.login_redirect"), + url_for("atst.login_redirect", **url_query_args), environ_base={ "HTTP_X_SSL_CLIENT_VERIFY": verify, "HTTP_X_SSL_CLIENT_S_DN": sdn, @@ -88,12 +88,19 @@ def test_protected_routes_redirect_to_login(client, app): if "GET" in rule.methods: resp = client.get(protected_route) assert resp.status_code == 302 - assert resp.headers["Location"] == "http://localhost/" + assert "http://localhost/" in resp.headers["Location"] if "POST" in rule.methods: resp = client.post(protected_route) assert resp.status_code == 302 - assert resp.headers["Location"] == "http://localhost/" + assert "http://localhost/" in resp.headers["Location"] + + +def test_get_protected_route_encodes_redirect(client): + workspace_index = url_for("workspaces.workspaces") + response = client.get(workspace_index) + redirect = url_for("atst.root", next=workspace_index) + assert redirect in response.headers["Location"] def test_unprotected_routes_set_user_if_logged_in(client, app, user_session): @@ -178,3 +185,16 @@ def test_logout(app, client, monkeypatch): assert resp_failure.status_code == 302 destination = urlparse(resp_failure.headers["Location"]).path assert destination == url_for("atst.root") + + +def test_redirected_on_login(client, monkeypatch): + monkeypatch.setattr( + "atst.domain.authnid.AuthenticationContext.authenticate", lambda *args: True + ) + monkeypatch.setattr( + "atst.domain.authnid.AuthenticationContext.get_user", + lambda *args: UserFactory.create(), + ) + target_route = url_for("requests.requests_form_new", screen=1) + response = _login(client, next=target_route) + assert target_route in response.headers.get("Location")