diff --git a/.circleci/config.yml b/.circleci/config.yml index 9d57630f..5f304ed7 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -13,6 +13,7 @@ defaults: PGDATABASE: circle_test REDIS_URI: redis://localhost:6379 PIP_VERSION: 18.* + CRL_STORAGE_PROVIDER: CLOUDFILES dockerCmdEnvironment: &dockerCmdEnvironment APP_USER: atst APP_GROUP: atat @@ -83,18 +84,9 @@ jobs: paths: - ./node_modules key: node-v1-{{ .Branch }}-{{ checksum "yarn.lock" }} - - restore_cache: - name: "Load Cache: CRLs" - keys: - - disa-crls-v2 - run: name: "Update CRLs" command: ./script/sync-crls - - save_cache: - name: "Save Cache: CRLs" - paths: - - ./crl - key: disa-crls-v2-{{ .Branch }}-{{ epoch}} - run: name: "Generate build info" command: ./script/generate_build_info.sh diff --git a/.gitignore b/.gitignore index c7205185..d4d60946 100644 --- a/.gitignore +++ b/.gitignore @@ -36,6 +36,7 @@ config/dev.ini # CRLs /crl +/crls /crl-tmp *.bk diff --git a/atst/app.py b/atst/app.py index 19e09ca2..d3cb357a 100644 --- a/atst/app.py +++ b/atst/app.py @@ -50,10 +50,10 @@ def make_app(config): app.config.update({"SESSION_REDIS": app.redis}) make_flask_callbacks(app) - make_crl_validator(app) register_filters(app) make_eda_client(app) make_csp_provider(app) + make_crl_validator(app) make_mailer(app) queue.init_app(app) @@ -200,10 +200,13 @@ def make_crl_validator(app): app.crl_cache = NoOpCRLCache(logger=app.logger) else: crl_locations = [] - for filename in pathlib.Path(app.config["CRL_DIRECTORY"]).glob("*.crl"): + for filename in pathlib.Path(app.config["CRL_CONTAINER"]).glob("*.crl"): crl_locations.append(filename.absolute()) app.crl_cache = CRLCache( - app.config["CA_CHAIN"], crl_locations, logger=app.logger + app.config["CA_CHAIN"], + crl_locations, + logger=app.logger, + crl_update_func=app.csp.crls.sync_crls, ) diff --git a/atst/domain/authnid/crl/__init__.py b/atst/domain/authnid/crl/__init__.py index fb81d54d..937a2db1 100644 --- a/atst/domain/authnid/crl/__init__.py +++ b/atst/domain/authnid/crl/__init__.py @@ -2,6 +2,7 @@ import sys import os import re import hashlib +from datetime import datetime from OpenSSL import crypto, SSL @@ -56,13 +57,20 @@ class CRLCache(CRLInterface): ) def __init__( - self, root_location, crl_locations=[], store_class=crypto.X509Store, logger=None + self, + root_location, + crl_locations=[], + store_class=crypto.X509Store, + logger=None, + crl_update_func=None, ): + self._crl_locations = crl_locations self.logger = logger self.store_class = store_class self.certificate_authorities = {} + self._crl_update_func = crl_update_func self._load_roots(root_location) - self._build_crl_cache(crl_locations) + self._build_crl_cache() def _get_store(self, cert): return self._build_store(cert.get_issuer()) @@ -76,12 +84,17 @@ class CRLCache(CRLInterface): def _parse_roots(self, root_str): return [match.group(0) for match in self._PEM_RE.finditer(root_str)] - def _build_crl_cache(self, crl_locations): + def _build_crl_cache(self): self.crl_cache = {} - for crl_location in crl_locations: + for crl_location in self._crl_locations: crl = self._load_crl(crl_location) if crl: - self.crl_cache[crl.get_issuer().der()] = crl_location + issuer_der = crl.get_issuer().der() + expires = crl.to_cryptography().next_update + self.crl_cache[issuer_der] = { + "location": crl_location, + "expires": expires, + } def _load_crl(self, crl_location): with open(crl_location, "rb") as crl_file: @@ -94,26 +107,41 @@ class CRLCache(CRLInterface): store = self.store_class() self._log_info("STORE ID: {}. Building store.".format(id(store))) store.set_flags(crypto.X509StoreFlags.CRL_CHECK) - crl_location = self.crl_cache.get(issuer.der()) + crl_info = self.crl_cache.get(issuer.der(), {}) issuer_name = get_common_name(issuer) - if not crl_location: + if not crl_info: raise CRLRevocationException( "Could not find matching CRL for issuer with Common Name {}".format( issuer_name ) ) - crl = self._load_crl(crl_location) + self._manage_expiring(crl_info["expires"]) + + crl = self._load_crl(crl_info["location"]) store.add_crl(crl) + self._log_info( "STORE ID: {}. Adding CRL with issuer Common Name {}".format( id(store), issuer_name ) ) + store = self._add_certificate_chain_to_store(store, crl.get_issuer()) return store + def _manage_expiring(self, crl_expiry): + """ + If a CRL is expired and CRLCache has been given a function for updating + the physical CRL locations, run the update function and then rebuild + the CRL cache. + """ + current = datetime.now(crl_expiry.tzinfo) + if self._crl_update_func and current > crl_expiry: + self._crl_update_func() + self._build_crl_cache() + # this _should_ happen just twice for the DoD PKI (intermediary, root) but # theoretically it can build a longer certificate chain diff --git a/atst/domain/authnid/crl/util.py b/atst/domain/authnid/crl/util.py deleted file mode 100644 index f5a702fd..00000000 --- a/atst/domain/authnid/crl/util.py +++ /dev/null @@ -1,113 +0,0 @@ -import requests -import re -import os -import pendulum -from html.parser import HTMLParser - -_DISA_CRLS = "https://iasecontent.disa.mil/pki-pke/data/crls/dod_crldps.htm" - -MODIFIED_TIME_BUFFER = 15 * 60 - - -def fetch_disa(): - response = requests.get(_DISA_CRLS) - return response.text - - -class DISAParser(HTMLParser): - crl_list = [] - _CRL_MATCH = re.compile("DOD(ROOT|EMAIL|ID)?CA") - - def handle_starttag(self, tag, attrs): - if tag == "a": - href = [pair[1] for pair in attrs if pair[0] == "href"].pop() - if re.search(self._CRL_MATCH, href): - self.crl_list.append(href) - - -def crl_list_from_disa_html(html): - parser = DISAParser() - parser.reset() - parser.feed(html) - return parser.crl_list - - -def crl_local_path(out_dir, crl_location): - name = re.split("/", crl_location)[-1] - crl = os.path.join(out_dir, name) - return crl - - -def existing_crl_modification_time(crl): - if os.path.exists(crl): - prev_time = os.path.getmtime(crl) - buffered = prev_time + MODIFIED_TIME_BUFFER - mod_time = prev_time if pendulum.now().timestamp() < buffered else buffered - dt = pendulum.from_timestamp(mod_time, tz="GMT") - return dt.format("ddd, DD MMM YYYY HH:mm:ss zz") - - else: - return False - - -def write_crl(out_dir, target_dir, crl_location): - crl = crl_local_path(out_dir, crl_location) - existing = crl_local_path(target_dir, crl_location) - options = {"stream": True} - mod_time = existing_crl_modification_time(existing) - if mod_time: - options["headers"] = {"If-Modified-Since": mod_time} - - with requests.get(crl_location, **options) as response: - if response.status_code == 304: - return False - - with open(crl, "wb") as crl_file: - for chunk in response.iter_content(chunk_size=1024): - if chunk: - crl_file.write(chunk) - - return True - - -def remove_bad_crl(out_dir, crl_location): - crl = crl_local_path(out_dir, crl_location) - if os.path.isfile(crl): - os.remove(crl) - - -def refresh_crls(out_dir, target_dir, logger): - disa_html = fetch_disa() - crl_list = crl_list_from_disa_html(disa_html) - for crl_location in crl_list: - logger.info("updating CRL from {}".format(crl_location)) - try: - if write_crl(out_dir, target_dir, crl_location): - logger.info("successfully synced CRL from {}".format(crl_location)) - else: - logger.info("no updates for CRL from {}".format(crl_location)) - except requests.exceptions.RequestException: - if logger: - logger.error( - "Error downloading {}, removing file and continuing anyway".format( - crl_location - ) - ) - remove_bad_crl(out_dir, crl_location) - - -if __name__ == "__main__": - import sys - import logging - - logging.basicConfig( - level=logging.INFO, format="[%(asctime)s]:%(levelname)s: %(message)s" - ) - logger = logging.getLogger() - logger.info("Updating CRLs") - try: - refresh_crls(sys.argv[1], sys.argv[2], logger) - except Exception as err: - logger.exception("Fatal error encountered, stopping") - sys.exit(1) - logger.info("Finished updating CRLs") diff --git a/atst/domain/csp/__init__.py b/atst/domain/csp/__init__.py index a40d200c..6e5c4bbb 100644 --- a/atst/domain/csp/__init__.py +++ b/atst/domain/csp/__init__.py @@ -1,5 +1,5 @@ from .cloud import MockCloudProvider -from .files import RackspaceFileProvider +from .files import RackspaceFileProvider, RackspaceCRLProvider from .reports import MockReportingProvider @@ -8,6 +8,7 @@ class MockCSP: self.cloud = MockCloudProvider() self.files = RackspaceFileProvider(app) self.reports = MockReportingProvider() + self.crls = RackspaceCRLProvider(app) def make_csp_provider(app): diff --git a/atst/domain/csp/files.py b/atst/domain/csp/files.py index a74403e5..3e14068c 100644 --- a/atst/domain/csp/files.py +++ b/atst/domain/csp/files.py @@ -1,4 +1,6 @@ -from tempfile import NamedTemporaryFile +import os +import tarfile +from tempfile import NamedTemporaryFile, TemporaryDirectory from uuid import uuid4 from libcloud.storage.types import Provider @@ -34,23 +36,26 @@ class FileProviderInterface: raise NotImplementedError() +def get_rackspace_container(provider, container=None, **kwargs): + if provider == "LOCAL": # pragma: no branch + kwargs["key"] = container + if not os.path.exists(container): + os.mkdir(container) + container = "" + + driver = get_driver(getattr(Provider, provider))(**kwargs) + return driver.get_container(container) + + class RackspaceFileProvider(FileProviderInterface): def __init__(self, app): - self.container = self._get_container( + self.container = get_rackspace_container( provider=app.config.get("STORAGE_PROVIDER"), container=app.config.get("STORAGE_CONTAINER"), key=app.config.get("STORAGE_KEY"), secret=app.config.get("STORAGE_SECRET"), ) - def _get_container(self, provider, container=None, key=None, secret=None): - if provider == "LOCAL": # pragma: no branch - key = container - container = "" - - driver = get_driver(getattr(Provider, provider))(key=key, secret=secret) - return driver.get_container(container) - def upload(self, fyle): self._enforce_mimetype(fyle) @@ -70,3 +75,38 @@ class RackspaceFileProvider(FileProviderInterface): with NamedTemporaryFile() as tempfile: obj.download(tempfile.name, overwrite_existing=True) return open(tempfile.name, "rb") + + +class CRLProviderInterface: + def sync_crls(self): # pragma: no cover + """ + Retrieve copies of the CRLs and unpack them to disk. + """ + raise NotImplementedError() + + +class RackspaceCRLProvider(CRLProviderInterface): + def __init__(self, app): + provider = app.config.get("CRL_STORAGE_PROVIDER") or app.config.get( + "STORAGE_PROVIDER" + ) + self.container = get_rackspace_container( + provider=provider, + container=app.config.get("CRL_STORAGE_CONTAINER"), + key=app.config.get("STORAGE_KEY"), + secret=app.config.get("STORAGE_SECRET"), + region=app.config.get("CRL_STORAGE_REGION"), + ) + self._crl_dir = app.config.get("CRL_STORAGE_CONTAINER") + self._object_name = app.config.get("STORAGE_CRL_ARCHIVE_NAME") + + def sync_crls(self): + if not os.path.exists(self._crl_dir): + os.mkdir(self._crl_dir) + + obj = self.container.get_object(object_name=self._object_name) + with TemporaryDirectory() as tempdir: + dl_path = os.path.join(tempdir, self._object_name) + obj.download(dl_path, overwrite_existing=True) + archive = tarfile.open(dl_path, "r:bz2") + archive.extractall(self._crl_dir) diff --git a/config/base.ini b/config/base.ini index b8044bff..0629ec29 100644 --- a/config/base.ini +++ b/config/base.ini @@ -3,7 +3,9 @@ CAC_URL = http://localhost:8000/login-redirect CA_CHAIN = ssl/server-certs/ca-chain.pem CLASSIFIED = false COOKIE_SECRET = some-secret-please-replace -CRL_DIRECTORY = crl +CRL_STORAGE_CONTAINER = crls +CRL_STORAGE_PROVIDER = LOCAL +CRL_STORAGE_REGION = iad DISABLE_CRL_CHECK = false DEBUG = true ENVIRONMENT = dev @@ -24,5 +26,8 @@ SESSION_COOKIE_NAME=atat SESSION_TYPE = redis SESSION_USE_SIGNER = True STORAGE_CONTAINER=uploads +STORAGE_KEY='' +STORAGE_SECRET='' STORAGE_PROVIDER=LOCAL +STORAGE_CRL_ARCHIVE_NAME = dod_crls.tar.bz WTF_CSRF_ENABLED = true diff --git a/config/ci.ini b/config/ci.ini index 7f89af01..f472488d 100644 --- a/config/ci.ini +++ b/config/ci.ini @@ -3,5 +3,5 @@ DEBUG = false PGHOST = postgreshost PGDATABASE = atat_test REDIS_URI = redis://redishost:6379 -CRL_DIRECTORY = tests/fixtures/crl +CRL_CONTAINER = tests/fixtures/crl WTF_CSRF_ENABLED = false diff --git a/config/selenium.ini b/config/selenium.ini index 054b8ed3..47f730ef 100644 --- a/config/selenium.ini +++ b/config/selenium.ini @@ -1,3 +1,3 @@ [default] PGDATABASE = atat_selenium -CRL_DIRECTORY = tests/fixtures/crl +CRL_CONTAINER = tests/fixtures/crl diff --git a/config/test.ini b/config/test.ini index b5ac5230..3a3b0b1b 100644 --- a/config/test.ini +++ b/config/test.ini @@ -2,6 +2,6 @@ DEBUG = false ENVIRONMENT = test PGDATABASE = atat_test -CRL_DIRECTORY = tests/fixtures/crl +CRL_CONTAINER = tests/fixtures/crl WTF_CSRF_ENABLED = false STORAGE_PROVIDER=LOCAL diff --git a/script/sync-crls b/script/sync-crls index 3b4eb027..82b6631b 100755 --- a/script/sync-crls +++ b/script/sync-crls @@ -1,22 +1,14 @@ -#!/bin/bash +#! .venv/bin/python +# Add root application dir to the python path +import os +import sys -# script/sync-crls: update the DOD CRLs and place them where authnid expects them -set -e -cd "$(dirname "$0")/.." +parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +sys.path.append(parent_dir) -if [[ $# -eq 0 ]]; then - TMP_DIR=crl-tmp -else - TMP_DIR=$1 -fi +from atst.app import make_config, make_app -mkdir -p $TMP_DIR -pipenv run python ./atst/domain/authnid/crl/util.py $TMP_DIR crl -mkdir -p crl -rsync -rq --min-size 400 $TMP_DIR/. crl/. -rm -rf $TMP_DIR - -if [[ $FLASK_ENV != "prod" ]]; then - # place our test CRL there - cp ssl/client-certs/client-ca.der.crl crl/ -fi +if __name__ == "__main__": + config = make_config({"DISABLE_CRL_CHECK": True}) + app = make_app(config) + app.csp.crls.sync_crls() diff --git a/tests/domain/authnid/test_crl.py b/tests/domain/authnid/test_crl.py index c03d353b..c385c4bb 100644 --- a/tests/domain/authnid/test_crl.py +++ b/tests/domain/authnid/test_crl.py @@ -3,14 +3,118 @@ import pytest import re import os import shutil +from datetime import datetime, timezone, timedelta from OpenSSL import crypto, SSL +from cryptography import x509 +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.hazmat.primitives.serialization import Encoding +from cryptography.x509.oid import NameOID from atst.domain.authnid.crl import CRLCache, CRLRevocationException, NoOpCRLCache -import atst.domain.authnid.crl.util as util from tests.mocks import FIXTURE_EMAIL_ADDRESS, DOD_CN +def rsa_key(): + return rsa.generate_private_key( + public_exponent=65537, key_size=2048, backend=default_backend() + ) + + +@pytest.fixture +def ca_key(): + return rsa_key() + + +@pytest.fixture +def make_x509(): + def _make_x509(private_key, signer_key=None, cn="ATAT", signer_cn="ATAT"): + if signer_key is None: + signer_key = private_key + + one_day = timedelta(1, 0, 0) + public_key = private_key.public_key() + builder = x509.CertificateBuilder() + builder = builder.subject_name( + x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, cn)]) + ) + builder = builder.issuer_name( + x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, signer_cn)]) + ) + if signer_key == private_key: + builder = builder.add_extension( + x509.BasicConstraints(ca=True, path_length=None), critical=True + ) + builder = builder.not_valid_before(datetime.today() - (one_day * 2)) + builder = builder.not_valid_after(datetime.today() + (one_day * 30)) + builder = builder.serial_number(x509.random_serial_number()) + builder = builder.public_key(public_key) + certificate = builder.sign( + private_key=signer_key, algorithm=hashes.SHA256(), backend=default_backend() + ) + + return certificate + + return _make_x509 + + +@pytest.fixture +def make_crl(): + def _make_crl(private_key, last_update_days=-1, next_update_days=30, cn="ATAT"): + one_day = timedelta(1, 0, 0) + builder = x509.CertificateRevocationListBuilder() + builder = builder.issuer_name( + x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, cn)]) + ) + builder = builder.last_update(datetime.today() + (one_day * last_update_days)) + builder = builder.next_update(datetime.today() + (one_day * next_update_days)) + crl = builder.sign( + private_key=private_key, + algorithm=hashes.SHA256(), + backend=default_backend(), + ) + + return crl + + return _make_crl + + +def serialize_pki_object_to_disk(obj, name, encoding=Encoding.PEM): + with open(name, "wb") as file_: + file_.write(obj.public_bytes(encoding)) + + return name + + +@pytest.fixture +def ca_file(make_x509, ca_key, tmpdir): + ca = make_x509(ca_key) + ca_out = tmpdir.join("atat-ca.crt") + serialize_pki_object_to_disk(ca, ca_out) + + return ca_out + + +@pytest.fixture +def expired_crl_file(make_crl, ca_key, tmpdir): + crl = make_crl(ca_key, last_update_days=-7, next_update_days=-1) + crl_out = tmpdir.join("atat-expired.crl") + serialize_pki_object_to_disk(crl, crl_out, encoding=Encoding.DER) + + return crl_out + + +@pytest.fixture +def crl_file(make_crl, ca_key, tmpdir): + crl = make_crl(ca_key) + crl_out = tmpdir.join("atat-valid.crl") + serialize_pki_object_to_disk(crl, crl_out, encoding=Encoding.DER) + + return crl_out + + class MockX509Store: def __init__(self): self.crls = [] @@ -26,14 +130,17 @@ class MockX509Store: pass -def test_can_build_crl_list(monkeypatch): - location = "ssl/client-certs/client-ca.der.crl" - cache = CRLCache( - "ssl/client-certs/client-ca.crt", - crl_locations=[location], - store_class=MockX509Store, - ) +def test_can_build_crl_list(ca_file, ca_key, make_crl, tmpdir): + crl = make_crl(ca_key) + crl_file = tmpdir.join("atat.crl") + serialize_pki_object_to_disk(crl, crl_file, encoding=Encoding.DER) + + cache = CRLCache(ca_file, crl_locations=[crl_file], store_class=MockX509Store) + issuer_der = crl.issuer.public_bytes(default_backend()) assert len(cache.crl_cache.keys()) == 1 + assert issuer_der in cache.crl_cache + assert cache.crl_cache[issuer_der]["location"] == crl_file + assert cache.crl_cache[issuer_der]["expires"] == crl.next_update def test_can_build_trusted_root_list(): @@ -104,47 +211,6 @@ def test_multistep_certificate_chain(): assert cache.crl_check(cert) -def test_parse_disa_pki_list(): - with open("tests/fixtures/disa-pki.html") as disa: - disa_html = disa.read() - crl_list = util.crl_list_from_disa_html(disa_html) - href_matches = re.findall("DOD(ROOT|EMAIL|ID)?CA", disa_html) - assert len(crl_list) > 0 - assert len(crl_list) == len(href_matches) - - -class MockStreamingResponse: - def __init__(self, content_chunks, code=200): - self.content_chunks = content_chunks - self.status_code = code - - def iter_content(self, chunk_size=0): - return self.content_chunks - - def __enter__(self): - return self - - def __exit__(self, *args): - pass - - -def test_write_crl(tmpdir, monkeypatch): - monkeypatch.setattr( - "requests.get", lambda u, **kwargs: MockStreamingResponse([b"it worked"]) - ) - crl = "crl_1" - assert util.write_crl(tmpdir, "random_target_dir", crl) - assert [p.basename for p in tmpdir.listdir()] == [crl] - assert [p.read() for p in tmpdir.listdir()] == ["it worked"] - - -def test_skips_crl_if_it_has_not_been_modified(tmpdir, monkeypatch): - monkeypatch.setattr( - "requests.get", lambda u, **kwargs: MockStreamingResponse([b"it worked"], 304) - ) - assert not util.write_crl(tmpdir, "random_target_dir", "crl_file_name") - - class FakeLogger: def __init__(self): self.messages = [] @@ -159,29 +225,27 @@ class FakeLogger: self.messages.append(msg) -def test_refresh_crls_with_error(tmpdir, monkeypatch): - def _mock_create_connection(*args, **kwargs): - raise TimeoutError - - fake_crl = "https://fakecrl.com/fake.crl" - - monkeypatch.setattr( - "urllib3.util.connection.create_connection", _mock_create_connection - ) - monkeypatch.setattr("atst.domain.authnid.crl.util.fetch_disa", lambda *args: None) - monkeypatch.setattr( - "atst.domain.authnid.crl.util.crl_list_from_disa_html", lambda *args: [fake_crl] - ) - - logger = FakeLogger() - util.refresh_crls(tmpdir, tmpdir, logger) - - assert "Error downloading {}".format(fake_crl) in logger.messages[-1] - - def test_no_op_crl_cache_logs_common_name(): logger = FakeLogger() cert = open("ssl/client-certs/atat.mil.crt", "rb").read() cache = NoOpCRLCache(logger=logger) assert cache.crl_check(cert) assert "ART.GARFUNKEL.1234567890" in logger.messages[-1] + + +def test_updates_expired_certs(ca_file, expired_crl_file, crl_file, ca_key, make_x509): + """ + Given a CRLCache object with an expired CRL and a function for updating the + CRLs, the CRLCache should run the update function before checking a + certificate that requires the expired CRL. + """ + client_cert = make_x509(rsa_key(), signer_key=ca_key, cn="chewbacca") + client_pem = client_cert.public_bytes(Encoding.PEM) + + def _crl_update_func(): + shutil.copyfile(crl_file, expired_crl_file) + + crl_cache = CRLCache( + ca_file, crl_locations=[expired_crl_file], crl_update_func=_crl_update_func + ) + crl_cache.crl_check(client_pem)