diff --git a/atst/app.py b/atst/app.py index c30baa88..c1b5d231 100644 --- a/atst/app.py +++ b/atst/app.py @@ -67,6 +67,7 @@ def make_app(config): def make_flask_callbacks(app): @app.before_request def _set_globals(): + g.current_user = None g.dev = os.getenv("FLASK_ENV", "dev") == "dev" g.matchesPath = lambda href: re.match("^" + href, request.path) g.modal = request.args.get("modal", None) @@ -74,7 +75,7 @@ def make_flask_callbacks(app): @app.after_request def _cleanup(response): - g.pop("current_user", None) + g.current_user = None return response diff --git a/atst/routes/__init__.py b/atst/routes/__init__.py index d586ef6e..c4b76de3 100644 --- a/atst/routes/__init__.py +++ b/atst/routes/__init__.py @@ -18,6 +18,9 @@ bp = Blueprint("atst", __name__) @bp.route("/") def root(): + if g.current_user: + return redirect(url_for(".home")) + redirect_url = app.config.get("CAC_URL") if request.args.get("next"): redirect_url = url.urljoin( diff --git a/tests/routes/test_root.py b/tests/routes/test_root.py new file mode 100644 index 00000000..b06befaf --- /dev/null +++ b/tests/routes/test_root.py @@ -0,0 +1,7 @@ +from tests.factories import UserFactory + + +def test_root_redirects_if_user_is_logged_in(client, user_session): + user_session(UserFactory.create()) + response = client.get("/", follow_redirects=False) + assert "home" in response.location diff --git a/tests/test_routes.py b/tests/test_routes.py index c2d77150..b1c03a4a 100644 --- a/tests/test_routes.py +++ b/tests/test_routes.py @@ -4,7 +4,6 @@ import pytest @pytest.mark.parametrize( "path", ( - "/", "/workspaces", "/requests", "/requests/new/1",