Add fn to ensure a url matches an app url pattern
In some functions, we redirect a user based on a parameter in a query string. This commit adds a function that checks to see if a given url matches a url pattern of a view function. This will help us ensure that the url passed as the next parameter isn't malicious.
This commit is contained in:
@@ -14,7 +14,9 @@ from flask import (
|
||||
from jinja2.exceptions import TemplateNotFound
|
||||
import pendulum
|
||||
import os
|
||||
from werkzeug.exceptions import NotFound
|
||||
from werkzeug.exceptions import NotFound, MethodNotAllowed
|
||||
from werkzeug.routing import RequestRedirect
|
||||
|
||||
|
||||
from atst.domain.users import Users
|
||||
from atst.domain.authnid import AuthenticationContext
|
||||
@@ -61,17 +63,36 @@ def _make_authentication_context():
|
||||
|
||||
|
||||
def redirect_after_login_url():
|
||||
if request.args.get("next"):
|
||||
returl = request.args.get("next")
|
||||
if request.args.get(app.form_cache.PARAM_NAME):
|
||||
returl += "?" + url.urlencode(
|
||||
{app.form_cache.PARAM_NAME: request.args.get(app.form_cache.PARAM_NAME)}
|
||||
)
|
||||
returl = request.args.get("next")
|
||||
if match_url_pattern(returl):
|
||||
param_name = request.args.get(app.form_cache.PARAM_NAME)
|
||||
if param_name:
|
||||
returl += "?" + url.urlencode({app.form_cache.PARAM_NAME: param_name})
|
||||
return returl
|
||||
else:
|
||||
return url_for("atst.home")
|
||||
|
||||
|
||||
def match_url_pattern(url, method="GET"):
|
||||
"""Ensure a url matches a url pattern in the flask app
|
||||
inspired by https://stackoverflow.com/questions/38488134/get-the-flask-view-function-that-matches-a-url/38488506#38488506
|
||||
"""
|
||||
server_name = app.config.get("SERVER_NAME") or "localhost"
|
||||
adapter = app.url_map.bind(server_name=server_name)
|
||||
|
||||
try:
|
||||
match = adapter.match(url, method=method)
|
||||
except RequestRedirect as e:
|
||||
# recursively match redirects
|
||||
return match_url_pattern(e.new_url, method)
|
||||
except (MethodNotAllowed, NotFound):
|
||||
# no match
|
||||
return None
|
||||
|
||||
if match[0] in app.view_functions:
|
||||
return url
|
||||
|
||||
|
||||
def current_user_setup(user):
|
||||
session["user_id"] = user.id
|
||||
session["last_login"] = user.last_login
|
||||
|
@@ -3,6 +3,7 @@ from flask import Blueprint, render_template, g, request as http_request, redire
|
||||
from atst.forms.edit_user import EditUserForm
|
||||
from atst.domain.users import Users
|
||||
from atst.utils.flash import formatted_flash as flash
|
||||
from atst.routes import match_url_pattern
|
||||
|
||||
|
||||
bp = Blueprint("users", __name__)
|
||||
@@ -35,7 +36,7 @@ def update_user():
|
||||
if form.validate():
|
||||
Users.update(user, form.data)
|
||||
flash("user_updated")
|
||||
if next_url:
|
||||
if match_url_pattern(next_url):
|
||||
return redirect(next_url)
|
||||
|
||||
return render_template(
|
||||
|
Reference in New Issue
Block a user