The previous solution (ad-hoc stream-parsing the CRLs to obtain their issuers and nextUpdate) was too cute. It began breaking on CRLs that had an addition hex 0x30 byte somewhere in their header. I thought that 0x30 was a reserved character only to be used for tags in ASN1 encoded with DER; turns out that's not true. Rather than write a full-fledged ASN1 stream-parser, the simplest solution is to just maintain the list of issuers as a constant in the codebase. This is fine because the issuer for a specific CRL URI should not change. If it does, we've probably got bigger problems. This also removes the Flask app's functionality for updating the local CRL cache. This is being handled out-of-band by a Kubernetes CronJob and is not a concern of the app's. This means that instances of the CRLCache do not have to explicitly track expirations for CRLs. Previously, the in-memory dictionary or CRL issuers and locations included expirations; now it is flattened to not include that information. The CRLCache class has been updated to accept a crl_list kwargs so that unit tests can provide their own alternative CRL lists, since we now hard-code the expected CRLs and issuers. The nightly CRL check job has been updated to check that the hard-coded list of issuers matches what we get when we actually sync the CRLs.
304 lines
9.5 KiB
Python
304 lines
9.5 KiB
Python
import os
|
|
from urllib.parse import urlparse
|
|
|
|
import pytest
|
|
from datetime import datetime
|
|
from flask import session, url_for
|
|
from cryptography.hazmat.primitives.serialization import Encoding
|
|
|
|
from atst.domain.users import Users
|
|
from atst.domain.permission_sets import PermissionSets
|
|
from atst.domain.exceptions import NotFoundError
|
|
from atst.domain.authnid.crl import CRLInvalidException
|
|
from atst.domain.auth import UNPROTECTED_ROUTES
|
|
from atst.domain.authnid.crl import CRLCache
|
|
|
|
from .factories import UserFactory
|
|
from .mocks import DOD_SDN_INFO, DOD_SDN, FIXTURE_EMAIL_ADDRESS
|
|
from .utils import make_crl_list
|
|
|
|
|
|
MOCK_USER = {"id": "438567dd-25fa-4d83-a8cc-8aa8366cb24a"}
|
|
|
|
|
|
def _fetch_user_info(c, t):
|
|
return MOCK_USER
|
|
|
|
|
|
def _login(client, verify="SUCCESS", sdn=DOD_SDN, cert="", **url_query_args):
|
|
return client.get(
|
|
url_for("atst.login_redirect", **url_query_args),
|
|
environ_base={
|
|
"HTTP_X_SSL_CLIENT_VERIFY": verify,
|
|
"HTTP_X_SSL_CLIENT_S_DN": sdn,
|
|
"HTTP_X_SSL_CLIENT_CERT": cert,
|
|
},
|
|
)
|
|
|
|
|
|
def test_successful_login_redirect_non_ccpo(client, monkeypatch):
|
|
monkeypatch.setattr(
|
|
"atst.domain.authnid.AuthenticationContext.authenticate", lambda *args: True
|
|
)
|
|
monkeypatch.setattr(
|
|
"atst.domain.authnid.AuthenticationContext.get_user",
|
|
lambda *args: UserFactory.create(),
|
|
)
|
|
|
|
resp = _login(client)
|
|
|
|
assert resp.status_code == 302
|
|
assert "home" in resp.headers["Location"]
|
|
assert session["user_id"]
|
|
|
|
|
|
def test_successful_login_redirect_ccpo(client, monkeypatch):
|
|
monkeypatch.setattr(
|
|
"atst.domain.authnid.AuthenticationContext.authenticate", lambda *args: True
|
|
)
|
|
role = PermissionSets.get(PermissionSets.VIEW_AUDIT_LOG)
|
|
monkeypatch.setattr(
|
|
"atst.domain.authnid.AuthenticationContext.get_user",
|
|
lambda *args: UserFactory.create(),
|
|
)
|
|
|
|
resp = _login(client)
|
|
|
|
assert resp.status_code == 302
|
|
assert "home" in resp.headers["Location"]
|
|
assert session["user_id"]
|
|
|
|
|
|
def test_unsuccessful_login_redirect(client, monkeypatch):
|
|
resp = client.get(url_for("atst.login_redirect"))
|
|
|
|
assert resp.status_code == 401
|
|
assert "user_id" not in session
|
|
|
|
|
|
# checks that all of the routes in the app are protected by auth
|
|
def is_unprotected(rule):
|
|
return rule.endpoint in UNPROTECTED_ROUTES
|
|
|
|
|
|
def protected_routes(app):
|
|
for rule in app.url_map.iter_rules():
|
|
args = [1] * len(rule.arguments)
|
|
mock_args = dict(zip(rule.arguments, args))
|
|
_n, route = rule.build(mock_args)
|
|
if is_unprotected(rule) or "/static" in route:
|
|
continue
|
|
yield rule, route
|
|
|
|
|
|
def test_protected_routes_redirect_to_login(client, app):
|
|
server_name = app.config.get("SERVER_NAME") or "localhost"
|
|
for rule, protected_route in protected_routes(app):
|
|
if "GET" in rule.methods:
|
|
resp = client.get(protected_route)
|
|
assert resp.status_code == 302
|
|
assert server_name in resp.headers["Location"]
|
|
|
|
if "POST" in rule.methods:
|
|
resp = client.post(protected_route)
|
|
assert resp.status_code == 302
|
|
assert server_name in resp.headers["Location"]
|
|
|
|
|
|
def test_get_protected_route_encodes_redirect(client):
|
|
portfolio_index = url_for("portfolios.portfolios")
|
|
response = client.get(portfolio_index)
|
|
redirect = url_for("atst.root", next=portfolio_index)
|
|
assert redirect in response.headers["Location"]
|
|
|
|
|
|
def test_unprotected_routes_set_user_if_logged_in(client, app, user_session):
|
|
user = UserFactory.create()
|
|
|
|
resp = client.get(url_for("atst.helpdocs"))
|
|
assert resp.status_code == 200
|
|
assert user.full_name not in resp.data.decode()
|
|
|
|
user_session(user)
|
|
resp = client.get(url_for("atst.helpdocs"))
|
|
assert resp.status_code == 200
|
|
assert user.full_name in resp.data.decode()
|
|
|
|
|
|
def test_unprotected_routes_set_user_if_logged_in(client, app, user_session):
|
|
user = UserFactory.create()
|
|
|
|
resp = client.get(url_for("atst.helpdocs"))
|
|
assert resp.status_code == 200
|
|
assert user.full_name not in resp.data.decode()
|
|
|
|
user_session(user)
|
|
resp = client.get(url_for("atst.helpdocs"))
|
|
assert resp.status_code == 200
|
|
assert user.full_name in resp.data.decode()
|
|
|
|
|
|
@pytest.fixture
|
|
def swap_crl_cache(
|
|
app, ca_key, ca_file, crl_file, make_crl, serialize_pki_object_to_disk
|
|
):
|
|
original = app.crl_cache
|
|
|
|
def _swap_crl_cache(new_cache=None):
|
|
if new_cache:
|
|
app.crl_cache = new_cache
|
|
else:
|
|
crl = make_crl(ca_key)
|
|
serialize_pki_object_to_disk(crl, crl_file, encoding=Encoding.DER)
|
|
crl_dir = os.path.dirname(crl_file)
|
|
crl_list = make_crl_list(crl, crl_file)
|
|
app.crl_cache = CRLCache(ca_file, crl_dir, crl_list=crl_list)
|
|
|
|
yield _swap_crl_cache
|
|
|
|
app.crl_cache = original
|
|
|
|
|
|
def test_crl_validation_on_login(
|
|
app,
|
|
client,
|
|
ca_key,
|
|
ca_file,
|
|
crl_file,
|
|
rsa_key,
|
|
make_x509,
|
|
make_crl,
|
|
serialize_pki_object_to_disk,
|
|
swap_crl_cache,
|
|
):
|
|
good_cert = make_x509(rsa_key(), signer_key=ca_key, cn="luke")
|
|
bad_cert = make_x509(rsa_key(), signer_key=ca_key, cn="darth")
|
|
|
|
crl = make_crl(ca_key, expired_serials=[bad_cert.serial_number])
|
|
serialize_pki_object_to_disk(crl, crl_file, encoding=Encoding.DER)
|
|
|
|
crl_dir = os.path.dirname(crl_file)
|
|
crl_list = make_crl_list(good_cert, crl_file)
|
|
cache = CRLCache(ca_file, crl_dir, crl_list=crl_list)
|
|
swap_crl_cache(cache)
|
|
|
|
# bad cert is on the test CRL
|
|
resp = _login(client, cert=bad_cert.public_bytes(Encoding.PEM).decode())
|
|
assert resp.status_code == 401
|
|
assert "user_id" not in session
|
|
|
|
# good cert is not on the test CRL, passes
|
|
resp = _login(client, cert=good_cert.public_bytes(Encoding.PEM).decode())
|
|
assert session["user_id"]
|
|
|
|
|
|
def test_creates_new_user_on_login(monkeypatch, client, ca_key):
|
|
monkeypatch.setattr(
|
|
"atst.domain.authnid.AuthenticationContext.authenticate", lambda *args: True
|
|
)
|
|
cert_file = open("tests/fixtures/{}.crt".format(FIXTURE_EMAIL_ADDRESS)).read()
|
|
|
|
# ensure user does not exist
|
|
with pytest.raises(NotFoundError):
|
|
Users.get_by_dod_id(DOD_SDN_INFO["dod_id"])
|
|
|
|
resp = _login(client, cert=cert_file)
|
|
|
|
user = Users.get_by_dod_id(DOD_SDN_INFO["dod_id"])
|
|
assert user.first_name == DOD_SDN_INFO["first_name"]
|
|
assert user.last_name == DOD_SDN_INFO["last_name"]
|
|
assert user.email == FIXTURE_EMAIL_ADDRESS
|
|
|
|
|
|
def test_creates_new_user_without_email_on_login(
|
|
client, ca_key, rsa_key, make_x509, swap_crl_cache
|
|
):
|
|
cert = make_x509(rsa_key(), signer_key=ca_key, cn=DOD_SDN)
|
|
swap_crl_cache()
|
|
|
|
# ensure user does not exist
|
|
with pytest.raises(NotFoundError):
|
|
Users.get_by_dod_id(DOD_SDN_INFO["dod_id"])
|
|
|
|
resp = _login(client, cert=cert.public_bytes(Encoding.PEM).decode())
|
|
|
|
user = Users.get_by_dod_id(DOD_SDN_INFO["dod_id"])
|
|
assert user.first_name == DOD_SDN_INFO["first_name"]
|
|
assert user.last_name == DOD_SDN_INFO["last_name"]
|
|
assert user.email == None
|
|
|
|
|
|
def test_logout(app, client, monkeypatch):
|
|
monkeypatch.setattr(
|
|
"atst.domain.authnid.AuthenticationContext.authenticate", lambda s: True
|
|
)
|
|
monkeypatch.setattr(
|
|
"atst.domain.authnid.AuthenticationContext.get_user",
|
|
lambda s: UserFactory.create(),
|
|
)
|
|
# create a real session
|
|
resp = _login(client)
|
|
resp_success = client.get(url_for("users.user"))
|
|
# verify session is valid
|
|
assert resp_success.status_code == 200
|
|
client.get(url_for("atst.logout"))
|
|
resp_failure = client.get(url_for("users.user"))
|
|
# verify that logging out has cleared the session
|
|
assert resp_failure.status_code == 302
|
|
destination = urlparse(resp_failure.headers["Location"]).path
|
|
assert destination == url_for("atst.root")
|
|
|
|
|
|
def test_logging_out_creates_a_flash_message(app, client, monkeypatch):
|
|
monkeypatch.setattr(
|
|
"atst.domain.authnid.AuthenticationContext.authenticate", lambda s: True
|
|
)
|
|
monkeypatch.setattr(
|
|
"atst.domain.authnid.AuthenticationContext.get_user",
|
|
lambda s: UserFactory.create(),
|
|
)
|
|
_login(client)
|
|
logout_response = client.get(url_for("atst.logout"), follow_redirects=True)
|
|
|
|
assert "Logged out" in logout_response.data.decode()
|
|
|
|
|
|
def test_redirected_on_login(client, monkeypatch):
|
|
monkeypatch.setattr(
|
|
"atst.domain.authnid.AuthenticationContext.authenticate", lambda *args: True
|
|
)
|
|
monkeypatch.setattr(
|
|
"atst.domain.authnid.AuthenticationContext.get_user",
|
|
lambda *args: UserFactory.create(),
|
|
)
|
|
target_route = url_for("users.user")
|
|
response = _login(client, next=target_route)
|
|
assert target_route in response.headers.get("Location")
|
|
|
|
|
|
def test_error_on_invalid_crl(client, monkeypatch):
|
|
def _raise_crl_error(*args):
|
|
raise CRLInvalidException()
|
|
|
|
monkeypatch.setattr(
|
|
"atst.domain.authnid.AuthenticationContext.authenticate", _raise_crl_error
|
|
)
|
|
response = _login(client)
|
|
assert response.status_code == 401
|
|
assert "Error Code 008" in response.data.decode()
|
|
|
|
|
|
def test_last_login_set_when_user_logs_in(client, monkeypatch):
|
|
last_login = datetime.now()
|
|
user = UserFactory.create(last_login=last_login)
|
|
monkeypatch.setattr(
|
|
"atst.domain.authnid.AuthenticationContext.authenticate", lambda *args: True
|
|
)
|
|
monkeypatch.setattr(
|
|
"atst.domain.authnid.AuthenticationContext.get_user", lambda *args: user
|
|
)
|
|
response = _login(client)
|
|
assert session["last_login"]
|
|
assert user.last_login > session["last_login"]
|
|
assert isinstance(session["last_login"], datetime)
|