atst/tests/test_auth.py
dandds 0b5acde4c4 Stream-parse CRLs for caching file locations.
AT-AT needs to maintain a key-value CRL cache where each key is the DER
byte-string of the issuer and the value is a dictionary of the CRL file
path and expiration. This way when it checks a client certificate, it
can load the correct CRL by comparing the issuers. This is preferable to
loading all of the CRLs in-memory. However, it still requires that AT-AT
load and parse each CRL when the application boots. Because of the size
of the CRLs and their parsed, in-memory size, this leads to the
application spiking to use nearly 900MB of memory (resting usage is
around 50MB).

This change introduces a small function to ad-hoc parse the CRL and
obtain the information in the CRL we need: the issuer and the
expiration. It does this by reading the CRL byte-by-byte until it
reaches the ASN1 sequence that corresponds to the issuer, and then looks
ahead to find the nextUpdate field (i.e., the expiration date). The
CRLCache class uses this function to build its cache and JSON-serializes
the cache to disk. If another AT-AT application process finds the
serialized version, it will load that copy instead of rebuilding it. It
also entails a change to the function signature for the init method of
CRLCache: now it expects the CRL directory as its second argument,
instead of a list of locations.

The Python script invoked by `script/sync-crls` will rebuild the
location cache each time it's run. This means that when the Kubernetes
CronJob for CRLs runs, it will refresh the cache each time. When a new
application container boots, it will get the refreshed cache.

This also adds a nightly CircleCI job to sync the CRLs and test that the
ad-hoc parsing function returns the same result as a proper parsing
using the Python cryptography library. This provides extra insurance
that the function is returning correct results on real data.
2019-11-04 08:36:03 -05:00

300 lines
9.3 KiB
Python

import os
from urllib.parse import urlparse
import pytest
from datetime import datetime
from flask import session, url_for
from cryptography.hazmat.primitives.serialization import Encoding
from .mocks import DOD_SDN_INFO, DOD_SDN, FIXTURE_EMAIL_ADDRESS
from atst.domain.users import Users
from atst.domain.permission_sets import PermissionSets
from atst.domain.exceptions import NotFoundError
from atst.domain.authnid.crl import CRLInvalidException
from atst.domain.auth import UNPROTECTED_ROUTES
from atst.domain.authnid.crl import CRLCache
from .factories import UserFactory
MOCK_USER = {"id": "438567dd-25fa-4d83-a8cc-8aa8366cb24a"}
def _fetch_user_info(c, t):
return MOCK_USER
def _login(client, verify="SUCCESS", sdn=DOD_SDN, cert="", **url_query_args):
return client.get(
url_for("atst.login_redirect", **url_query_args),
environ_base={
"HTTP_X_SSL_CLIENT_VERIFY": verify,
"HTTP_X_SSL_CLIENT_S_DN": sdn,
"HTTP_X_SSL_CLIENT_CERT": cert,
},
)
def test_successful_login_redirect_non_ccpo(client, monkeypatch):
monkeypatch.setattr(
"atst.domain.authnid.AuthenticationContext.authenticate", lambda *args: True
)
monkeypatch.setattr(
"atst.domain.authnid.AuthenticationContext.get_user",
lambda *args: UserFactory.create(),
)
resp = _login(client)
assert resp.status_code == 302
assert "home" in resp.headers["Location"]
assert session["user_id"]
def test_successful_login_redirect_ccpo(client, monkeypatch):
monkeypatch.setattr(
"atst.domain.authnid.AuthenticationContext.authenticate", lambda *args: True
)
role = PermissionSets.get(PermissionSets.VIEW_AUDIT_LOG)
monkeypatch.setattr(
"atst.domain.authnid.AuthenticationContext.get_user",
lambda *args: UserFactory.create(),
)
resp = _login(client)
assert resp.status_code == 302
assert "home" in resp.headers["Location"]
assert session["user_id"]
def test_unsuccessful_login_redirect(client, monkeypatch):
resp = client.get(url_for("atst.login_redirect"))
assert resp.status_code == 401
assert "user_id" not in session
# checks that all of the routes in the app are protected by auth
def is_unprotected(rule):
return rule.endpoint in UNPROTECTED_ROUTES
def protected_routes(app):
for rule in app.url_map.iter_rules():
args = [1] * len(rule.arguments)
mock_args = dict(zip(rule.arguments, args))
_n, route = rule.build(mock_args)
if is_unprotected(rule) or "/static" in route:
continue
yield rule, route
def test_protected_routes_redirect_to_login(client, app):
server_name = app.config.get("SERVER_NAME") or "localhost"
for rule, protected_route in protected_routes(app):
if "GET" in rule.methods:
resp = client.get(protected_route)
assert resp.status_code == 302
assert server_name in resp.headers["Location"]
if "POST" in rule.methods:
resp = client.post(protected_route)
assert resp.status_code == 302
assert server_name in resp.headers["Location"]
def test_get_protected_route_encodes_redirect(client):
portfolio_index = url_for("portfolios.portfolios")
response = client.get(portfolio_index)
redirect = url_for("atst.root", next=portfolio_index)
assert redirect in response.headers["Location"]
def test_unprotected_routes_set_user_if_logged_in(client, app, user_session):
user = UserFactory.create()
resp = client.get(url_for("atst.helpdocs"))
assert resp.status_code == 200
assert user.full_name not in resp.data.decode()
user_session(user)
resp = client.get(url_for("atst.helpdocs"))
assert resp.status_code == 200
assert user.full_name in resp.data.decode()
def test_unprotected_routes_set_user_if_logged_in(client, app, user_session):
user = UserFactory.create()
resp = client.get(url_for("atst.helpdocs"))
assert resp.status_code == 200
assert user.full_name not in resp.data.decode()
user_session(user)
resp = client.get(url_for("atst.helpdocs"))
assert resp.status_code == 200
assert user.full_name in resp.data.decode()
@pytest.fixture
def swap_crl_cache(
app, ca_key, ca_file, crl_file, make_crl, serialize_pki_object_to_disk
):
original = app.crl_cache
def _swap_crl_cache(new_cache=None):
if new_cache:
app.crl_cache = new_cache
else:
crl = make_crl(ca_key)
serialize_pki_object_to_disk(crl, crl_file, encoding=Encoding.DER)
crl_dir = os.path.dirname(crl_file)
app.crl_cache = CRLCache(ca_file, crl_dir)
yield _swap_crl_cache
app.crl_cache = original
def test_crl_validation_on_login(
app,
client,
ca_key,
ca_file,
crl_file,
rsa_key,
make_x509,
make_crl,
serialize_pki_object_to_disk,
swap_crl_cache,
):
good_cert = make_x509(rsa_key(), signer_key=ca_key, cn="luke")
bad_cert = make_x509(rsa_key(), signer_key=ca_key, cn="darth")
crl = make_crl(ca_key, expired_serials=[bad_cert.serial_number])
serialize_pki_object_to_disk(crl, crl_file, encoding=Encoding.DER)
crl_dir = os.path.dirname(crl_file)
cache = CRLCache(ca_file, crl_dir)
swap_crl_cache(cache)
# bad cert is on the test CRL
resp = _login(client, cert=bad_cert.public_bytes(Encoding.PEM).decode())
assert resp.status_code == 401
assert "user_id" not in session
# good cert is not on the test CRL, passes
resp = _login(client, cert=good_cert.public_bytes(Encoding.PEM).decode())
assert session["user_id"]
def test_creates_new_user_on_login(monkeypatch, client, ca_key):
monkeypatch.setattr(
"atst.domain.authnid.AuthenticationContext.authenticate", lambda *args: True
)
cert_file = open("tests/fixtures/{}.crt".format(FIXTURE_EMAIL_ADDRESS)).read()
# ensure user does not exist
with pytest.raises(NotFoundError):
Users.get_by_dod_id(DOD_SDN_INFO["dod_id"])
resp = _login(client, cert=cert_file)
user = Users.get_by_dod_id(DOD_SDN_INFO["dod_id"])
assert user.first_name == DOD_SDN_INFO["first_name"]
assert user.last_name == DOD_SDN_INFO["last_name"]
assert user.email == FIXTURE_EMAIL_ADDRESS
def test_creates_new_user_without_email_on_login(
client, ca_key, rsa_key, make_x509, swap_crl_cache
):
cert = make_x509(rsa_key(), signer_key=ca_key, cn=DOD_SDN)
swap_crl_cache()
# ensure user does not exist
with pytest.raises(NotFoundError):
Users.get_by_dod_id(DOD_SDN_INFO["dod_id"])
resp = _login(client, cert=cert.public_bytes(Encoding.PEM).decode())
user = Users.get_by_dod_id(DOD_SDN_INFO["dod_id"])
assert user.first_name == DOD_SDN_INFO["first_name"]
assert user.last_name == DOD_SDN_INFO["last_name"]
assert user.email == None
def test_logout(app, client, monkeypatch):
monkeypatch.setattr(
"atst.domain.authnid.AuthenticationContext.authenticate", lambda s: True
)
monkeypatch.setattr(
"atst.domain.authnid.AuthenticationContext.get_user",
lambda s: UserFactory.create(),
)
# create a real session
resp = _login(client)
resp_success = client.get(url_for("users.user"))
# verify session is valid
assert resp_success.status_code == 200
client.get(url_for("atst.logout"))
resp_failure = client.get(url_for("users.user"))
# verify that logging out has cleared the session
assert resp_failure.status_code == 302
destination = urlparse(resp_failure.headers["Location"]).path
assert destination == url_for("atst.root")
def test_logging_out_creates_a_flash_message(app, client, monkeypatch):
monkeypatch.setattr(
"atst.domain.authnid.AuthenticationContext.authenticate", lambda s: True
)
monkeypatch.setattr(
"atst.domain.authnid.AuthenticationContext.get_user",
lambda s: UserFactory.create(),
)
_login(client)
logout_response = client.get(url_for("atst.logout"), follow_redirects=True)
assert "Logged out" in logout_response.data.decode()
def test_redirected_on_login(client, monkeypatch):
monkeypatch.setattr(
"atst.domain.authnid.AuthenticationContext.authenticate", lambda *args: True
)
monkeypatch.setattr(
"atst.domain.authnid.AuthenticationContext.get_user",
lambda *args: UserFactory.create(),
)
target_route = url_for("users.user")
response = _login(client, next=target_route)
assert target_route in response.headers.get("Location")
def test_error_on_invalid_crl(client, monkeypatch):
def _raise_crl_error(*args):
raise CRLInvalidException()
monkeypatch.setattr(
"atst.domain.authnid.AuthenticationContext.authenticate", _raise_crl_error
)
response = _login(client)
assert response.status_code == 401
assert "Error Code 008" in response.data.decode()
def test_last_login_set_when_user_logs_in(client, monkeypatch):
last_login = datetime.now()
user = UserFactory.create(last_login=last_login)
monkeypatch.setattr(
"atst.domain.authnid.AuthenticationContext.authenticate", lambda *args: True
)
monkeypatch.setattr(
"atst.domain.authnid.AuthenticationContext.get_user", lambda *args: user
)
response = _login(client)
assert session["last_login"]
assert user.last_login > session["last_login"]
assert isinstance(session["last_login"], datetime)