diff --git a/atst/app.py b/atst/app.py index 66299d0d..d3cb357a 100644 --- a/atst/app.py +++ b/atst/app.py @@ -50,10 +50,10 @@ def make_app(config): app.config.update({"SESSION_REDIS": app.redis}) make_flask_callbacks(app) - make_crl_validator(app) register_filters(app) make_eda_client(app) make_csp_provider(app) + make_crl_validator(app) make_mailer(app) queue.init_app(app) @@ -203,7 +203,10 @@ def make_crl_validator(app): for filename in pathlib.Path(app.config["CRL_CONTAINER"]).glob("*.crl"): crl_locations.append(filename.absolute()) app.crl_cache = CRLCache( - app.config["CA_CHAIN"], crl_locations, logger=app.logger + app.config["CA_CHAIN"], + crl_locations, + logger=app.logger, + crl_update_func=app.csp.crls.sync_crls, ) diff --git a/atst/domain/authnid/crl/__init__.py b/atst/domain/authnid/crl/__init__.py index 038f32f1..937a2db1 100644 --- a/atst/domain/authnid/crl/__init__.py +++ b/atst/domain/authnid/crl/__init__.py @@ -2,6 +2,7 @@ import sys import os import re import hashlib +from datetime import datetime from OpenSSL import crypto, SSL @@ -56,13 +57,20 @@ class CRLCache(CRLInterface): ) def __init__( - self, root_location, crl_locations=[], store_class=crypto.X509Store, logger=None + self, + root_location, + crl_locations=[], + store_class=crypto.X509Store, + logger=None, + crl_update_func=None, ): + self._crl_locations = crl_locations self.logger = logger self.store_class = store_class self.certificate_authorities = {} + self._crl_update_func = crl_update_func self._load_roots(root_location) - self._build_crl_cache(crl_locations) + self._build_crl_cache() def _get_store(self, cert): return self._build_store(cert.get_issuer()) @@ -76,9 +84,9 @@ class CRLCache(CRLInterface): def _parse_roots(self, root_str): return [match.group(0) for match in self._PEM_RE.finditer(root_str)] - def _build_crl_cache(self, crl_locations): + def _build_crl_cache(self): self.crl_cache = {} - for crl_location in crl_locations: + for crl_location in self._crl_locations: crl = self._load_crl(crl_location) if crl: issuer_der = crl.get_issuer().der() @@ -109,6 +117,8 @@ class CRLCache(CRLInterface): ) ) + self._manage_expiring(crl_info["expires"]) + crl = self._load_crl(crl_info["location"]) store.add_crl(crl) @@ -121,6 +131,17 @@ class CRLCache(CRLInterface): store = self._add_certificate_chain_to_store(store, crl.get_issuer()) return store + def _manage_expiring(self, crl_expiry): + """ + If a CRL is expired and CRLCache has been given a function for updating + the physical CRL locations, run the update function and then rebuild + the CRL cache. + """ + current = datetime.now(crl_expiry.tzinfo) + if self._crl_update_func and current > crl_expiry: + self._crl_update_func() + self._build_crl_cache() + # this _should_ happen just twice for the DoD PKI (intermediary, root) but # theoretically it can build a longer certificate chain diff --git a/tests/domain/authnid/test_crl.py b/tests/domain/authnid/test_crl.py index 92c6bd73..c385c4bb 100644 --- a/tests/domain/authnid/test_crl.py +++ b/tests/domain/authnid/test_crl.py @@ -17,6 +17,104 @@ from atst.domain.authnid.crl import CRLCache, CRLRevocationException, NoOpCRLCac from tests.mocks import FIXTURE_EMAIL_ADDRESS, DOD_CN +def rsa_key(): + return rsa.generate_private_key( + public_exponent=65537, key_size=2048, backend=default_backend() + ) + + +@pytest.fixture +def ca_key(): + return rsa_key() + + +@pytest.fixture +def make_x509(): + def _make_x509(private_key, signer_key=None, cn="ATAT", signer_cn="ATAT"): + if signer_key is None: + signer_key = private_key + + one_day = timedelta(1, 0, 0) + public_key = private_key.public_key() + builder = x509.CertificateBuilder() + builder = builder.subject_name( + x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, cn)]) + ) + builder = builder.issuer_name( + x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, signer_cn)]) + ) + if signer_key == private_key: + builder = builder.add_extension( + x509.BasicConstraints(ca=True, path_length=None), critical=True + ) + builder = builder.not_valid_before(datetime.today() - (one_day * 2)) + builder = builder.not_valid_after(datetime.today() + (one_day * 30)) + builder = builder.serial_number(x509.random_serial_number()) + builder = builder.public_key(public_key) + certificate = builder.sign( + private_key=signer_key, algorithm=hashes.SHA256(), backend=default_backend() + ) + + return certificate + + return _make_x509 + + +@pytest.fixture +def make_crl(): + def _make_crl(private_key, last_update_days=-1, next_update_days=30, cn="ATAT"): + one_day = timedelta(1, 0, 0) + builder = x509.CertificateRevocationListBuilder() + builder = builder.issuer_name( + x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, cn)]) + ) + builder = builder.last_update(datetime.today() + (one_day * last_update_days)) + builder = builder.next_update(datetime.today() + (one_day * next_update_days)) + crl = builder.sign( + private_key=private_key, + algorithm=hashes.SHA256(), + backend=default_backend(), + ) + + return crl + + return _make_crl + + +def serialize_pki_object_to_disk(obj, name, encoding=Encoding.PEM): + with open(name, "wb") as file_: + file_.write(obj.public_bytes(encoding)) + + return name + + +@pytest.fixture +def ca_file(make_x509, ca_key, tmpdir): + ca = make_x509(ca_key) + ca_out = tmpdir.join("atat-ca.crt") + serialize_pki_object_to_disk(ca, ca_out) + + return ca_out + + +@pytest.fixture +def expired_crl_file(make_crl, ca_key, tmpdir): + crl = make_crl(ca_key, last_update_days=-7, next_update_days=-1) + crl_out = tmpdir.join("atat-expired.crl") + serialize_pki_object_to_disk(crl, crl_out, encoding=Encoding.DER) + + return crl_out + + +@pytest.fixture +def crl_file(make_crl, ca_key, tmpdir): + crl = make_crl(ca_key) + crl_out = tmpdir.join("atat-valid.crl") + serialize_pki_object_to_disk(crl, crl_out, encoding=Encoding.DER) + + return crl_out + + class MockX509Store: def __init__(self): self.crls = [] @@ -32,20 +130,16 @@ class MockX509Store: pass -def test_can_build_crl_list(ca_key, make_crl, make_x509, tmpdir): - ca = make_x509(ca_key) - ca_out = tmpdir.join("atat.crt") - serialize_pki_object_to_disk(ca, ca_out) - +def test_can_build_crl_list(ca_file, ca_key, make_crl, tmpdir): crl = make_crl(ca_key) - crl_out = tmpdir.join("atat.crl") - serialize_pki_object_to_disk(crl, crl_out, encoding=Encoding.DER) + crl_file = tmpdir.join("atat.crl") + serialize_pki_object_to_disk(crl, crl_file, encoding=Encoding.DER) - cache = CRLCache(ca_out, crl_locations=[crl_out], store_class=MockX509Store) + cache = CRLCache(ca_file, crl_locations=[crl_file], store_class=MockX509Store) issuer_der = crl.issuer.public_bytes(default_backend()) assert len(cache.crl_cache.keys()) == 1 assert issuer_der in cache.crl_cache - assert cache.crl_cache[issuer_der]["location"] == crl_out + assert cache.crl_cache[issuer_der]["location"] == crl_file assert cache.crl_cache[issuer_der]["expires"] == crl.next_update @@ -139,89 +233,19 @@ def test_no_op_crl_cache_logs_common_name(): assert "ART.GARFUNKEL.1234567890" in logger.messages[-1] -def rsa_key(): - return rsa.generate_private_key( - public_exponent=65537, key_size=2048, backend=default_backend() - ) - - -@pytest.fixture -def ca_key(): - return rsa_key() - - -@pytest.fixture -def make_x509(): - def _make_x509(private_key, signer_key=None, cn="ATAT", signer_cn="ATAT"): - if signer_key is None: - signer_key = private_key - - one_day = timedelta(1, 0, 0) - public_key = private_key.public_key() - builder = x509.CertificateBuilder() - builder = builder.subject_name( - x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, cn)]) - ) - builder = builder.issuer_name( - x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, signer_cn)]) - ) - if signer_key == private_key: - builder = builder.add_extension( - x509.BasicConstraints(ca=True, path_length=None), critical=True - ) - builder = builder.not_valid_before(datetime.today() - (one_day * 2)) - builder = builder.not_valid_after(datetime.today() + (one_day * 30)) - builder = builder.serial_number(x509.random_serial_number()) - builder = builder.public_key(public_key) - certificate = builder.sign( - private_key=signer_key, algorithm=hashes.SHA256(), backend=default_backend() - ) - - return certificate - - return _make_x509 - - -@pytest.fixture -def make_crl(): - def _make_crl(private_key, last_update_days=-1, next_update_days=30, cn="ATAT"): - one_day = timedelta(1, 0, 0) - builder = x509.CertificateRevocationListBuilder() - builder = builder.issuer_name( - x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, cn)]) - ) - builder = builder.last_update(datetime.today() + (one_day * last_update_days)) - builder = builder.next_update(datetime.today() + (one_day * next_update_days)) - crl = builder.sign( - private_key=private_key, - algorithm=hashes.SHA256(), - backend=default_backend(), - ) - - return crl - - return _make_crl - - -def serialize_pki_object_to_disk(obj, name, encoding=Encoding.PEM): - with open(name, "wb") as file_: - file_.write(obj.public_bytes(encoding)) - - return name - - -@pytest.mark.skip(reason="not implemented yet") -def test_updates_expired_certs(ca_key, make_crl, make_x509, tmpdir): - ca = make_x509(ca_key) - ca_out = tmpdir.join("atat.crt") - serialize_pki_object_to_disk(ca, ca_out) - - crl = make_crl(ca_key, last_update_days=-7, next_update_days=-1) - crl_out = tmpdir.join("atat.crl") - serialize_pki_object_to_disk(crl, crl_out, encoding=Encoding.DER) - +def test_updates_expired_certs(ca_file, expired_crl_file, crl_file, ca_key, make_x509): + """ + Given a CRLCache object with an expired CRL and a function for updating the + CRLs, the CRLCache should run the update function before checking a + certificate that requires the expired CRL. + """ client_cert = make_x509(rsa_key(), signer_key=ca_key, cn="chewbacca") client_pem = client_cert.public_bytes(Encoding.PEM) - crl_cache = CRLCache(ca_out, crl_locations=[crl_out]) + def _crl_update_func(): + shutil.copyfile(crl_file, expired_crl_file) + + crl_cache = CRLCache( + ca_file, crl_locations=[expired_crl_file], crl_update_func=_crl_update_func + ) crl_cache.crl_check(client_pem)