diff --git a/atst/app.py b/atst/app.py index cbca4c22..ddf66a40 100644 --- a/atst/app.py +++ b/atst/app.py @@ -16,7 +16,7 @@ from atst.routes.workspaces import bp as workspace_routes from atst.routes.requests import requests_bp from atst.routes.dev import bp as dev_routes from atst.routes.errors import make_error_pages -from atst.domain.authnid.crl import Validator, CRLCache +from atst.domain.authnid.crl import CRLCache from atst.domain.auth import apply_authentication diff --git a/atst/domain/authnid/__init__.py b/atst/domain/authnid/__init__.py index 80d645b8..35a4477c 100644 --- a/atst/domain/authnid/__init__.py +++ b/atst/domain/authnid/__init__.py @@ -1,17 +1,18 @@ from atst.domain.exceptions import UnauthenticatedError, NotFoundError from atst.domain.users import Users from .utils import parse_sdn, email_from_certificate +from .crl import Validator class AuthenticationContext(): - def __init__(self, crl_validator, auth_status, sdn, cert): + def __init__(self, crl_cache, auth_status, sdn, cert): if None in locals().values(): raise UnauthenticatedError( "Missing required authentication context components" ) - self.crl_validator = crl_validator + self.crl_cache = crl_cache self.auth_status = auth_status self.sdn = sdn self.cert = cert.encode() @@ -44,8 +45,9 @@ class AuthenticationContext(): return None def _crl_check(self): + validator = Validator(self.crl_cache, self.cert) if self.cert: - result = self.crl_validator.validate(self.cert) + result = validator.validate() return result else: diff --git a/atst/domain/authnid/crl/__init__.py b/atst/domain/authnid/crl/__init__.py index dc988546..5bbc7c72 100644 --- a/atst/domain/authnid/crl/__init__.py +++ b/atst/domain/authnid/crl/__init__.py @@ -56,28 +56,50 @@ class CRLCache(): # theoretically it can build a longer certificate chain def _add_certificate_chain_to_store(self, store, issuer): ca = self.certificate_authorities.get(issuer.der()) - # i.e., it is the root CA - if issuer == ca.get_subject(): - return store - store.add_cert(ca) - return self._add_certificate_chain_to_store(store, ca.get_issuer()) + if issuer == ca.get_subject(): + # i.e., it is the root CA and we are at the end of the chain + return store + else: + return self._add_certificate_chain_to_store(store, ca.get_issuer()) + + def get_store(self, cert): + return self._check_cache(cert.get_issuer().der()) + + def _check_cache(self, issuer): + if issuer in self.crl_cache: + filename, checksum = self.crl_cache[issuer] + if sha256_checksum(filename) != checksum: + issuer, store = self._build_store(filename) + self.x509_stores[issuer] = store + return store + else: + return self.x509_stores[issuer] class Validator: - _PEM_RE = re.compile( - b"-----BEGIN CERTIFICATE-----\r?.+?\r?-----END CERTIFICATE-----\r?\n?", - re.DOTALL, - ) - - def __init__(self, root, crl_locations=[], base_store=crypto.X509Store, logger=None): - self.crl_locations = crl_locations - self.root = root - self.base_store = base_store + def __init__(self, cache, cert, logger=None): + self.cache = cache + self.cert = cert self.logger = logger - self._reset() + + def validate(self): + parsed = crypto.load_certificate(crypto.FILETYPE_PEM, self.cert) + store = self.cache.get_store(parsed) + context = crypto.X509StoreContext(store, parsed) + try: + context.verify_certificate() + return True + + except crypto.X509StoreContextError as err: + self.log_error( + "Certificate revoked or errored. Error: {}. Args: {}".format( + type(err), err.args + ) + ) + return False def _add_roots(self, roots): with open(filename, "rb") as f: @@ -161,26 +183,3 @@ class Validator: return error.args == self.PRELOADED_CRL or error.args == self.PRELOADED_CERT # Checks that the CRL currently in-memory is up-to-date via the checksum. - - def refresh_cache(self, cert): - der = cert.get_issuer().der() - if der in self.cache: - filename, checksum = self.cache[der] - if sha256_checksum(filename) != checksum: - self._reset() - - def validate(self, cert): - parsed = crypto.load_certificate(crypto.FILETYPE_PEM, cert) - self.refresh_cache(parsed) - context = crypto.X509StoreContext(self.store, parsed) - try: - context.verify_certificate() - return True - - except crypto.X509StoreContextError as err: - self.log_error( - "Certificate revoked or errored. Error: {}. Args: {}".format( - type(err), err.args - ) - ) - return False diff --git a/atst/routes/__init__.py b/atst/routes/__init__.py index f8c4199c..68c83437 100644 --- a/atst/routes/__init__.py +++ b/atst/routes/__init__.py @@ -32,7 +32,7 @@ def catch_all(path): def _make_authentication_context(): return AuthenticationContext( - crl_validator=app.crl_validator, + crl_cache=app.crl_cache, auth_status=request.environ.get("HTTP_X_SSL_CLIENT_VERIFY"), sdn=request.environ.get("HTTP_X_SSL_CLIENT_S_DN"), cert=request.environ.get("HTTP_X_SSL_CLIENT_CERT") diff --git a/tests/domain/authnid/test_authentication_context.py b/tests/domain/authnid/test_authentication_context.py index f2a359af..f07d6e94 100644 --- a/tests/domain/authnid/test_authentication_context.py +++ b/tests/domain/authnid/test_authentication_context.py @@ -10,25 +10,23 @@ from tests.factories import UserFactory CERT = open("tests/fixtures/{}.crt".format(FIXTURE_EMAIL_ADDRESS)).read() -class MockCRLValidator(): - - def __init__(self, value): - self.value = value - - def validate(self, cert): - return self.value +class MockCRLCache(): + def get_store(self, cert): + pass -def test_can_authenticate(): +def test_can_authenticate(monkeypatch): + monkeypatch.setattr("atst.domain.authnid.Validator.validate", lambda s: True) auth_context = AuthenticationContext( - MockCRLValidator(True), "SUCCESS", DOD_SDN, CERT + MockCRLCache(), "SUCCESS", DOD_SDN, CERT ) assert auth_context.authenticate() -def test_unsuccessful_status(): +def test_unsuccessful_status(monkeypatch): + monkeypatch.setattr("atst.domain.authnid.Validator.validate", lambda s: True) auth_context = AuthenticationContext( - MockCRLValidator(True), "FAILURE", DOD_SDN, CERT + MockCRLCache(), "FAILURE", DOD_SDN, CERT ) with pytest.raises(UnauthenticatedError) as excinfo: assert auth_context.authenticate() @@ -37,9 +35,10 @@ def test_unsuccessful_status(): assert "client authentication" in message -def test_crl_check_fails(): +def test_crl_check_fails(monkeypatch): + monkeypatch.setattr("atst.domain.authnid.Validator.validate", lambda s: False) auth_context = AuthenticationContext( - MockCRLValidator(False), "SUCCESS", DOD_SDN, CERT + MockCRLCache(), "SUCCESS", DOD_SDN, CERT ) with pytest.raises(UnauthenticatedError) as excinfo: assert auth_context.authenticate() @@ -48,9 +47,10 @@ def test_crl_check_fails(): assert "CRL check" in message -def test_bad_sdn(): +def test_bad_sdn(monkeypatch): + monkeypatch.setattr("atst.domain.authnid.Validator.validate", lambda s: True) auth_context = AuthenticationContext( - MockCRLValidator(True), "SUCCESS", "abc123", CERT + MockCRLCache(), "SUCCESS", "abc123", CERT ) with pytest.raises(UnauthenticatedError) as excinfo: auth_context.get_user() @@ -59,33 +59,36 @@ def test_bad_sdn(): assert "SDN" in message -def test_user_exists(): +def test_user_exists(monkeypatch): + monkeypatch.setattr("atst.domain.authnid.Validator.validate", lambda s: True) user = UserFactory.create(**DOD_SDN_INFO) auth_context = AuthenticationContext( - MockCRLValidator(True), "SUCCESS", DOD_SDN, CERT + MockCRLCache(), "SUCCESS", DOD_SDN, CERT ) auth_user = auth_context.get_user() assert auth_user == user -def test_creates_user(): +def test_creates_user(monkeypatch): + monkeypatch.setattr("atst.domain.authnid.Validator.validate", lambda s: True) # check user does not exist with pytest.raises(NotFoundError): Users.get_by_dod_id(DOD_SDN_INFO["dod_id"]) auth_context = AuthenticationContext( - MockCRLValidator(True), "SUCCESS", DOD_SDN, CERT + MockCRLCache(), "SUCCESS", DOD_SDN, CERT ) user = auth_context.get_user() assert user.dod_id == DOD_SDN_INFO["dod_id"] assert user.email == FIXTURE_EMAIL_ADDRESS -def test_user_cert_has_no_email(): +def test_user_cert_has_no_email(monkeypatch): + monkeypatch.setattr("atst.domain.authnid.Validator.validate", lambda s: True) cert = open("ssl/client-certs/atat.mil.crt").read() auth_context = AuthenticationContext( - MockCRLValidator(True), "SUCCESS", DOD_SDN, cert + MockCRLCache(), "SUCCESS", DOD_SDN, cert ) user = auth_context.get_user() diff --git a/tests/domain/authnid/test_crl.py b/tests/domain/authnid/test_crl.py index 1b9fa2ec..fcd371b3 100644 --- a/tests/domain/authnid/test_crl.py +++ b/tests/domain/authnid/test_crl.py @@ -4,7 +4,7 @@ import re import os import shutil from OpenSSL import crypto, SSL -from atst.domain.authnid.crl import Validator +from atst.domain.authnid.crl import Validator, CRLCache import atst.domain.authnid.crl.util as util @@ -24,38 +24,33 @@ class MockX509Store(): def test_can_build_crl_list(monkeypatch): location = 'ssl/client-certs/client-ca.der.crl' - validator = Validator(crl_locations=[location], base_store=MockX509Store) - assert len(validator.store.crls) == 1 + cache = CRLCache('ssl/client-certs/client-ca.crt', crl_locations=[location], store_class=MockX509Store) + for store in cache.x509_stores.values(): + assert len(store.crls) == 1 def test_can_build_trusted_root_list(): location = 'ssl/server-certs/ca-chain.pem' - validator = Validator(roots=[location], base_store=MockX509Store) + cache = CRLCache(root_location=location, crl_locations=[], store_class=MockX509Store) with open(location) as f: content = f.read() - assert len(validator.store.certs) == content.count('BEGIN CERT') + assert len(cache.certificate_authorities.keys()) == content.count('BEGIN CERT') def test_can_validate_certificate(): - validator = Validator( - roots=['ssl/server-certs/ca-chain.pem'], - crl_locations=['ssl/client-certs/client-ca.der.crl'] - ) + cache = CRLCache('ssl/server-certs/ca-chain.pem', crl_locations=['ssl/client-certs/client-ca.der.crl']) good_cert = open('ssl/client-certs/atat.mil.crt', 'rb').read() bad_cert = open('ssl/client-certs/bad-atat.mil.crt', 'rb').read() - assert validator.validate(good_cert) - assert validator.validate(bad_cert) == False + assert Validator(cache, good_cert).validate() + assert Validator(cache, bad_cert).validate() == False def test_can_dynamically_update_crls(tmpdir): crl_file = tmpdir.join('test.crl') shutil.copyfile('ssl/client-certs/client-ca.der.crl', crl_file) - validator = Validator( - roots=['ssl/server-certs/ca-chain.pem'], - crl_locations=[crl_file] - ) + cache = CRLCache('ssl/server-certs/ca-chain.pem', crl_locations=[crl_file]) cert = open('ssl/client-certs/atat.mil.crt', 'rb').read() - assert validator.validate(cert) + assert Validator(cache, cert).validate() # override the original CRL with one that revokes atat.mil.crt shutil.copyfile('tests/fixtures/test.der.crl', crl_file) - assert validator.validate(cert) == False + assert Validator(cache, cert).validate() == False def test_parse_disa_pki_list(): with open('tests/fixtures/disa-pki.html') as disa: diff --git a/tests/fixtures/test.der.crl b/tests/fixtures/test.der.crl index dc8310f2..24708b0b 100644 Binary files a/tests/fixtures/test.der.crl and b/tests/fixtures/test.der.crl differ