create AuthenticationContext to consolidate auth logic
This commit is contained in:
parent
3a41d9f81c
commit
07ce940650
@ -0,0 +1,62 @@
|
|||||||
|
from atst.domain.exceptions import UnauthenticatedError, NotFoundError
|
||||||
|
from atst.domain.users import Users
|
||||||
|
from .utils import parse_sdn, email_from_certificate
|
||||||
|
|
||||||
|
|
||||||
|
class AuthenticationContext():
|
||||||
|
|
||||||
|
def __init__(self, crl_validator, auth_status, sdn, cert):
|
||||||
|
if None in locals().values():
|
||||||
|
raise UnauthenticatedError("Missing required authentication context components")
|
||||||
|
|
||||||
|
self.crl_validator = crl_validator
|
||||||
|
self.auth_status = auth_status
|
||||||
|
self.sdn = sdn
|
||||||
|
self.cert = cert.encode()
|
||||||
|
self._parsed_sdn = None
|
||||||
|
|
||||||
|
|
||||||
|
def authenticate(self):
|
||||||
|
if not self.auth_status == "SUCCESS":
|
||||||
|
raise UnauthenticatedError("SSL/TLS client authentication failed")
|
||||||
|
|
||||||
|
elif not self._crl_check():
|
||||||
|
raise UnauthenticatedError("Client certificate failed CRL check")
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def get_user(self):
|
||||||
|
try:
|
||||||
|
return Users.get_by_dod_id(self.parsed_sdn["dod_id"])
|
||||||
|
|
||||||
|
except NotFoundError:
|
||||||
|
email = self._get_user_email()
|
||||||
|
return Users.create(**{"email": email, **self.parsed_sdn})
|
||||||
|
|
||||||
|
def _get_user_email(self):
|
||||||
|
try:
|
||||||
|
return email_from_certificate(self.cert)
|
||||||
|
|
||||||
|
# this just means it is not an email certificate; we might choose to
|
||||||
|
# log in that case
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _crl_check(self):
|
||||||
|
if self.cert:
|
||||||
|
result = self.crl_validator.validate(self.cert)
|
||||||
|
return result
|
||||||
|
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def parsed_sdn(self):
|
||||||
|
if not self._parsed_sdn:
|
||||||
|
try:
|
||||||
|
self._parsed_sdn = parse_sdn(self.sdn)
|
||||||
|
except ValueError as exc:
|
||||||
|
raise UnauthenticatedError(str(exc))
|
||||||
|
|
||||||
|
return self._parsed_sdn
|
@ -30,3 +30,7 @@ class UnauthenticatedError(Exception):
|
|||||||
@property
|
@property
|
||||||
def message(self):
|
def message(self):
|
||||||
return str(self)
|
return str(self)
|
||||||
|
|
||||||
|
|
||||||
|
class CRLValidationError(Exception):
|
||||||
|
pass
|
||||||
|
@ -4,8 +4,8 @@ import pendulum
|
|||||||
|
|
||||||
from atst.domain.requests import Requests
|
from atst.domain.requests import Requests
|
||||||
from atst.domain.users import Users
|
from atst.domain.users import Users
|
||||||
from atst.domain.authnid.utils import parse_sdn, email_from_certificate
|
from atst.domain.authnid import AuthenticationContext
|
||||||
from atst.domain.exceptions import UnauthenticatedError, NotFoundError
|
|
||||||
|
|
||||||
bp = Blueprint("atst", __name__)
|
bp = Blueprint("atst", __name__)
|
||||||
|
|
||||||
@ -30,29 +30,23 @@ def catch_all(path):
|
|||||||
return render_template("{}.html".format(path))
|
return render_template("{}.html".format(path))
|
||||||
|
|
||||||
|
|
||||||
# TODO: this should be partly consolidated into a domain function that takes
|
def _make_authentication_context():
|
||||||
# all the necessary UWSGI environment values as args and either returns a user
|
return AuthenticationContext(
|
||||||
# or raises the UnauthenticatedError
|
crl_validator=app.crl_validator,
|
||||||
|
auth_status=request.environ.get("HTTP_X_SSL_CLIENT_VERIFY"),
|
||||||
|
sdn=request.environ.get("HTTP_X_SSL_CLIENT_S_DN"),
|
||||||
|
cert=request.environ.get("HTTP_X_SSL_CLIENT_CERT")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/login-redirect')
|
@bp.route('/login-redirect')
|
||||||
def login_redirect():
|
def login_redirect():
|
||||||
# raise S_DN parse errors
|
auth_context = _make_authentication_context()
|
||||||
if request.environ.get('HTTP_X_SSL_CLIENT_VERIFY') == 'SUCCESS' and _is_valid_certificate(request):
|
auth_context.authenticate()
|
||||||
sdn = request.environ.get('HTTP_X_SSL_CLIENT_S_DN')
|
user = auth_context.get_user()
|
||||||
sdn_parts = parse_sdn(sdn)
|
|
||||||
try:
|
|
||||||
user = Users.get_by_dod_id(sdn_parts["dod_id"])
|
|
||||||
except NotFoundError:
|
|
||||||
try:
|
|
||||||
email = email_from_certificate(request.environ.get('HTTP_X_SSL_CLIENT_CERT').encode())
|
|
||||||
sdn_parts["email"] = email
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
user = Users.create(**sdn_parts)
|
|
||||||
session["user_id"] = user.id
|
session["user_id"] = user.id
|
||||||
|
|
||||||
return redirect(url_for("atst.home"))
|
return redirect(url_for("atst.home"))
|
||||||
else:
|
|
||||||
raise UnauthenticatedError()
|
|
||||||
|
|
||||||
|
|
||||||
def _is_valid_certificate(request):
|
def _is_valid_certificate(request):
|
||||||
|
@ -3,6 +3,7 @@ from flask import session, url_for
|
|||||||
from .mocks import DOD_SDN_INFO, DOD_SDN, FIXTURE_EMAIL_ADDRESS
|
from .mocks import DOD_SDN_INFO, DOD_SDN, FIXTURE_EMAIL_ADDRESS
|
||||||
from atst.domain.users import Users
|
from atst.domain.users import Users
|
||||||
from atst.domain.exceptions import NotFoundError
|
from atst.domain.exceptions import NotFoundError
|
||||||
|
from .factories import UserFactory
|
||||||
|
|
||||||
|
|
||||||
MOCK_USER = {"id": "438567dd-25fa-4d83-a8cc-8aa8366cb24a"}
|
MOCK_USER = {"id": "438567dd-25fa-4d83-a8cc-8aa8366cb24a"}
|
||||||
@ -13,14 +14,14 @@ def _fetch_user_info(c, t):
|
|||||||
|
|
||||||
|
|
||||||
def test_successful_login_redirect(client, monkeypatch):
|
def test_successful_login_redirect(client, monkeypatch):
|
||||||
monkeypatch.setattr("atst.routes._is_valid_certificate", lambda *args: True)
|
monkeypatch.setattr("atst.domain.authnid.AuthenticationContext.authenticate", lambda *args: True)
|
||||||
monkeypatch.setattr("atst.routes.email_from_certificate", lambda *args: None)
|
monkeypatch.setattr("atst.domain.authnid.AuthenticationContext.get_user", lambda *args: UserFactory.create())
|
||||||
|
|
||||||
resp = client.get(
|
resp = client.get(
|
||||||
"/login-redirect",
|
"/login-redirect",
|
||||||
environ_base={
|
environ_base={
|
||||||
"HTTP_X_SSL_CLIENT_VERIFY": "SUCCESS",
|
"HTTP_X_SSL_CLIENT_VERIFY": "SUCCESS",
|
||||||
"HTTP_X_SSL_CLIENT_S_DN": DOD_SDN,
|
"HTTP_X_SSL_CLIENT_S_DN": "",
|
||||||
"HTTP_X_SSL_CLIENT_CERT": "",
|
"HTTP_X_SSL_CLIENT_CERT": "",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@ -94,7 +95,7 @@ def test_crl_validation_on_login(client):
|
|||||||
|
|
||||||
|
|
||||||
def test_creates_new_user_on_login(monkeypatch, client):
|
def test_creates_new_user_on_login(monkeypatch, client):
|
||||||
monkeypatch.setattr("atst.routes._is_valid_certificate", lambda *args: True)
|
monkeypatch.setattr("atst.domain.authnid.AuthenticationContext.authenticate", lambda *args: True)
|
||||||
cert_file = open("tests/fixtures/{}.crt".format(FIXTURE_EMAIL_ADDRESS)).read()
|
cert_file = open("tests/fixtures/{}.crt".format(FIXTURE_EMAIL_ADDRESS)).read()
|
||||||
|
|
||||||
# ensure user does not exist
|
# ensure user does not exist
|
||||||
|
Loading…
x
Reference in New Issue
Block a user