diff --git a/atst/domain/auth.py b/atst/domain/auth.py
index 9af97dd7..999553dc 100644
--- a/atst/domain/auth.py
+++ b/atst/domain/auth.py
@@ -21,7 +21,7 @@ def apply_authentication(app):
if user:
g.current_user = user
elif not _unprotected_route(request):
- return redirect(url_for("atst.root"))
+ return redirect(url_for("atst.root", next=request.path))
def get_current_user():
diff --git a/atst/routes/__init__.py b/atst/routes/__init__.py
index 4edbdb57..d13bdb36 100644
--- a/atst/routes/__init__.py
+++ b/atst/routes/__init__.py
@@ -1,4 +1,6 @@
+import urllib.parse as url
from flask import Blueprint, render_template, g, redirect, session, url_for, request
+
from flask import current_app as app
import pendulum
@@ -15,7 +17,16 @@ bp = Blueprint("atst", __name__)
@bp.route("/")
def root():
- return render_template("login.html")
+ redirect_url = app.config.get("CAC_URL")
+ if request.args.get("next"):
+ redirect_url = url.urljoin(
+ redirect_url,
+ "?{}".format(url.urlencode({"next": request.args.get("next")})),
+ )
+
+ return render_template(
+ "login.html", redirect=bool(request.args.get("next")), redirect_url=redirect_url
+ )
@bp.route("/help")
@@ -70,6 +81,13 @@ def _make_authentication_context():
)
+def redirect_after_login_url():
+ if request.args.get("next"):
+ return request.args.get("next")
+ else:
+ return url_for("atst.home")
+
+
@bp.route("/login-redirect")
def login_redirect():
auth_context = _make_authentication_context()
@@ -77,13 +95,13 @@ def login_redirect():
user = auth_context.get_user()
session["user_id"] = user.id
- return redirect(url_for(".home"))
+ return redirect(redirect_after_login_url())
@bp.route("/logout")
def logout():
_logout()
- return redirect(url_for(".home"))
+ return redirect(url_for(".root"))
@bp.route("/activity-history")
diff --git a/atst/routes/dev.py b/atst/routes/dev.py
index 0b53daf0..e9b107e8 100644
--- a/atst/routes/dev.py
+++ b/atst/routes/dev.py
@@ -1,5 +1,6 @@
-from flask import Blueprint, request, session, redirect, url_for
+from flask import Blueprint, request, session, redirect
+from . import redirect_after_login_url
from atst.domain.users import Users
bp = Blueprint("dev", __name__)
@@ -63,4 +64,4 @@ def login_dev():
)
session["user_id"] = user.id
- return redirect(url_for("atst.home"))
+ return redirect(redirect_after_login_url())
diff --git a/templates/login.html b/templates/login.html
index 7beaee66..0ff48b72 100644
--- a/templates/login.html
+++ b/templates/login.html
@@ -15,13 +15,20 @@
- Sign in with CAC
+ Sign in with CAC
{% if g.dev %}
- DEV Login
+ DEV Login
{% endif %}
+ {% if redirect %}
+ {{ Alert('Log in Required.',
+ message='After you log in, you will be redirected to your destination page.',
+ level='warning'
+ ) }}
+ {% endif %}
+
{{ Alert('Certificate Selection',
message='When you are prompted to select a certificate, please select E-mail Certificate from the provided choices.',
actions=[
diff --git a/tests/test_auth.py b/tests/test_auth.py
index 47d7c6e1..87e90f77 100644
--- a/tests/test_auth.py
+++ b/tests/test_auth.py
@@ -17,9 +17,9 @@ def _fetch_user_info(c, t):
return MOCK_USER
-def _login(client, verify="SUCCESS", sdn=DOD_SDN, cert=""):
+def _login(client, verify="SUCCESS", sdn=DOD_SDN, cert="", **url_query_args):
return client.get(
- url_for("atst.login_redirect"),
+ url_for("atst.login_redirect", **url_query_args),
environ_base={
"HTTP_X_SSL_CLIENT_VERIFY": verify,
"HTTP_X_SSL_CLIENT_S_DN": sdn,
@@ -88,12 +88,19 @@ def test_protected_routes_redirect_to_login(client, app):
if "GET" in rule.methods:
resp = client.get(protected_route)
assert resp.status_code == 302
- assert resp.headers["Location"] == "http://localhost/"
+ assert "http://localhost/" in resp.headers["Location"]
if "POST" in rule.methods:
resp = client.post(protected_route)
assert resp.status_code == 302
- assert resp.headers["Location"] == "http://localhost/"
+ assert "http://localhost/" in resp.headers["Location"]
+
+
+def test_get_protected_route_encodes_redirect(client):
+ workspace_index = url_for("workspaces.workspaces")
+ response = client.get(workspace_index)
+ redirect = url_for("atst.root", next=workspace_index)
+ assert redirect in response.headers["Location"]
def test_unprotected_routes_set_user_if_logged_in(client, app, user_session):
@@ -178,3 +185,16 @@ def test_logout(app, client, monkeypatch):
assert resp_failure.status_code == 302
destination = urlparse(resp_failure.headers["Location"]).path
assert destination == url_for("atst.root")
+
+
+def test_redirected_on_login(client, monkeypatch):
+ monkeypatch.setattr(
+ "atst.domain.authnid.AuthenticationContext.authenticate", lambda *args: True
+ )
+ monkeypatch.setattr(
+ "atst.domain.authnid.AuthenticationContext.get_user",
+ lambda *args: UserFactory.create(),
+ )
+ target_route = url_for("requests.requests_form_new", screen=1)
+ response = _login(client, next=target_route)
+ assert target_route in response.headers.get("Location")