diff --git a/atst/domain/auth.py b/atst/domain/auth.py
index a60a503a..5ac56fc9 100644
--- a/atst/domain/auth.py
+++ b/atst/domain/auth.py
@@ -21,10 +21,26 @@ def apply_authentication(app):
user = get_current_user()
if user:
g.current_user = user
+ if should_redirect_to_user_profile(request, user):
+ return redirect(url_for("users.user", next=request.path))
elif not _unprotected_route(request):
return redirect(url_for("atst.root", next=request.path))
+def should_redirect_to_user_profile(request, user):
+ has_complete_profile = user.profile_complete
+ is_unprotected_route = _unprotected_route(request)
+ is_requesting_user_endpoint = request.endpoint in [
+ "users.user",
+ "users.update_user",
+ ]
+
+ if has_complete_profile or is_unprotected_route or is_requesting_user_endpoint:
+ return False
+
+ return True
+
+
def get_current_user():
user_id = session.get("user_id")
if user_id:
diff --git a/atst/routes/users.py b/atst/routes/users.py
index e57584a7..f658b823 100644
--- a/atst/routes/users.py
+++ b/atst/routes/users.py
@@ -1,4 +1,4 @@
-from flask import Blueprint, render_template, g, request as http_request
+from flask import Blueprint, render_template, g, request as http_request, redirect
from atst.forms.edit_user import EditUserForm
from atst.domain.users import Users
@@ -10,16 +10,21 @@ bp = Blueprint("users", __name__)
def user():
user = g.current_user
form = EditUserForm(data=user.to_dictionary())
- return render_template("user/edit.html", form=form, user=user)
+ return render_template(
+ "user/edit.html", next=http_request.args.get("next"), form=form, user=user
+ )
@bp.route("/user", methods=["POST"])
def update_user():
user = g.current_user
form = EditUserForm(http_request.form)
- rerender_args = {"form": form, "user": user}
+ next_url = http_request.args.get("next")
+ rerender_args = {"form": form, "user": user, "next": next_url}
if form.validate():
Users.update(user, form.data)
rerender_args["updated"] = True
+ if next_url:
+ return redirect(next_url)
return render_template("user/edit.html", **rerender_args)
diff --git a/templates/user/edit.html b/templates/user/edit.html
index 6578c969..bd247a2f 100644
--- a/templates/user/edit.html
+++ b/templates/user/edit.html
@@ -4,6 +4,13 @@
{% block content %}
+ {% if next is not none %}
+ {{ Alert('You must complete your profile',
+ message='
Before continuing, you must complete your profile
',
+ level='info'
+ ) }}
+ {% endif %}
+
{% if form.errors %}
{{ Alert('There were some errors',
message="
Please see below.
",
@@ -25,7 +32,7 @@
- {% set form_action = url_for('users.update_user') %}
+ {% set form_action = url_for('users.update_user', next=next) %}
{% include "fragments/edit_user_form.html" %}
diff --git a/tests/routes/test_auth.py b/tests/routes/test_auth.py
new file mode 100644
index 00000000..aef34fbd
--- /dev/null
+++ b/tests/routes/test_auth.py
@@ -0,0 +1,45 @@
+from flask import url_for
+from urllib.parse import quote
+
+from tests.factories import UserFactory
+
+
+def test_request_page_with_complete_profile(client, user_session):
+ user = UserFactory.create()
+ user_session(user)
+ response = client.get("/requests", follow_redirects=False)
+ assert response.status_code == 200
+
+
+def test_redirect_when_profile_missing_fields(client, user_session):
+ user = UserFactory.create(date_latest_training=None)
+ user_session(user)
+ requested_url = "/requests"
+ response = client.get(requested_url, follow_redirects=False)
+ assert response.status_code == 302
+ assert "/user?next={}".format(quote(requested_url, safe="")) in response.location
+
+
+def test_unprotected_route_with_incomplete_profile(client, user_session):
+ user = UserFactory.create(date_latest_training=None)
+ user_session(user)
+ response = client.get("/about", follow_redirects=False)
+ assert response.status_code == 200
+
+
+def test_completing_user_profile(client, user_session):
+ user = UserFactory.create(phone_number=None)
+ user_session(user)
+ response = client.get("/requests", follow_redirects=True)
+ assert b"You must complete your profile" in response.data
+
+ updated_data = {**user.to_dictionary(), "phone_number": "5558675309"}
+ updated_data["date_latest_training"] = updated_data[
+ "date_latest_training"
+ ].strftime("%m/%d/%Y")
+ response = client.post(url_for("users.update_user"), data=updated_data)
+ assert response.status_code == 200
+
+ response = client.get("/requests", follow_redirects=False)
+ assert response.status_code == 200
+ assert b"You must complete your profile" not in response.data
diff --git a/tests/routes/test_users.py b/tests/routes/test_users.py
index 74fbdc44..50a723a1 100644
--- a/tests/routes/test_users.py
+++ b/tests/routes/test_users.py
@@ -23,3 +23,21 @@ def test_user_can_update_profile(user_session, client):
updated_user = Users.get_by_dod_id(user.dod_id)
assert updated_user.first_name == "chad"
assert updated_user.last_name == "vader"
+
+
+def test_user_is_redirected_when_updating_profile(user_session, client):
+ user = UserFactory.create()
+ user_session(user)
+ next_url = "/requests"
+
+ user_data = user.to_dictionary()
+ user_data["date_latest_training"] = user_data["date_latest_training"].strftime(
+ "%m/%d/%Y"
+ )
+ response = client.post(
+ url_for("users.update_user", next=next_url),
+ data=user_data,
+ follow_redirects=False,
+ )
+ assert response.status_code == 302
+ assert response.location.endswith(next_url)