diff --git a/tests/conftest.py b/tests/conftest.py index bdc59286..c3add296 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,9 +6,10 @@ import alembic.command from atst.app import make_app, make_config from atst.database import db as _db from .mocks import MOCK_USER +import tests.factories as factories -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def app(request): config = make_config() @@ -27,11 +28,11 @@ def apply_migrations(): alembic_config = os.path.join(os.path.dirname(__file__), "../", "alembic.ini") config = alembic.config.Config(alembic_config) app_config = make_config() - config.set_main_option('sqlalchemy.url', app_config["DATABASE_URI"]) - alembic.command.upgrade(config, 'head') + config.set_main_option("sqlalchemy.url", app_config["DATABASE_URI"]) + alembic.command.upgrade(config, "head") -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def db(app, request): _db.app = app @@ -43,7 +44,7 @@ def db(app, request): _db.drop_all() -@pytest.fixture(scope='function', autouse=True) +@pytest.fixture(scope="function", autouse=True) def session(db, request): """Creates a new database session for a test.""" connection = db.engine.connect() @@ -54,6 +55,14 @@ def session(db, request): db.session = session + factory_list = [ + cls + for _name, cls in factories.__dict__.items() + if isinstance(cls, type) and cls.__module__ == "tests.factories" + ] + for factory in factory_list: + factory._meta.sqlalchemy_session = session + yield session transaction.rollback() @@ -66,6 +75,7 @@ class DummyForm(dict): class DummyField(object): + def __init__(self, data=None, errors=(), raw_data=None): self.data = data self.errors = list(errors) @@ -81,10 +91,11 @@ def dummy_form(): def dummy_field(): return DummyField() + @pytest.fixture def user_session(monkeypatch): - def set_user_session(user = MOCK_USER): + def set_user_session(user=MOCK_USER): monkeypatch.setattr("atst.domain.auth.get_current_user", lambda *args: user) return set_user_session diff --git a/tests/domain/test_pe_numbers.py b/tests/domain/test_pe_numbers.py index 60c410b5..98b90470 100644 --- a/tests/domain/test_pe_numbers.py +++ b/tests/domain/test_pe_numbers.py @@ -6,20 +6,8 @@ from atst.domain.pe_numbers import PENumbers from tests.factories import PENumberFactory -@pytest.fixture(scope="function") -def new_pe_number(session): - def make_pe_number(**kwargs): - pen = PENumberFactory.create(**kwargs) - session.add(pen) - session.commit() - - return pen - - return make_pe_number - - -def test_can_get_pe_number(new_pe_number): - new_pen = new_pe_number(number="0701367F", description="Combat Support - Offensive") +def test_can_get_pe_number(): + new_pen = PENumberFactory.create(number="0701367F", description="Combat Support - Offensive") pen = PENumbers.get(new_pen.number) assert pen.number == new_pen.number diff --git a/tests/domain/test_requests.py b/tests/domain/test_requests.py index 901caf73..0597714a 100644 --- a/tests/domain/test_requests.py +++ b/tests/domain/test_requests.py @@ -9,11 +9,7 @@ from tests.factories import RequestFactory @pytest.fixture(scope="function") def new_request(session): - created_request = RequestFactory.create() - session.add(created_request) - session.commit() - - return created_request + return RequestFactory.create() def test_can_get_request(new_request): diff --git a/tests/domain/test_task_orders.py b/tests/domain/test_task_orders.py index 2f03a6d0..ba422bc5 100644 --- a/tests/domain/test_task_orders.py +++ b/tests/domain/test_task_orders.py @@ -6,20 +6,8 @@ from atst.domain.task_orders import TaskOrders from tests.factories import TaskOrderFactory -@pytest.fixture(scope="function") -def new_task_order(session): - def make_task_order(**kwargs): - to = TaskOrderFactory.create(**kwargs) - session.add(to) - session.commit() - - return to - - return make_task_order - - -def test_can_get_task_order(new_task_order): - new_to = new_task_order(number="0101969F") +def test_can_get_task_order(): + new_to = TaskOrderFactory.create(number="0101969F") to = TaskOrders.get(new_to.number) assert to.id == to.id diff --git a/tests/factories.py b/tests/factories.py index 11b11a1b..5cefc1eb 100644 --- a/tests/factories.py +++ b/tests/factories.py @@ -8,27 +8,36 @@ from atst.models.user import User from atst.models.role import Role -class RequestFactory(factory.Factory): +class RequestFactory(factory.alchemy.SQLAlchemyModelFactory): + class Meta: model = Request id = factory.Sequence(lambda x: uuid4()) -class PENumberFactory(factory.Factory): + +class PENumberFactory(factory.alchemy.SQLAlchemyModelFactory): + class Meta: model = PENumber -class TaskOrderFactory(factory.Factory): + +class TaskOrderFactory(factory.alchemy.SQLAlchemyModelFactory): + class Meta: model = TaskOrder -class RoleFactory(factory.Factory): + +class RoleFactory(factory.alchemy.SQLAlchemyModelFactory): + class Meta: model = Role permissions = [] -class UserFactory(factory.Factory): + +class UserFactory(factory.alchemy.SQLAlchemyModelFactory): + class Meta: model = User diff --git a/tests/mocks.py b/tests/mocks.py index 1e44f96e..0307997e 100644 --- a/tests/mocks.py +++ b/tests/mocks.py @@ -1,8 +1,8 @@ from tests.factories import RequestFactory, UserFactory -MOCK_USER = UserFactory.create() -MOCK_REQUEST = RequestFactory.create( +MOCK_USER = UserFactory.build() +MOCK_REQUEST = RequestFactory.build( creator=MOCK_USER.id, body={ "financial_verification": { diff --git a/tests/routes/test_financial_verification.py b/tests/routes/test_financial_verification.py index 60f0f9b0..e15638df 100644 --- a/tests/routes/test_financial_verification.py +++ b/tests/routes/test_financial_verification.py @@ -61,11 +61,9 @@ class TestPENumberInForm: assert response.status_code == 302 assert "/requests/financial_verification_submitted" in response.headers.get("Location") - def test_submit_request_form_with_new_valid_pe_id(self, session, monkeypatch, client): + def test_submit_request_form_with_new_valid_pe_id(self, monkeypatch, client): self._set_monkeypatches(monkeypatch) pe = PENumberFactory.create(number="8675309U", description="sample PE number") - session.add(pe) - session.commit() data = dict(self.required_data) data['pe_id'] = pe.number