Currently, we use both Python's built-in datetime library and Pendulum to do datetime operations. For the sake of consistency, we should try to stick to one library for datetimes. We could have used either, but Pendulum has a more ergonomic API, so I decided to go with it when possible. The places where were we didn't / couldn't replace datetime are: - checking instances of datetimes. Pendulum's objects are subclasses of python native datetime objects, so it's still useful to import datetime in those cases of using is_instance() - WTForms date validators expect datetime style string formats -- Pendulum has its own format for formatting/ parsing strings. As such, our custom validator DateRange needs to use datetime.stptime() to account for this format.
349 lines
9.1 KiB
Python
349 lines
9.1 KiB
Python
import os
|
|
import pytest
|
|
import alembic.config
|
|
import alembic.command
|
|
from logging.config import dictConfig
|
|
from werkzeug.datastructures import FileStorage
|
|
from collections import OrderedDict
|
|
from unittest.mock import Mock
|
|
|
|
from atst.app import make_app, make_config
|
|
from atst.database import db as _db
|
|
import tests.factories as factories
|
|
from tests.mocks import PDF_FILENAME, PDF_FILENAME2
|
|
from tests.utils import FakeLogger, FakeNotificationSender
|
|
|
|
import pendulum
|
|
from cryptography.hazmat.primitives.asymmetric import rsa
|
|
from cryptography import x509
|
|
from cryptography.hazmat.backends import default_backend
|
|
from cryptography.hazmat.primitives import hashes
|
|
from cryptography.hazmat.primitives.serialization import Encoding
|
|
from cryptography.x509.oid import NameOID
|
|
|
|
|
|
dictConfig({"version": 1, "handlers": {"wsgi": {"class": "logging.NullHandler"}}})
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def app(request):
|
|
config = make_config()
|
|
_app = make_app(config)
|
|
|
|
ctx = _app.app_context()
|
|
ctx.push()
|
|
|
|
yield _app
|
|
|
|
ctx.pop()
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def skip_audit_log(request):
|
|
"""
|
|
Conditionally skip tests marked with 'audit_log' based on the
|
|
USE_AUDIT_LOG config value.
|
|
"""
|
|
config = make_config()
|
|
if request.node.get_closest_marker("audit_log"):
|
|
use_audit_log = config.get("USE_AUDIT_LOG", False)
|
|
if not use_audit_log:
|
|
pytest.skip("audit log feature flag disabled")
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def no_debug_app(request):
|
|
config = make_config(direct_config={"DEBUG": False})
|
|
_app = make_app(config)
|
|
|
|
ctx = _app.app_context()
|
|
ctx.push()
|
|
|
|
yield _app
|
|
|
|
ctx.pop()
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def no_debug_client(no_debug_app):
|
|
yield no_debug_app.test_client()
|
|
|
|
|
|
def apply_migrations():
|
|
"""Applies all alembic migrations."""
|
|
alembic_config = os.path.join(os.path.dirname(__file__), "../", "alembic.ini")
|
|
config = alembic.config.Config(alembic_config)
|
|
app_config = make_config()
|
|
config.set_main_option("sqlalchemy.url", app_config["DATABASE_URI"])
|
|
alembic.command.upgrade(config, "head")
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def db(app, request):
|
|
|
|
_db.app = app
|
|
|
|
apply_migrations()
|
|
|
|
yield _db
|
|
|
|
_db.drop_all()
|
|
|
|
|
|
@pytest.fixture(scope="function", autouse=True)
|
|
def session(db, request):
|
|
"""Creates a new database session for a test."""
|
|
connection = db.engine.connect()
|
|
transaction = connection.begin()
|
|
|
|
options = dict(bind=connection, binds={})
|
|
session = db.create_scoped_session(options=options)
|
|
|
|
db.session = session
|
|
|
|
factory_list = [
|
|
cls
|
|
for _name, cls in factories.__dict__.items()
|
|
if isinstance(cls, type) and cls.__module__ == "tests.factories"
|
|
]
|
|
for factory in factory_list:
|
|
factory._meta.sqlalchemy_session = session
|
|
factory._meta.sqlalchemy_session_persistence = "commit"
|
|
|
|
yield session
|
|
|
|
transaction.rollback()
|
|
connection.close()
|
|
session.remove()
|
|
|
|
|
|
class DummyForm(dict):
|
|
def __init__(self, data=OrderedDict(), errors=(), raw_data=None):
|
|
self._fields = data
|
|
self.errors = list(errors)
|
|
|
|
|
|
class DummyField(object):
|
|
def __init__(self, data=None, errors=(), raw_data=None, name=None):
|
|
self.data = data
|
|
self.errors = list(errors)
|
|
self.raw_data = raw_data
|
|
self.name = name
|
|
|
|
|
|
@pytest.fixture
|
|
def dummy_form():
|
|
return DummyForm()
|
|
|
|
|
|
@pytest.fixture
|
|
def dummy_field():
|
|
return DummyField()
|
|
|
|
|
|
@pytest.fixture
|
|
def user_session(monkeypatch, session):
|
|
def set_user_session(user=None):
|
|
monkeypatch.setattr(
|
|
"atst.domain.auth.get_current_user",
|
|
lambda *args: user or factories.UserFactory.create(),
|
|
)
|
|
|
|
return set_user_session
|
|
|
|
|
|
@pytest.fixture
|
|
def pdf_upload():
|
|
with open(PDF_FILENAME, "rb") as fp:
|
|
yield FileStorage(fp, content_type="application/pdf")
|
|
|
|
|
|
@pytest.fixture
|
|
def pdf_upload2():
|
|
with open(PDF_FILENAME2, "rb") as fp:
|
|
yield FileStorage(fp, content_type="application/pdf")
|
|
|
|
|
|
@pytest.fixture
|
|
def downloaded_task_order():
|
|
with open(PDF_FILENAME, "rb") as fp:
|
|
yield {"name": "mock.pdf", "content": fp.read()}
|
|
|
|
|
|
@pytest.fixture
|
|
def extended_financial_verification_data(pdf_upload):
|
|
return {
|
|
"funding_type": "RDTE",
|
|
"funding_type_other": "other",
|
|
"expiration_date": "1/1/{}".format(pendulum.today().year + 1),
|
|
"clin_0001": "50000",
|
|
"clin_0003": "13000",
|
|
"clin_1001": "30000",
|
|
"clin_1003": "7000",
|
|
"clin_2001": "30000",
|
|
"clin_2003": "7000",
|
|
"legacy_task_order": pdf_upload,
|
|
}
|
|
|
|
|
|
@pytest.fixture
|
|
def crl_failover_open_app(app):
|
|
app.config.update({"CRL_FAIL_OPEN": True})
|
|
yield app
|
|
app.config.update({"CRL_FAIL_OPEN": False})
|
|
|
|
|
|
@pytest.fixture
|
|
def rsa_key():
|
|
def _rsa_key():
|
|
return rsa.generate_private_key(
|
|
public_exponent=65537, key_size=2048, backend=default_backend()
|
|
)
|
|
|
|
return _rsa_key
|
|
|
|
|
|
@pytest.fixture
|
|
def ca_key(rsa_key):
|
|
return rsa_key()
|
|
|
|
|
|
@pytest.fixture
|
|
def make_x509():
|
|
def _make_x509(private_key, signer_key=None, cn="ATAT", signer_cn="ATAT"):
|
|
if signer_key is None:
|
|
signer_key = private_key
|
|
|
|
public_key = private_key.public_key()
|
|
builder = x509.CertificateBuilder()
|
|
builder = builder.subject_name(
|
|
x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, cn)])
|
|
)
|
|
builder = builder.issuer_name(
|
|
x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, signer_cn)])
|
|
)
|
|
if signer_key == private_key:
|
|
builder = builder.add_extension(
|
|
x509.BasicConstraints(ca=True, path_length=None), critical=True
|
|
)
|
|
builder = builder.not_valid_before(pendulum.today().subtract(days=2))
|
|
builder = builder.not_valid_after(pendulum.today().add(days=30))
|
|
builder = builder.serial_number(x509.random_serial_number())
|
|
builder = builder.public_key(public_key)
|
|
certificate = builder.sign(
|
|
private_key=signer_key, algorithm=hashes.SHA256(), backend=default_backend()
|
|
)
|
|
|
|
return certificate
|
|
|
|
return _make_x509
|
|
|
|
|
|
@pytest.fixture
|
|
def make_crl():
|
|
def _make_crl(
|
|
private_key,
|
|
last_update_days=-1,
|
|
next_update_days=30,
|
|
cn="ATAT",
|
|
expired_serials=None,
|
|
):
|
|
builder = x509.CertificateRevocationListBuilder()
|
|
builder = builder.issuer_name(
|
|
x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, cn)])
|
|
)
|
|
last_update = pendulum.today().add(days=last_update_days)
|
|
next_update = pendulum.today().add(days=next_update_days)
|
|
builder = builder.last_update(last_update)
|
|
builder = builder.next_update(next_update)
|
|
if expired_serials:
|
|
for serial in expired_serials:
|
|
builder = add_revoked_cert(builder, serial, last_update)
|
|
|
|
crl = builder.sign(
|
|
private_key=private_key,
|
|
algorithm=hashes.SHA256(),
|
|
backend=default_backend(),
|
|
)
|
|
|
|
return crl
|
|
|
|
return _make_crl
|
|
|
|
|
|
def add_revoked_cert(crl_builder, serial, revocation_date):
|
|
revoked_cert = (
|
|
x509.RevokedCertificateBuilder()
|
|
.serial_number(serial)
|
|
.revocation_date(revocation_date)
|
|
.build(default_backend())
|
|
)
|
|
return crl_builder.add_revoked_certificate(revoked_cert)
|
|
|
|
|
|
@pytest.fixture
|
|
def serialize_pki_object_to_disk():
|
|
def _serialize_pki_object_to_disk(obj, name, encoding=Encoding.PEM):
|
|
with open(name, "wb") as file_:
|
|
file_.write(obj.public_bytes(encoding))
|
|
|
|
return name
|
|
|
|
return _serialize_pki_object_to_disk
|
|
|
|
|
|
@pytest.fixture
|
|
def ca_file(make_x509, ca_key, tmpdir, serialize_pki_object_to_disk):
|
|
ca = make_x509(ca_key)
|
|
ca_out = tmpdir.join("atat-ca.crt")
|
|
serialize_pki_object_to_disk(ca, ca_out)
|
|
|
|
return ca_out
|
|
|
|
|
|
@pytest.fixture
|
|
def expired_crl_file(make_crl, ca_key, tmpdir, serialize_pki_object_to_disk):
|
|
crl = make_crl(ca_key, last_update_days=-7, next_update_days=-1)
|
|
crl_out = tmpdir.join("atat-expired.crl")
|
|
serialize_pki_object_to_disk(crl, crl_out, encoding=Encoding.DER)
|
|
|
|
return crl_out
|
|
|
|
|
|
@pytest.fixture
|
|
def crl_file(make_crl, ca_key, tmpdir, serialize_pki_object_to_disk):
|
|
crl = make_crl(ca_key)
|
|
crl_out = tmpdir.join("atat-valid.crl")
|
|
serialize_pki_object_to_disk(crl, crl_out, encoding=Encoding.DER)
|
|
|
|
return crl_out
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_logger(app):
|
|
real_logger = app.logger
|
|
app.logger = FakeLogger()
|
|
|
|
yield app.logger
|
|
|
|
app.logger = real_logger
|
|
|
|
|
|
@pytest.fixture(scope="function", autouse=True)
|
|
def notification_sender(app):
|
|
real_notification_sender = app.notification_sender
|
|
app.notification_sender = FakeNotificationSender()
|
|
|
|
yield app.notification_sender
|
|
|
|
app.notification_sender = real_notification_sender
|
|
|
|
|
|
# This is the only effective means I could find to disable logging. Setting a
|
|
# `celery_enable_logging` fixture to return False should work according to the
|
|
# docs, but doesn't:
|
|
# https://docs.celeryproject.org/en/latest/userguide/testing.html#celery-enable-logging-override-to-enable-logging-in-embedded-workers
|
|
@pytest.fixture(scope="function")
|
|
def celery_worker_parameters():
|
|
return {"loglevel": "FATAL"}
|