Add fn to ensure a url matches an app url pattern

In some functions, we redirect a user based on a parameter in a query
string.  This commit adds a function that checks to see if a given url
matches a url pattern of a view function. This will help us ensure that
the url passed  as the next parameter isn't malicious.
This commit is contained in:
graham-dds 2020-01-28 11:30:38 -05:00
parent 7812da5eae
commit 82ef8f3574
4 changed files with 61 additions and 10 deletions

View File

@ -14,7 +14,9 @@ from flask import (
from jinja2.exceptions import TemplateNotFound from jinja2.exceptions import TemplateNotFound
import pendulum import pendulum
import os import os
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound, MethodNotAllowed
from werkzeug.routing import RequestRedirect
from atst.domain.users import Users from atst.domain.users import Users
from atst.domain.authnid import AuthenticationContext from atst.domain.authnid import AuthenticationContext
@ -61,17 +63,36 @@ def _make_authentication_context():
def redirect_after_login_url(): def redirect_after_login_url():
if request.args.get("next"):
returl = request.args.get("next") returl = request.args.get("next")
if request.args.get(app.form_cache.PARAM_NAME): if match_url_pattern(returl):
returl += "?" + url.urlencode( param_name = request.args.get(app.form_cache.PARAM_NAME)
{app.form_cache.PARAM_NAME: request.args.get(app.form_cache.PARAM_NAME)} if param_name:
) returl += "?" + url.urlencode({app.form_cache.PARAM_NAME: param_name})
return returl return returl
else: else:
return url_for("atst.home") return url_for("atst.home")
def match_url_pattern(url, method="GET"):
"""Ensure a url matches a url pattern in the flask app
inspired by https://stackoverflow.com/questions/38488134/get-the-flask-view-function-that-matches-a-url/38488506#38488506
"""
server_name = app.config.get("SERVER_NAME") or "localhost"
adapter = app.url_map.bind(server_name=server_name)
try:
match = adapter.match(url, method=method)
except RequestRedirect as e:
# recursively match redirects
return match_url_pattern(e.new_url, method)
except (MethodNotAllowed, NotFound):
# no match
return None
if match[0] in app.view_functions:
return url
def current_user_setup(user): def current_user_setup(user):
session["user_id"] = user.id session["user_id"] = user.id
session["last_login"] = user.last_login session["last_login"] = user.last_login

View File

@ -3,6 +3,7 @@ from flask import Blueprint, render_template, g, request as http_request, redire
from atst.forms.edit_user import EditUserForm from atst.forms.edit_user import EditUserForm
from atst.domain.users import Users from atst.domain.users import Users
from atst.utils.flash import formatted_flash as flash from atst.utils.flash import formatted_flash as flash
from atst.routes import match_url_pattern
bp = Blueprint("users", __name__) bp = Blueprint("users", __name__)
@ -35,7 +36,7 @@ def update_user():
if form.validate(): if form.validate():
Users.update(user, form.data) Users.update(user, form.data)
flash("user_updated") flash("user_updated")
if next_url: if match_url_pattern(next_url):
return redirect(next_url) return redirect(next_url)
return render_template( return render_template(

View File

@ -1,7 +1,36 @@
from tests.factories import UserFactory from tests.factories import UserFactory, PortfolioFactory
from atst.routes import match_url_pattern
def test_root_redirects_if_user_is_logged_in(client, user_session): def test_root_redirects_if_user_is_logged_in(client, user_session):
user_session(UserFactory.create()) user_session(UserFactory.create())
response = client.get("/", follow_redirects=False) response = client.get("/", follow_redirects=False)
assert "home" in response.location assert "home" in response.location
def test_match_url_pattern(client):
assert not match_url_pattern(None)
assert match_url_pattern("/home") == "/home"
portfolio = PortfolioFactory()
# matches a URL with an argument
assert (
match_url_pattern(f"/portfolios/{portfolio.id}") # /portfolios/<portfolio_id>
== f"/portfolios/{portfolio.id}"
)
# matches a url with a query string
assert (
match_url_pattern(f"/portfolios/{portfolio.id}?foo=bar")
== f"/portfolios/{portfolio.id}?foo=bar"
)
# matches a URL only with a valid method
assert not match_url_pattern(f"/portfolios/{portfolio.id}/edit")
assert (
match_url_pattern(f"/portfolios/{portfolio.id}/edit", method="POST")
== f"/portfolios/{portfolio.id}/edit"
)
# returns None for URL that doesn't match a view function
assert not match_url_pattern("/pwned")
assert not match_url_pattern("http://www.hackersite.com/pwned")

View File

@ -28,7 +28,7 @@ def test_user_can_update_profile(user_session, client):
def test_user_is_redirected_when_updating_profile(user_session, client): def test_user_is_redirected_when_updating_profile(user_session, client):
user = UserFactory.create() user = UserFactory.create()
user_session(user) user_session(user)
next_url = "/requests" next_url = "/home"
user_data = user.to_dictionary() user_data = user.to_dictionary()
user_data["date_latest_training"] = user_data["date_latest_training"].strftime( user_data["date_latest_training"] = user_data["date_latest_training"].strftime(