Merge pull request #1364 from dod-ccpo/safe_redirect
Add fn to ensure a url matches an app url pattern
This commit is contained in:
commit
088bd37c6b
@ -14,7 +14,9 @@ from flask import (
|
|||||||
from jinja2.exceptions import TemplateNotFound
|
from jinja2.exceptions import TemplateNotFound
|
||||||
import pendulum
|
import pendulum
|
||||||
import os
|
import os
|
||||||
from werkzeug.exceptions import NotFound
|
from werkzeug.exceptions import NotFound, MethodNotAllowed
|
||||||
|
from werkzeug.routing import RequestRedirect
|
||||||
|
|
||||||
|
|
||||||
from atst.domain.users import Users
|
from atst.domain.users import Users
|
||||||
from atst.domain.authnid import AuthenticationContext
|
from atst.domain.authnid import AuthenticationContext
|
||||||
@ -61,17 +63,36 @@ def _make_authentication_context():
|
|||||||
|
|
||||||
|
|
||||||
def redirect_after_login_url():
|
def redirect_after_login_url():
|
||||||
if request.args.get("next"):
|
|
||||||
returl = request.args.get("next")
|
returl = request.args.get("next")
|
||||||
if request.args.get(app.form_cache.PARAM_NAME):
|
if match_url_pattern(returl):
|
||||||
returl += "?" + url.urlencode(
|
param_name = request.args.get(app.form_cache.PARAM_NAME)
|
||||||
{app.form_cache.PARAM_NAME: request.args.get(app.form_cache.PARAM_NAME)}
|
if param_name:
|
||||||
)
|
returl += "?" + url.urlencode({app.form_cache.PARAM_NAME: param_name})
|
||||||
return returl
|
return returl
|
||||||
else:
|
else:
|
||||||
return url_for("atst.home")
|
return url_for("atst.home")
|
||||||
|
|
||||||
|
|
||||||
|
def match_url_pattern(url, method="GET"):
|
||||||
|
"""Ensure a url matches a url pattern in the flask app
|
||||||
|
inspired by https://stackoverflow.com/questions/38488134/get-the-flask-view-function-that-matches-a-url/38488506#38488506
|
||||||
|
"""
|
||||||
|
server_name = app.config.get("SERVER_NAME") or "localhost"
|
||||||
|
adapter = app.url_map.bind(server_name=server_name)
|
||||||
|
|
||||||
|
try:
|
||||||
|
match = adapter.match(url, method=method)
|
||||||
|
except RequestRedirect as e:
|
||||||
|
# recursively match redirects
|
||||||
|
return match_url_pattern(e.new_url, method)
|
||||||
|
except (MethodNotAllowed, NotFound):
|
||||||
|
# no match
|
||||||
|
return None
|
||||||
|
|
||||||
|
if match[0] in app.view_functions:
|
||||||
|
return url
|
||||||
|
|
||||||
|
|
||||||
def current_user_setup(user):
|
def current_user_setup(user):
|
||||||
session["user_id"] = user.id
|
session["user_id"] = user.id
|
||||||
session["last_login"] = user.last_login
|
session["last_login"] = user.last_login
|
||||||
|
@ -3,6 +3,7 @@ from flask import Blueprint, render_template, g, request as http_request, redire
|
|||||||
from atst.forms.edit_user import EditUserForm
|
from atst.forms.edit_user import EditUserForm
|
||||||
from atst.domain.users import Users
|
from atst.domain.users import Users
|
||||||
from atst.utils.flash import formatted_flash as flash
|
from atst.utils.flash import formatted_flash as flash
|
||||||
|
from atst.routes import match_url_pattern
|
||||||
|
|
||||||
|
|
||||||
bp = Blueprint("users", __name__)
|
bp = Blueprint("users", __name__)
|
||||||
@ -35,7 +36,7 @@ def update_user():
|
|||||||
if form.validate():
|
if form.validate():
|
||||||
Users.update(user, form.data)
|
Users.update(user, form.data)
|
||||||
flash("user_updated")
|
flash("user_updated")
|
||||||
if next_url:
|
if match_url_pattern(next_url):
|
||||||
return redirect(next_url)
|
return redirect(next_url)
|
||||||
|
|
||||||
return render_template(
|
return render_template(
|
||||||
|
@ -1,7 +1,36 @@
|
|||||||
from tests.factories import UserFactory
|
from tests.factories import UserFactory, PortfolioFactory
|
||||||
|
from atst.routes import match_url_pattern
|
||||||
|
|
||||||
|
|
||||||
def test_root_redirects_if_user_is_logged_in(client, user_session):
|
def test_root_redirects_if_user_is_logged_in(client, user_session):
|
||||||
user_session(UserFactory.create())
|
user_session(UserFactory.create())
|
||||||
response = client.get("/", follow_redirects=False)
|
response = client.get("/", follow_redirects=False)
|
||||||
assert "home" in response.location
|
assert "home" in response.location
|
||||||
|
|
||||||
|
|
||||||
|
def test_match_url_pattern(client):
|
||||||
|
|
||||||
|
assert not match_url_pattern(None)
|
||||||
|
assert match_url_pattern("/home") == "/home"
|
||||||
|
|
||||||
|
portfolio = PortfolioFactory()
|
||||||
|
# matches a URL with an argument
|
||||||
|
assert (
|
||||||
|
match_url_pattern(f"/portfolios/{portfolio.id}") # /portfolios/<portfolio_id>
|
||||||
|
== f"/portfolios/{portfolio.id}"
|
||||||
|
)
|
||||||
|
# matches a url with a query string
|
||||||
|
assert (
|
||||||
|
match_url_pattern(f"/portfolios/{portfolio.id}?foo=bar")
|
||||||
|
== f"/portfolios/{portfolio.id}?foo=bar"
|
||||||
|
)
|
||||||
|
# matches a URL only with a valid method
|
||||||
|
assert not match_url_pattern(f"/portfolios/{portfolio.id}/edit")
|
||||||
|
assert (
|
||||||
|
match_url_pattern(f"/portfolios/{portfolio.id}/edit", method="POST")
|
||||||
|
== f"/portfolios/{portfolio.id}/edit"
|
||||||
|
)
|
||||||
|
|
||||||
|
# returns None for URL that doesn't match a view function
|
||||||
|
assert not match_url_pattern("/pwned")
|
||||||
|
assert not match_url_pattern("http://www.hackersite.com/pwned")
|
||||||
|
@ -28,7 +28,7 @@ def test_user_can_update_profile(user_session, client):
|
|||||||
def test_user_is_redirected_when_updating_profile(user_session, client):
|
def test_user_is_redirected_when_updating_profile(user_session, client):
|
||||||
user = UserFactory.create()
|
user = UserFactory.create()
|
||||||
user_session(user)
|
user_session(user)
|
||||||
next_url = "/requests"
|
next_url = "/home"
|
||||||
|
|
||||||
user_data = user.to_dictionary()
|
user_data = user.to_dictionary()
|
||||||
user_data["date_latest_training"] = user_data["date_latest_training"].strftime(
|
user_data["date_latest_training"] = user_data["date_latest_training"].strftime(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user