From 82ef8f3574dfaf9c5229ddb03e151e44cbda3288 Mon Sep 17 00:00:00 2001 From: graham-dds Date: Tue, 28 Jan 2020 11:30:38 -0500 Subject: [PATCH] Add fn to ensure a url matches an app url pattern In some functions, we redirect a user based on a parameter in a query string. This commit adds a function that checks to see if a given url matches a url pattern of a view function. This will help us ensure that the url passed as the next parameter isn't malicious. --- atst/routes/__init__.py | 35 ++++++++++++++++++++++++++++------- atst/routes/users.py | 3 ++- tests/routes/test_root.py | 31 ++++++++++++++++++++++++++++++- tests/routes/test_users.py | 2 +- 4 files changed, 61 insertions(+), 10 deletions(-) diff --git a/atst/routes/__init__.py b/atst/routes/__init__.py index 78934400..52873e4c 100644 --- a/atst/routes/__init__.py +++ b/atst/routes/__init__.py @@ -14,7 +14,9 @@ from flask import ( from jinja2.exceptions import TemplateNotFound import pendulum import os -from werkzeug.exceptions import NotFound +from werkzeug.exceptions import NotFound, MethodNotAllowed +from werkzeug.routing import RequestRedirect + from atst.domain.users import Users from atst.domain.authnid import AuthenticationContext @@ -61,17 +63,36 @@ def _make_authentication_context(): def redirect_after_login_url(): - if request.args.get("next"): - returl = request.args.get("next") - if request.args.get(app.form_cache.PARAM_NAME): - returl += "?" + url.urlencode( - {app.form_cache.PARAM_NAME: request.args.get(app.form_cache.PARAM_NAME)} - ) + returl = request.args.get("next") + if match_url_pattern(returl): + param_name = request.args.get(app.form_cache.PARAM_NAME) + if param_name: + returl += "?" + url.urlencode({app.form_cache.PARAM_NAME: param_name}) return returl else: return url_for("atst.home") +def match_url_pattern(url, method="GET"): + """Ensure a url matches a url pattern in the flask app + inspired by https://stackoverflow.com/questions/38488134/get-the-flask-view-function-that-matches-a-url/38488506#38488506 + """ + server_name = app.config.get("SERVER_NAME") or "localhost" + adapter = app.url_map.bind(server_name=server_name) + + try: + match = adapter.match(url, method=method) + except RequestRedirect as e: + # recursively match redirects + return match_url_pattern(e.new_url, method) + except (MethodNotAllowed, NotFound): + # no match + return None + + if match[0] in app.view_functions: + return url + + def current_user_setup(user): session["user_id"] = user.id session["last_login"] = user.last_login diff --git a/atst/routes/users.py b/atst/routes/users.py index b9325f93..ec5557aa 100644 --- a/atst/routes/users.py +++ b/atst/routes/users.py @@ -3,6 +3,7 @@ from flask import Blueprint, render_template, g, request as http_request, redire from atst.forms.edit_user import EditUserForm from atst.domain.users import Users from atst.utils.flash import formatted_flash as flash +from atst.routes import match_url_pattern bp = Blueprint("users", __name__) @@ -35,7 +36,7 @@ def update_user(): if form.validate(): Users.update(user, form.data) flash("user_updated") - if next_url: + if match_url_pattern(next_url): return redirect(next_url) return render_template( diff --git a/tests/routes/test_root.py b/tests/routes/test_root.py index b06befaf..b012e87d 100644 --- a/tests/routes/test_root.py +++ b/tests/routes/test_root.py @@ -1,7 +1,36 @@ -from tests.factories import UserFactory +from tests.factories import UserFactory, PortfolioFactory +from atst.routes import match_url_pattern def test_root_redirects_if_user_is_logged_in(client, user_session): user_session(UserFactory.create()) response = client.get("/", follow_redirects=False) assert "home" in response.location + + +def test_match_url_pattern(client): + + assert not match_url_pattern(None) + assert match_url_pattern("/home") == "/home" + + portfolio = PortfolioFactory() + # matches a URL with an argument + assert ( + match_url_pattern(f"/portfolios/{portfolio.id}") # /portfolios/ + == f"/portfolios/{portfolio.id}" + ) + # matches a url with a query string + assert ( + match_url_pattern(f"/portfolios/{portfolio.id}?foo=bar") + == f"/portfolios/{portfolio.id}?foo=bar" + ) + # matches a URL only with a valid method + assert not match_url_pattern(f"/portfolios/{portfolio.id}/edit") + assert ( + match_url_pattern(f"/portfolios/{portfolio.id}/edit", method="POST") + == f"/portfolios/{portfolio.id}/edit" + ) + + # returns None for URL that doesn't match a view function + assert not match_url_pattern("/pwned") + assert not match_url_pattern("http://www.hackersite.com/pwned") diff --git a/tests/routes/test_users.py b/tests/routes/test_users.py index 50a723a1..47a67f80 100644 --- a/tests/routes/test_users.py +++ b/tests/routes/test_users.py @@ -28,7 +28,7 @@ def test_user_can_update_profile(user_session, client): def test_user_is_redirected_when_updating_profile(user_session, client): user = UserFactory.create() user_session(user) - next_url = "/requests" + next_url = "/home" user_data = user.to_dictionary() user_data["date_latest_training"] = user_data["date_latest_training"].strftime(