atst/tests/conftest.py
2019-10-28 13:55:34 -04:00

344 lines
9.1 KiB
Python

import os
import pytest
import alembic.config
import alembic.command
from logging.config import dictConfig
from werkzeug.datastructures import FileStorage
from collections import OrderedDict
from atst.app import make_app, make_config
from atst.database import db as _db
import tests.factories as factories
from tests.mocks import PDF_FILENAME, PDF_FILENAME2
from tests.utils import FakeLogger, FakeNotificationSender
from datetime import datetime, timedelta
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography import x509
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.serialization import Encoding
from cryptography.x509.oid import NameOID
dictConfig({"version": 1, "handlers": {"wsgi": {"class": "logging.NullHandler"}}})
@pytest.fixture(scope="session")
def app(request):
config = make_config()
_app = make_app(config)
ctx = _app.app_context()
ctx.push()
yield _app
ctx.pop()
@pytest.fixture(autouse=True)
def skip_audit_log(request):
"""
Conditionally skip tests marked with 'audit_log' based on the
USE_AUDIT_LOG config value.
"""
config = make_config()
if request.node.get_closest_marker("audit_log"):
use_audit_log = config.get("USE_AUDIT_LOG", False)
if not use_audit_log:
pytest.skip("audit log feature flag disabled")
@pytest.fixture(scope="function")
def no_debug_app(request):
config = make_config(direct_config={"DEBUG": False})
_app = make_app(config)
ctx = _app.app_context()
ctx.push()
yield _app
ctx.pop()
@pytest.fixture(scope="function")
def no_debug_client(no_debug_app):
yield no_debug_app.test_client()
def apply_migrations():
"""Applies all alembic migrations."""
alembic_config = os.path.join(os.path.dirname(__file__), "../", "alembic.ini")
config = alembic.config.Config(alembic_config)
app_config = make_config()
config.set_main_option("sqlalchemy.url", app_config["DATABASE_URI"])
alembic.command.upgrade(config, "head")
@pytest.fixture(scope="session")
def db(app, request):
_db.app = app
apply_migrations()
yield _db
_db.drop_all()
@pytest.fixture(scope="function", autouse=True)
def session(db, request):
"""Creates a new database session for a test."""
connection = db.engine.connect()
transaction = connection.begin()
options = dict(bind=connection, binds={})
session = db.create_scoped_session(options=options)
db.session = session
factory_list = [
cls
for _name, cls in factories.__dict__.items()
if isinstance(cls, type) and cls.__module__ == "tests.factories"
]
for factory in factory_list:
factory._meta.sqlalchemy_session = session
factory._meta.sqlalchemy_session_persistence = "commit"
yield session
transaction.rollback()
connection.close()
session.remove()
class DummyForm(dict):
def __init__(self, data=OrderedDict(), errors=(), raw_data=None):
self._fields = data
self.errors = list(errors)
class DummyField(object):
def __init__(self, data=None, errors=(), raw_data=None, name=None):
self.data = data
self.errors = list(errors)
self.raw_data = raw_data
self.name = name
@pytest.fixture
def dummy_form():
return DummyForm()
@pytest.fixture
def dummy_field():
return DummyField()
@pytest.fixture
def user_session(monkeypatch, session):
def set_user_session(user=None):
monkeypatch.setattr(
"atst.domain.auth.get_current_user",
lambda *args: user or factories.UserFactory.create(),
)
return set_user_session
@pytest.fixture
def pdf_upload():
with open(PDF_FILENAME, "rb") as fp:
yield FileStorage(fp, content_type="application/pdf")
@pytest.fixture
def pdf_upload2():
with open(PDF_FILENAME2, "rb") as fp:
yield FileStorage(fp, content_type="application/pdf")
@pytest.fixture
def extended_financial_verification_data(pdf_upload):
return {
"funding_type": "RDTE",
"funding_type_other": "other",
"expiration_date": "1/1/{}".format(datetime.date.today().year + 1),
"clin_0001": "50000",
"clin_0003": "13000",
"clin_1001": "30000",
"clin_1003": "7000",
"clin_2001": "30000",
"clin_2003": "7000",
"legacy_task_order": pdf_upload,
}
@pytest.fixture
def crl_failover_open_app(app):
app.config.update({"CRL_FAIL_OPEN": True})
yield app
app.config.update({"CRL_FAIL_OPEN": False})
@pytest.fixture
def rsa_key():
def _rsa_key():
return rsa.generate_private_key(
public_exponent=65537, key_size=2048, backend=default_backend()
)
return _rsa_key
@pytest.fixture
def ca_key(rsa_key):
return rsa_key()
@pytest.fixture
def make_x509():
def _make_x509(private_key, signer_key=None, cn="ATAT", signer_cn="ATAT"):
if signer_key is None:
signer_key = private_key
one_day = timedelta(1, 0, 0)
public_key = private_key.public_key()
builder = x509.CertificateBuilder()
builder = builder.subject_name(
x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, cn)])
)
builder = builder.issuer_name(
x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, signer_cn)])
)
if signer_key == private_key:
builder = builder.add_extension(
x509.BasicConstraints(ca=True, path_length=None), critical=True
)
builder = builder.not_valid_before(datetime.today() - (one_day * 2))
builder = builder.not_valid_after(datetime.today() + (one_day * 30))
builder = builder.serial_number(x509.random_serial_number())
builder = builder.public_key(public_key)
certificate = builder.sign(
private_key=signer_key, algorithm=hashes.SHA256(), backend=default_backend()
)
return certificate
return _make_x509
@pytest.fixture
def make_crl():
def _make_crl(
private_key,
last_update_days=-1,
next_update_days=30,
cn="ATAT",
expired_serials=None,
):
one_day = timedelta(1, 0, 0)
builder = x509.CertificateRevocationListBuilder()
builder = builder.issuer_name(
x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, cn)])
)
last_update = datetime.today() + (one_day * last_update_days)
next_update = datetime.today() + (one_day * next_update_days)
builder = builder.last_update(last_update)
builder = builder.next_update(next_update)
if expired_serials:
for serial in expired_serials:
builder = add_revoked_cert(builder, serial, last_update)
crl = builder.sign(
private_key=private_key,
algorithm=hashes.SHA256(),
backend=default_backend(),
)
return crl
return _make_crl
def add_revoked_cert(crl_builder, serial, revocation_date):
revoked_cert = (
x509.RevokedCertificateBuilder()
.serial_number(serial)
.revocation_date(revocation_date)
.build(default_backend())
)
return crl_builder.add_revoked_certificate(revoked_cert)
@pytest.fixture
def serialize_pki_object_to_disk():
def _serialize_pki_object_to_disk(obj, name, encoding=Encoding.PEM):
with open(name, "wb") as file_:
file_.write(obj.public_bytes(encoding))
return name
return _serialize_pki_object_to_disk
@pytest.fixture
def ca_file(make_x509, ca_key, tmpdir, serialize_pki_object_to_disk):
ca = make_x509(ca_key)
ca_out = tmpdir.join("atat-ca.crt")
serialize_pki_object_to_disk(ca, ca_out)
return ca_out
@pytest.fixture
def expired_crl_file(make_crl, ca_key, tmpdir, serialize_pki_object_to_disk):
crl = make_crl(ca_key, last_update_days=-7, next_update_days=-1)
crl_out = tmpdir.join("atat-expired.crl")
serialize_pki_object_to_disk(crl, crl_out, encoding=Encoding.DER)
return crl_out
@pytest.fixture
def crl_file(make_crl, ca_key, tmpdir, serialize_pki_object_to_disk):
crl = make_crl(ca_key)
crl_out = tmpdir.join("atat-valid.crl")
serialize_pki_object_to_disk(crl, crl_out, encoding=Encoding.DER)
return crl_out
@pytest.fixture
def mock_logger(app):
real_logger = app.logger
app.logger = FakeLogger()
yield app.logger
app.logger = real_logger
@pytest.fixture(scope="function", autouse=True)
def notification_sender(app):
real_notification_sender = app.notification_sender
app.notification_sender = FakeNotificationSender()
yield app.notification_sender
app.notification_sender = real_notification_sender
# This is the only effective means I could find to disable logging. Setting a
# `celery_enable_logging` fixture to return False should work according to the
# docs, but doesn't:
# https://docs.celeryproject.org/en/latest/userguide/testing.html#celery-enable-logging-override-to-enable-logging-in-embedded-workers
@pytest.fixture(scope="function")
def celery_worker_parameters():
return {"loglevel": "FATAL"}