Merge branch 'staging' into azure-admin-provisioning

This commit is contained in:
tomdds 2020-01-30 11:17:33 -05:00 committed by GitHub
commit 6480060b8d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 61 additions and 10 deletions

View File

@ -14,7 +14,9 @@ from flask import (
from jinja2.exceptions import TemplateNotFound
import pendulum
import os
from werkzeug.exceptions import NotFound
from werkzeug.exceptions import NotFound, MethodNotAllowed
from werkzeug.routing import RequestRedirect
from atst.domain.users import Users
from atst.domain.authnid import AuthenticationContext
@ -61,17 +63,36 @@ def _make_authentication_context():
def redirect_after_login_url():
if request.args.get("next"):
returl = request.args.get("next")
if request.args.get(app.form_cache.PARAM_NAME):
returl += "?" + url.urlencode(
{app.form_cache.PARAM_NAME: request.args.get(app.form_cache.PARAM_NAME)}
)
returl = request.args.get("next")
if match_url_pattern(returl):
param_name = request.args.get(app.form_cache.PARAM_NAME)
if param_name:
returl += "?" + url.urlencode({app.form_cache.PARAM_NAME: param_name})
return returl
else:
return url_for("atst.home")
def match_url_pattern(url, method="GET"):
"""Ensure a url matches a url pattern in the flask app
inspired by https://stackoverflow.com/questions/38488134/get-the-flask-view-function-that-matches-a-url/38488506#38488506
"""
server_name = app.config.get("SERVER_NAME") or "localhost"
adapter = app.url_map.bind(server_name=server_name)
try:
match = adapter.match(url, method=method)
except RequestRedirect as e:
# recursively match redirects
return match_url_pattern(e.new_url, method)
except (MethodNotAllowed, NotFound):
# no match
return None
if match[0] in app.view_functions:
return url
def current_user_setup(user):
session["user_id"] = user.id
session["last_login"] = user.last_login

View File

@ -3,6 +3,7 @@ from flask import Blueprint, render_template, g, request as http_request, redire
from atst.forms.edit_user import EditUserForm
from atst.domain.users import Users
from atst.utils.flash import formatted_flash as flash
from atst.routes import match_url_pattern
bp = Blueprint("users", __name__)
@ -35,7 +36,7 @@ def update_user():
if form.validate():
Users.update(user, form.data)
flash("user_updated")
if next_url:
if match_url_pattern(next_url):
return redirect(next_url)
return render_template(

View File

@ -1,7 +1,36 @@
from tests.factories import UserFactory
from tests.factories import UserFactory, PortfolioFactory
from atst.routes import match_url_pattern
def test_root_redirects_if_user_is_logged_in(client, user_session):
user_session(UserFactory.create())
response = client.get("/", follow_redirects=False)
assert "home" in response.location
def test_match_url_pattern(client):
assert not match_url_pattern(None)
assert match_url_pattern("/home") == "/home"
portfolio = PortfolioFactory()
# matches a URL with an argument
assert (
match_url_pattern(f"/portfolios/{portfolio.id}") # /portfolios/<portfolio_id>
== f"/portfolios/{portfolio.id}"
)
# matches a url with a query string
assert (
match_url_pattern(f"/portfolios/{portfolio.id}?foo=bar")
== f"/portfolios/{portfolio.id}?foo=bar"
)
# matches a URL only with a valid method
assert not match_url_pattern(f"/portfolios/{portfolio.id}/edit")
assert (
match_url_pattern(f"/portfolios/{portfolio.id}/edit", method="POST")
== f"/portfolios/{portfolio.id}/edit"
)
# returns None for URL that doesn't match a view function
assert not match_url_pattern("/pwned")
assert not match_url_pattern("http://www.hackersite.com/pwned")

View File

@ -28,7 +28,7 @@ def test_user_can_update_profile(user_session, client):
def test_user_is_redirected_when_updating_profile(user_session, client):
user = UserFactory.create()
user_session(user)
next_url = "/requests"
next_url = "/home"
user_data = user.to_dictionary()
user_data["date_latest_training"] = user_data["date_latest_training"].strftime(