diff --git a/atst/domain/authnid/crl/__init__.py b/atst/domain/authnid/crl/__init__.py index 0dcf1fc0..f1c93fb7 100644 --- a/atst/domain/authnid/crl/__init__.py +++ b/atst/domain/authnid/crl/__init__.py @@ -71,15 +71,24 @@ class CRLCache(): def _build_store(self, issuer): store = self.store_class() store.set_flags(crypto.X509StoreFlags.CRL_CHECK) - crl_location = self.crl_cache[issuer] + crl_location = self._get_crl_location(issuer) with open(crl_location, "rb") as crl_file: crl = crypto.load_crl(crypto.FILETYPE_ASN1, crl_file.read()) store.add_crl(crl) store = self._add_certificate_chain_to_store(store, crl.get_issuer()) return store + def _get_crl_location(self, issuer): + crl_location = self.crl_cache.get(issuer) + + if not crl_location: + raise CRLRevocationException("Could not find matching CRL for issuer") + + return crl_location + # this _should_ happen just twice for the DoD PKI (intermediary, root) but # theoretically it can build a longer certificate chain + def _add_certificate_chain_to_store(self, store, issuer): ca = self.certificate_authorities.get(issuer.der()) store.add_cert(ca) @@ -87,6 +96,6 @@ class CRLCache(): if issuer == ca.get_subject(): # i.e., it is the root CA and we are at the end of the chain return store + else: return self._add_certificate_chain_to_store(store, ca.get_issuer()) - diff --git a/tests/domain/authnid/test_crl.py b/tests/domain/authnid/test_crl.py index 5bd009be..19757353 100644 --- a/tests/domain/authnid/test_crl.py +++ b/tests/domain/authnid/test_crl.py @@ -4,11 +4,15 @@ import re import os import shutil from OpenSSL import crypto, SSL + from atst.domain.authnid.crl import crl_check, CRLCache, CRLRevocationException import atst.domain.authnid.crl.util as util +from tests.mocks import FIXTURE_EMAIL_ADDRESS + class MockX509Store(): + def __init__(self): self.crls = [] self.certs = [] @@ -24,46 +28,69 @@ class MockX509Store(): def test_can_build_crl_list(monkeypatch): - location = 'ssl/client-certs/client-ca.der.crl' - cache = CRLCache('ssl/client-certs/client-ca.crt', crl_locations=[location], store_class=MockX509Store) + location = "ssl/client-certs/client-ca.der.crl" + cache = CRLCache( + "ssl/client-certs/client-ca.crt", + crl_locations=[location], + store_class=MockX509Store, + ) assert len(cache.crl_cache.keys()) == 1 def test_can_build_trusted_root_list(): - location = 'ssl/server-certs/ca-chain.pem' - cache = CRLCache(root_location=location, crl_locations=[], store_class=MockX509Store) + location = "ssl/server-certs/ca-chain.pem" + cache = CRLCache( + root_location=location, crl_locations=[], store_class=MockX509Store + ) with open(location) as f: content = f.read() - assert len(cache.certificate_authorities.keys()) == content.count('BEGIN CERT') + assert len(cache.certificate_authorities.keys()) == content.count("BEGIN CERT") + def test_can_validate_certificate(): - cache = CRLCache('ssl/server-certs/ca-chain.pem', crl_locations=['ssl/client-certs/client-ca.der.crl']) - good_cert = open('ssl/client-certs/atat.mil.crt', 'rb').read() - bad_cert = open('ssl/client-certs/bad-atat.mil.crt', 'rb').read() + cache = CRLCache( + "ssl/server-certs/ca-chain.pem", + crl_locations=["ssl/client-certs/client-ca.der.crl"], + ) + good_cert = open("ssl/client-certs/atat.mil.crt", "rb").read() + bad_cert = open("ssl/client-certs/bad-atat.mil.crt", "rb").read() assert crl_check(cache, good_cert) with pytest.raises(CRLRevocationException): crl_check(cache, bad_cert) + def test_can_dynamically_update_crls(tmpdir): - crl_file = tmpdir.join('test.crl') - shutil.copyfile('ssl/client-certs/client-ca.der.crl', crl_file) - cache = CRLCache('ssl/server-certs/ca-chain.pem', crl_locations=[crl_file]) - cert = open('ssl/client-certs/atat.mil.crt', 'rb').read() + crl_file = tmpdir.join("test.crl") + shutil.copyfile("ssl/client-certs/client-ca.der.crl", crl_file) + cache = CRLCache("ssl/server-certs/ca-chain.pem", crl_locations=[crl_file]) + cert = open("ssl/client-certs/atat.mil.crt", "rb").read() assert crl_check(cache, cert) # override the original CRL with one that revokes atat.mil.crt - shutil.copyfile('tests/fixtures/test.der.crl', crl_file) + shutil.copyfile("tests/fixtures/test.der.crl", crl_file) with pytest.raises(CRLRevocationException): assert crl_check(cache, cert) + +def test_throws_error_for_missing_issuer(): + cache = CRLCache("ssl/server-certs/ca-chain.pem", crl_locations=[]) + cert = open("tests/fixtures/{}.crt".format(FIXTURE_EMAIL_ADDRESS), "rb").read() + with pytest.raises(CRLRevocationException) as exc: + assert crl_check(cache, cert) + (message,) = exc.value.args + assert "issuer" in message + + def test_parse_disa_pki_list(): - with open('tests/fixtures/disa-pki.html') as disa: + with open("tests/fixtures/disa-pki.html") as disa: disa_html = disa.read() crl_list = util.crl_list_from_disa_html(disa_html) - href_matches = re.findall('DOD(ROOT|EMAIL|ID)?CA', disa_html) + href_matches = re.findall("DOD(ROOT|EMAIL|ID)?CA", disa_html) assert len(crl_list) > 0 assert len(crl_list) == len(href_matches) + class MockStreamingResponse(): + def __init__(self, content_chunks, code=200): self.content_chunks = content_chunks self.status_code = code @@ -77,13 +104,19 @@ class MockStreamingResponse(): def __exit__(self, *args): pass + def test_write_crl(tmpdir, monkeypatch): - monkeypatch.setattr('requests.get', lambda u, **kwargs: MockStreamingResponse([b'it worked'])) - crl = 'crl_1' + monkeypatch.setattr( + "requests.get", lambda u, **kwargs: MockStreamingResponse([b"it worked"]) + ) + crl = "crl_1" assert util.write_crl(tmpdir, "random_target_dir", crl) assert [p.basename for p in tmpdir.listdir()] == [crl] - assert [p.read() for p in tmpdir.listdir()] == ['it worked'] + assert [p.read() for p in tmpdir.listdir()] == ["it worked"] + def test_skips_crl_if_it_has_not_been_modified(tmpdir, monkeypatch): - monkeypatch.setattr('requests.get', lambda u, **kwargs: MockStreamingResponse([b'it worked'], 304)) - assert not util.write_crl(tmpdir, "random_target_dir", 'crl_file_name') + monkeypatch.setattr( + "requests.get", lambda u, **kwargs: MockStreamingResponse([b"it worked"], 304) + ) + assert not util.write_crl(tmpdir, "random_target_dir", "crl_file_name")