diff --git a/atst/app.py b/atst/app.py index 696cf645..305d3b6d 100644 --- a/atst/app.py +++ b/atst/app.py @@ -15,6 +15,7 @@ from atst.routes.workspaces import bp as workspace_routes from atst.routes.requests import requests_bp from atst.routes.dev import bp as dev_routes from atst.domain.authnid.crl.validator import Validator +from atst.domain.auth import apply_authentication ENV = os.getenv("FLASK_ENV", "dev") @@ -47,6 +48,8 @@ def make_app(config): if ENV != "production": app.register_blueprint(dev_routes) + apply_authentication(app) + return app @@ -136,3 +139,4 @@ def make_crl_validator(app): ) for e in app.crl_validator.errors: app.logger.error(e) + diff --git a/atst/domain/auth.py b/atst/domain/auth.py index cc3b1a60..4dcb7e22 100644 --- a/atst/domain/auth.py +++ b/atst/domain/auth.py @@ -1,22 +1,23 @@ -from functools import wraps -from flask import g, redirect, url_for, session +from flask import g, redirect, url_for, session, request from atst.domain.users import Users -def login_required(f): +UNPROTECTED_ROUTES = ["atst.root", "atst.login_dev", "atst.login_redirect", "atst.unauthorized"] - @wraps(f) - def decorated_function(*args, **kwargs): - user = get_current_user() - if user: - g.current_user = user - return f(*args, **kwargs) +def apply_authentication(app): + @app.before_request + # pylint: disable=unused-variable + def enforce_login(): - else: - return redirect(url_for("atst.root")) + if not _unprotected_route(request): + user = get_current_user() + if user: + g.current_user = user + + else: + return redirect(url_for("atst.root")) - return decorated_function def get_current_user(): user_id = session.get("user_id") @@ -24,3 +25,8 @@ def get_current_user(): return Users.get(user_id) else: return False + +def _unprotected_route(request): + if request.endpoint in UNPROTECTED_ROUTES: + return True + diff --git a/atst/routes/__init__.py b/atst/routes/__init__.py index 660796d1..1884d28f 100644 --- a/atst/routes/__init__.py +++ b/atst/routes/__init__.py @@ -5,7 +5,6 @@ import pendulum from atst.domain.requests import Requests from atst.domain.users import Users from atst.domain.authnid.utils import parse_sdn -from atst.domain.auth import login_required bp = Blueprint("atst", __name__) @@ -16,19 +15,16 @@ def root(): @bp.route("/home") -@login_required def home(): return render_template("home.html") @bp.route("/styleguide") -@login_required def styleguide(): return render_template("styleguide.html") @bp.route('/') -@login_required def catch_all(path): return render_template("{}.html".format(path)) diff --git a/atst/routes/requests/financial_verification.py b/atst/routes/requests/financial_verification.py index d4b64923..38420287 100644 --- a/atst/routes/requests/financial_verification.py +++ b/atst/routes/requests/financial_verification.py @@ -4,11 +4,9 @@ from flask import request as http_request from . import requests_bp from atst.domain.requests import Requests from atst.forms.financial import FinancialForm -from atst.domain.auth import login_required @requests_bp.route("/requests/verify/", methods=["GET"]) -@login_required def financial_verification(request_id=None): request = Requests.get(request_id) form = FinancialForm(data=request.body.get("financial_verification")) @@ -18,7 +16,6 @@ def financial_verification(request_id=None): @requests_bp.route("/requests/verify/", methods=["POST"]) -@login_required def update_financial_verification(request_id): post_data = http_request.form existing_request = Requests.get(request_id) @@ -43,6 +40,5 @@ def update_financial_verification(request_id): @requests_bp.route("/requests/financial_verification_submitted") -@login_required def financial_verification_submitted(): pass diff --git a/atst/routes/requests/index.py b/atst/routes/requests/index.py index b415163b..0de74bd2 100644 --- a/atst/routes/requests/index.py +++ b/atst/routes/requests/index.py @@ -3,7 +3,6 @@ from flask import render_template, g from . import requests_bp from atst.domain.requests import Requests -from atst.domain.auth import login_required def map_request(user, request): @@ -21,7 +20,6 @@ def map_request(user, request): @requests_bp.route("/requests", methods=["GET"]) -@login_required def requests_index(): requests = [] if ( diff --git a/atst/routes/requests/requests_form.py b/atst/routes/requests/requests_form.py index 8ebd20ca..63a14224 100644 --- a/atst/routes/requests/requests_form.py +++ b/atst/routes/requests/requests_form.py @@ -3,11 +3,9 @@ from flask import g, redirect, render_template, url_for, request as http_request from . import requests_bp from atst.domain.requests import Requests from atst.routes.requests.jedi_request_flow import JEDIRequestFlow -from atst.domain.auth import login_required @requests_bp.route("/requests/new/", methods=["GET"]) -@login_required def requests_form_new(screen): jedi_flow = JEDIRequestFlow(screen, request=None) @@ -26,7 +24,6 @@ def requests_form_new(screen): "/requests/new/", methods=["GET"], defaults={"request_id": None} ) @requests_bp.route("/requests/new//", methods=["GET"]) -@login_required def requests_form_update(screen=1, request_id=None): request = Requests.get(request_id) if request_id is not None else None jedi_flow = JEDIRequestFlow(screen, request, request_id=request_id) @@ -47,7 +44,6 @@ def requests_form_update(screen=1, request_id=None): "/requests/new/", methods=["POST"], defaults={"request_id": None} ) @requests_bp.route("/requests/new//", methods=["POST"]) -@login_required def requests_update(screen=1, request_id=None): screen = int(screen) post_data = http_request.form @@ -92,7 +88,6 @@ def requests_update(screen=1, request_id=None): @requests_bp.route("/requests/submit/", methods=["POST"]) -@login_required def requests_submit(request_id=None): request = Requests.get(request_id) Requests.submit(request) diff --git a/atst/routes/workspaces.py b/atst/routes/workspaces.py index 50c96dd6..e614c08c 100644 --- a/atst/routes/workspaces.py +++ b/atst/routes/workspaces.py @@ -1,7 +1,6 @@ from flask import Blueprint, render_template from atst.domain.workspaces import Projects, Members -from atst.domain.auth import login_required bp = Blueprint("workspaces", __name__) @@ -17,13 +16,11 @@ mock_workspaces = [ @bp.route("/workspaces") -@login_required def workspaces(): return render_template("workspaces.html", page=5, workspaces=mock_workspaces) @bp.route("/workspaces//projects") -@login_required def workspace_projects(workspace_id): projects_repo = Projects() projects = projects_repo.get_many(workspace_id) @@ -33,7 +30,6 @@ def workspace_projects(workspace_id): @bp.route("/workspaces//members") -@login_required def workspace_members(workspace_id): members_repo = Members() members = members_repo.get_many(workspace_id) diff --git a/tests/test_auth.py b/tests/test_auth.py index 6631b567..69cb3166 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -15,8 +15,7 @@ def test_successful_login_redirect(client, monkeypatch): resp = client.get( "/login-redirect", environ_base={ - "HTTP_X_SSL_CLIENT_VERIFY": "SUCCESS", - "HTTP_X_SSL_CLIENT_S_DN": DOD_SDN, + "HTTP_X_SSL_CLIENT_VERIFY": "SUCCESS", "HTTP_X_SSL_CLIENT_S_DN": DOD_SDN }, ) @@ -32,9 +31,10 @@ def test_unsuccessful_login_redirect(client, monkeypatch): assert "unauthorized" in resp.headers["Location"] assert "user_id" not in session -UNPROTECTED_ROUTES = ["/", "/login-dev", "/login-redirect", "/unauthorized"] # checks that all of the routes in the app are protected by auth + + def test_routes_are_protected(client, app): for rule in app.url_map.iter_rules(): args = [1] * len(rule.arguments) @@ -54,10 +54,14 @@ def test_routes_are_protected(client, app): assert resp.headers["Location"] == "http://localhost/" +UNPROTECTED_ROUTES = ["/", "/login-dev", "/login-redirect", "/unauthorized"] + # this implicitly relies on the test config and test CRL in tests/fixtures/crl + + def test_crl_validation_on_login(client): - good_cert = open('ssl/client-certs/atat.mil.crt', 'rb').read() - bad_cert = open('ssl/client-certs/bad-atat.mil.crt', 'rb').read() + good_cert = open("ssl/client-certs/atat.mil.crt", "rb").read() + bad_cert = open("ssl/client-certs/bad-atat.mil.crt", "rb").read() # bad cert is on the test CRL resp = client.get( @@ -65,7 +69,7 @@ def test_crl_validation_on_login(client): environ_base={ "HTTP_X_SSL_CLIENT_VERIFY": "SUCCESS", "HTTP_X_SSL_CLIENT_S_DN": DOD_SDN, - "HTTP_X_SSL_CLIENT_CERT": bad_cert.decode() + "HTTP_X_SSL_CLIENT_CERT": bad_cert.decode(), }, ) assert resp.status_code == 302 @@ -78,10 +82,9 @@ def test_crl_validation_on_login(client): environ_base={ "HTTP_X_SSL_CLIENT_VERIFY": "SUCCESS", "HTTP_X_SSL_CLIENT_S_DN": DOD_SDN, - "HTTP_X_SSL_CLIENT_CERT": good_cert.decode() + "HTTP_X_SSL_CLIENT_CERT": good_cert.decode(), }, ) assert resp.status_code == 302 assert "home" in resp.headers["Location"] assert session["user_id"] -