From 8a89c519ebd26116435c518e7fe8c2550662b44f Mon Sep 17 00:00:00 2001 From: dandds Date: Mon, 1 Oct 2018 14:22:44 -0400 Subject: [PATCH] on user login, redirect based on next query parameter if available --- atst/routes/__init__.py | 9 ++++++++- atst/routes/dev.py | 3 ++- tests/test_auth.py | 17 +++++++++++++++-- 3 files changed, 25 insertions(+), 4 deletions(-) diff --git a/atst/routes/__init__.py b/atst/routes/__init__.py index beaa0c8e..18e719f4 100644 --- a/atst/routes/__init__.py +++ b/atst/routes/__init__.py @@ -80,6 +80,13 @@ def _make_authentication_context(): ) +def redirect_url(): + if request.args.get("next"): + return request.args.get("next") + else: + return url_for(".home") + + @bp.route("/login-redirect") def login_redirect(): auth_context = _make_authentication_context() @@ -87,7 +94,7 @@ def login_redirect(): user = auth_context.get_user() session["user_id"] = user.id - return redirect(url_for(".home")) + return redirect(redirect_url()) @bp.route("/logout") diff --git a/atst/routes/dev.py b/atst/routes/dev.py index 0b53daf0..1deff0ac 100644 --- a/atst/routes/dev.py +++ b/atst/routes/dev.py @@ -1,5 +1,6 @@ from flask import Blueprint, request, session, redirect, url_for +from . import redirect_url from atst.domain.users import Users bp = Blueprint("dev", __name__) @@ -63,4 +64,4 @@ def login_dev(): ) session["user_id"] = user.id - return redirect(url_for("atst.home")) + return redirect(redirect_url()) diff --git a/tests/test_auth.py b/tests/test_auth.py index 85bc62e0..87e90f77 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -17,9 +17,9 @@ def _fetch_user_info(c, t): return MOCK_USER -def _login(client, verify="SUCCESS", sdn=DOD_SDN, cert=""): +def _login(client, verify="SUCCESS", sdn=DOD_SDN, cert="", **url_query_args): return client.get( - url_for("atst.login_redirect"), + url_for("atst.login_redirect", **url_query_args), environ_base={ "HTTP_X_SSL_CLIENT_VERIFY": verify, "HTTP_X_SSL_CLIENT_S_DN": sdn, @@ -185,3 +185,16 @@ def test_logout(app, client, monkeypatch): assert resp_failure.status_code == 302 destination = urlparse(resp_failure.headers["Location"]).path assert destination == url_for("atst.root") + + +def test_redirected_on_login(client, monkeypatch): + monkeypatch.setattr( + "atst.domain.authnid.AuthenticationContext.authenticate", lambda *args: True + ) + monkeypatch.setattr( + "atst.domain.authnid.AuthenticationContext.get_user", + lambda *args: UserFactory.create(), + ) + target_route = url_for("requests.requests_form_new", screen=1) + response = _login(client, next=target_route) + assert target_route in response.headers.get("Location")