import os import pytest import alembic.config import alembic.command from logging.config import dictConfig from werkzeug.datastructures import FileStorage from collections import OrderedDict from unittest.mock import Mock from atat.app import make_app, make_config from atat.database import db as _db import tests.factories as factories from tests.mocks import PDF_FILENAME, PDF_FILENAME2 from tests.utils import FakeLogger, FakeNotificationSender import pendulum from cryptography.hazmat.primitives.asymmetric import rsa from cryptography import x509 from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.serialization import Encoding from cryptography.x509.oid import NameOID dictConfig({"version": 1, "handlers": {"wsgi": {"class": "logging.NullHandler"}}}) @pytest.fixture(scope="session") def app(request): config = make_config() _app = make_app(config) ctx = _app.app_context() ctx.push() yield _app ctx.pop() @pytest.fixture(autouse=True) def skip_audit_log(request): """ Conditionally skip tests marked with 'audit_log' based on the USE_AUDIT_LOG config value. """ config = make_config() if request.node.get_closest_marker("audit_log"): use_audit_log = config.get("USE_AUDIT_LOG", False) if not use_audit_log: pytest.skip("audit log feature flag disabled") @pytest.fixture(scope="function") def no_debug_app(request): config = make_config(direct_config={"DEBUG": False}) _app = make_app(config) ctx = _app.app_context() ctx.push() yield _app ctx.pop() @pytest.fixture(scope="function") def no_debug_client(no_debug_app): yield no_debug_app.test_client() def apply_migrations(): """Applies all alembic migrations.""" alembic_config = os.path.join(os.path.dirname(__file__), "../", "alembic.ini") config = alembic.config.Config(alembic_config) app_config = make_config() config.set_main_option("sqlalchemy.url", app_config["DATABASE_URI"]) alembic.command.upgrade(config, "head") @pytest.fixture(scope="session") def db(app, request): _db.app = app apply_migrations() yield _db _db.drop_all() @pytest.fixture(scope="function", autouse=True) def session(db, request): """Creates a new database session for a test.""" connection = db.engine.connect() transaction = connection.begin() options = dict(bind=connection, binds={}) session = db.create_scoped_session(options=options) db.session = session factory_list = [ cls for _name, cls in factories.__dict__.items() if isinstance(cls, type) and cls.__module__ == "tests.factories" ] for factory in factory_list: factory._meta.sqlalchemy_session = session factory._meta.sqlalchemy_session_persistence = "commit" yield session transaction.rollback() connection.close() session.remove() class DummyForm(dict): def __init__(self, data=OrderedDict(), errors=(), raw_data=None): self._fields = data self.errors = list(errors) class DummyField(object): def __init__(self, data=None, errors=(), raw_data=None, name=None): self.data = data self.errors = list(errors) self.raw_data = raw_data self.name = name @pytest.fixture def dummy_form(): return DummyForm() @pytest.fixture def dummy_field(): return DummyField() @pytest.fixture def user_session(monkeypatch, session): def set_user_session(user=None): monkeypatch.setattr( "atat.domain.auth.get_current_user", lambda *args: user or factories.UserFactory.create(), ) return set_user_session @pytest.fixture def pdf_upload(): with open(PDF_FILENAME, "rb") as fp: yield FileStorage(fp, content_type="application/pdf") @pytest.fixture def pdf_upload2(): with open(PDF_FILENAME2, "rb") as fp: yield FileStorage(fp, content_type="application/pdf") @pytest.fixture def downloaded_task_order(): with open(PDF_FILENAME, "rb") as fp: yield {"name": "mock.pdf", "content": fp.read()} @pytest.fixture def extended_financial_verification_data(pdf_upload): return { "funding_type": "RDTE", "funding_type_other": "other", "expiration_date": "1/1/{}".format(pendulum.today().year + 1), "clin_0001": "50000", "clin_0003": "13000", "clin_1001": "30000", "clin_1003": "7000", "clin_2001": "30000", "clin_2003": "7000", "legacy_task_order": pdf_upload, } @pytest.fixture def crl_failover_open_app(app): app.config.update({"CRL_FAIL_OPEN": True}) yield app app.config.update({"CRL_FAIL_OPEN": False}) @pytest.fixture def rsa_key(): def _rsa_key(): return rsa.generate_private_key( public_exponent=65537, key_size=2048, backend=default_backend() ) return _rsa_key @pytest.fixture def ca_key(rsa_key): return rsa_key() @pytest.fixture def make_x509(): def _make_x509(private_key, signer_key=None, cn="ATAT", signer_cn="ATAT"): if signer_key is None: signer_key = private_key public_key = private_key.public_key() builder = x509.CertificateBuilder() builder = builder.subject_name( x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, cn)]) ) builder = builder.issuer_name( x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, signer_cn)]) ) if signer_key == private_key: builder = builder.add_extension( x509.BasicConstraints(ca=True, path_length=None), critical=True ) builder = builder.not_valid_before(pendulum.today().subtract(days=2)) builder = builder.not_valid_after(pendulum.today().add(days=30)) builder = builder.serial_number(x509.random_serial_number()) builder = builder.public_key(public_key) certificate = builder.sign( private_key=signer_key, algorithm=hashes.SHA256(), backend=default_backend() ) return certificate return _make_x509 @pytest.fixture def make_crl(): def _make_crl( private_key, last_update_days=-1, next_update_days=30, cn="ATAT", expired_serials=None, ): builder = x509.CertificateRevocationListBuilder() builder = builder.issuer_name( x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, cn)]) ) last_update = pendulum.today().add(days=last_update_days) next_update = pendulum.today().add(days=next_update_days) builder = builder.last_update(last_update) builder = builder.next_update(next_update) if expired_serials: for serial in expired_serials: builder = add_revoked_cert(builder, serial, last_update) crl = builder.sign( private_key=private_key, algorithm=hashes.SHA256(), backend=default_backend(), ) return crl return _make_crl def add_revoked_cert(crl_builder, serial, revocation_date): revoked_cert = ( x509.RevokedCertificateBuilder() .serial_number(serial) .revocation_date(revocation_date) .build(default_backend()) ) return crl_builder.add_revoked_certificate(revoked_cert) @pytest.fixture def serialize_pki_object_to_disk(): def _serialize_pki_object_to_disk(obj, name, encoding=Encoding.PEM): with open(name, "wb") as file_: file_.write(obj.public_bytes(encoding)) return name return _serialize_pki_object_to_disk @pytest.fixture def ca_file(make_x509, ca_key, tmpdir, serialize_pki_object_to_disk): ca = make_x509(ca_key) ca_out = tmpdir.join("atat-ca.crt") serialize_pki_object_to_disk(ca, ca_out) return ca_out @pytest.fixture def expired_crl_file(make_crl, ca_key, tmpdir, serialize_pki_object_to_disk): crl = make_crl(ca_key, last_update_days=-7, next_update_days=-1) crl_out = tmpdir.join("atat-expired.crl") serialize_pki_object_to_disk(crl, crl_out, encoding=Encoding.DER) return crl_out @pytest.fixture def crl_file(make_crl, ca_key, tmpdir, serialize_pki_object_to_disk): crl = make_crl(ca_key) crl_out = tmpdir.join("atat-valid.crl") serialize_pki_object_to_disk(crl, crl_out, encoding=Encoding.DER) return crl_out @pytest.fixture def mock_logger(app): real_logger = app.logger app.logger = FakeLogger() yield app.logger app.logger = real_logger @pytest.fixture(scope="function", autouse=True) def notification_sender(app): real_notification_sender = app.notification_sender app.notification_sender = FakeNotificationSender() yield app.notification_sender app.notification_sender = real_notification_sender # This is the only effective means I could find to disable logging. Setting a # `celery_enable_logging` fixture to return False should work according to the # docs, but doesn't: # https://docs.celeryproject.org/en/latest/userguide/testing.html#celery-enable-logging-override-to-enable-logging-in-embedded-workers @pytest.fixture(scope="function") def celery_worker_parameters(): return {"loglevel": "FATAL"}