diff --git a/tests/conftest.py b/tests/conftest.py index c650b5f1..baafaeda 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,18 +1,21 @@ import pytest +from sqlalchemy.orm import sessionmaker, scoped_session from atst.app import make_app, make_deps, make_config +from atst.database import make_db from tests.mocks import MockApiClient, MockFundzClient, MockRequestsClient, MockAuthzClient from atst.sessions import DictSessions @pytest.fixture -def app(): +def app(db): TEST_DEPS = { "authz_client": MockAuthzClient("authz"), "requests_client": MockRequestsClient("requests"), "authnid_client": MockApiClient("authnid"), "fundz_client": MockFundzClient("fundz"), "sessions": DictSessions(), + "db_session": db } config = make_config() @@ -21,6 +24,26 @@ def app(): return make_app(config, deps) + +@pytest.fixture(scope='function') +def db(): + + # Override db with a new SQLAlchemy session so that we can rollback + # each test's transaction. + # Inspiration: https://docs.sqlalchemy.org/en/latest/orm/session_transaction.html#session-external-transaction + config = make_config() + database = make_db(config) + connection = database.get_bind().connect() + transaction = connection.begin() + db = scoped_session(sessionmaker(bind=connection)) + + yield db + + db.close() + transaction.rollback() + connection.close() + + class DummyForm(dict): pass diff --git a/tests/domain/__init__.py b/tests/domain/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/domain/test_requests.py b/tests/domain/test_requests.py new file mode 100644 index 00000000..f45d5be2 --- /dev/null +++ b/tests/domain/test_requests.py @@ -0,0 +1,55 @@ +import pytest +from uuid import uuid4 + +from atst.domain.exceptions import NotFoundError +from atst.domain.requests import Requests + +from tests.factories import RequestFactory + + +@pytest.fixture() +def requests(db): + return Requests(db) + +@pytest.fixture(scope="function") +def new_request(db): + created_request = RequestFactory.create() + db.add(created_request) + db.commit() + + return created_request + + +def test_can_get_request(requests, new_request): + request = requests.get(new_request.id) + + assert request.id == new_request.id + + +def test_nonexistent_request_raises(requests): + with pytest.raises(NotFoundError): + requests.get(uuid4()) + + +@pytest.mark.gen_test +def test_auto_approve_less_than_1m(requests, new_request): + new_request.body = {"details_of_use": {"dollar_value": 999999}} + request = yield requests.submit(new_request) + + assert request.status == 'approved' + + +@pytest.mark.gen_test +def test_dont_auto_approve_if_dollar_value_is_1m_or_above(requests, new_request): + new_request.body = {"details_of_use": {"dollar_value": 1000000}} + request = yield requests.submit(new_request) + + assert request.status == 'submitted' + + +@pytest.mark.gen_test +def test_dont_auto_approve_if_no_dollar_value_specified(requests, new_request): + new_request.body = {"details_of_use": {}} + request = yield requests.submit(new_request) + + assert request.status == 'submitted' diff --git a/tests/factories.py b/tests/factories.py new file mode 100644 index 00000000..b07b36fa --- /dev/null +++ b/tests/factories.py @@ -0,0 +1,11 @@ +import factory +from uuid import uuid4 + +from atst.models import Request, RequestStatusEvent + + +class RequestFactory(factory.Factory): + class Meta: + model = Request + + id = factory.Sequence(lambda x: uuid4())