diff --git a/atst/domain/auth.py b/atst/domain/auth.py index a60a503a..5ac56fc9 100644 --- a/atst/domain/auth.py +++ b/atst/domain/auth.py @@ -21,10 +21,26 @@ def apply_authentication(app): user = get_current_user() if user: g.current_user = user + if should_redirect_to_user_profile(request, user): + return redirect(url_for("users.user", next=request.path)) elif not _unprotected_route(request): return redirect(url_for("atst.root", next=request.path)) +def should_redirect_to_user_profile(request, user): + has_complete_profile = user.profile_complete + is_unprotected_route = _unprotected_route(request) + is_requesting_user_endpoint = request.endpoint in [ + "users.user", + "users.update_user", + ] + + if has_complete_profile or is_unprotected_route or is_requesting_user_endpoint: + return False + + return True + + def get_current_user(): user_id = session.get("user_id") if user_id: diff --git a/atst/routes/users.py b/atst/routes/users.py index e57584a7..f658b823 100644 --- a/atst/routes/users.py +++ b/atst/routes/users.py @@ -1,4 +1,4 @@ -from flask import Blueprint, render_template, g, request as http_request +from flask import Blueprint, render_template, g, request as http_request, redirect from atst.forms.edit_user import EditUserForm from atst.domain.users import Users @@ -10,16 +10,21 @@ bp = Blueprint("users", __name__) def user(): user = g.current_user form = EditUserForm(data=user.to_dictionary()) - return render_template("user/edit.html", form=form, user=user) + return render_template( + "user/edit.html", next=http_request.args.get("next"), form=form, user=user + ) @bp.route("/user", methods=["POST"]) def update_user(): user = g.current_user form = EditUserForm(http_request.form) - rerender_args = {"form": form, "user": user} + next_url = http_request.args.get("next") + rerender_args = {"form": form, "user": user, "next": next_url} if form.validate(): Users.update(user, form.data) rerender_args["updated"] = True + if next_url: + return redirect(next_url) return render_template("user/edit.html", **rerender_args) diff --git a/templates/user/edit.html b/templates/user/edit.html index 6578c969..bd247a2f 100644 --- a/templates/user/edit.html +++ b/templates/user/edit.html @@ -4,6 +4,13 @@ {% block content %}
+ {% if next is not none %} + {{ Alert('You must complete your profile', + message='

Before continuing, you must complete your profile

', + level='info' + ) }} + {% endif %} + {% if form.errors %} {{ Alert('There were some errors', message="

Please see below.

", @@ -25,7 +32,7 @@
- {% set form_action = url_for('users.update_user') %} + {% set form_action = url_for('users.update_user', next=next) %} {% include "fragments/edit_user_form.html" %} diff --git a/tests/routes/test_auth.py b/tests/routes/test_auth.py new file mode 100644 index 00000000..aef34fbd --- /dev/null +++ b/tests/routes/test_auth.py @@ -0,0 +1,45 @@ +from flask import url_for +from urllib.parse import quote + +from tests.factories import UserFactory + + +def test_request_page_with_complete_profile(client, user_session): + user = UserFactory.create() + user_session(user) + response = client.get("/requests", follow_redirects=False) + assert response.status_code == 200 + + +def test_redirect_when_profile_missing_fields(client, user_session): + user = UserFactory.create(date_latest_training=None) + user_session(user) + requested_url = "/requests" + response = client.get(requested_url, follow_redirects=False) + assert response.status_code == 302 + assert "/user?next={}".format(quote(requested_url, safe="")) in response.location + + +def test_unprotected_route_with_incomplete_profile(client, user_session): + user = UserFactory.create(date_latest_training=None) + user_session(user) + response = client.get("/about", follow_redirects=False) + assert response.status_code == 200 + + +def test_completing_user_profile(client, user_session): + user = UserFactory.create(phone_number=None) + user_session(user) + response = client.get("/requests", follow_redirects=True) + assert b"You must complete your profile" in response.data + + updated_data = {**user.to_dictionary(), "phone_number": "5558675309"} + updated_data["date_latest_training"] = updated_data[ + "date_latest_training" + ].strftime("%m/%d/%Y") + response = client.post(url_for("users.update_user"), data=updated_data) + assert response.status_code == 200 + + response = client.get("/requests", follow_redirects=False) + assert response.status_code == 200 + assert b"You must complete your profile" not in response.data diff --git a/tests/routes/test_users.py b/tests/routes/test_users.py index 74fbdc44..50a723a1 100644 --- a/tests/routes/test_users.py +++ b/tests/routes/test_users.py @@ -23,3 +23,21 @@ def test_user_can_update_profile(user_session, client): updated_user = Users.get_by_dod_id(user.dod_id) assert updated_user.first_name == "chad" assert updated_user.last_name == "vader" + + +def test_user_is_redirected_when_updating_profile(user_session, client): + user = UserFactory.create() + user_session(user) + next_url = "/requests" + + user_data = user.to_dictionary() + user_data["date_latest_training"] = user_data["date_latest_training"].strftime( + "%m/%d/%Y" + ) + response = client.post( + url_for("users.update_user", next=next_url), + data=user_data, + follow_redirects=False, + ) + assert response.status_code == 302 + assert response.location.endswith(next_url)