diff --git a/atst/domain/requests.py b/atst/domain/requests.py index 23290085..3d50d887 100644 --- a/atst/domain/requests.py +++ b/atst/domain/requests.py @@ -1,4 +1,6 @@ +from enum import Enum from sqlalchemy import exists, and_, exc +from sqlalchemy.sql import text from sqlalchemy.orm.exc import NoResultFound from sqlalchemy.orm.attributes import flag_modified @@ -154,3 +156,21 @@ class Requests(object): @classmethod def is_pending_ccpo_approval(cls, request): return request.status == RequestStatus.PENDING_CCPO_APPROVAL + + @classmethod + def count_status(self, status): + raw = text(""" +SELECT count(requests_with_status.id) +FROM ( + SELECT DISTINCT ON (rse.request_id) r.*, rse.new_status as status + FROM request_status_events rse JOIN requests r ON r.id = rse.request_id + ORDER BY rse.request_id, rse.sequence DESC +) as requests_with_status +WHERE requests_with_status.status = :status; + """) + if isinstance(status, Enum): + status = status.name + results = db.session.execute(raw, {"status": status}).fetchone() + (count,) = results + return count + diff --git a/tests/conftest.py b/tests/conftest.py index 12e7c320..e5094c1b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -69,6 +69,7 @@ def session(db, request): ] for factory in factory_list: factory._meta.sqlalchemy_session = session + factory._meta.sqlalchemy_session_persistence = "commit" yield session diff --git a/tests/domain/test_requests.py b/tests/domain/test_requests.py index 2044d52b..420294e2 100644 --- a/tests/domain/test_requests.py +++ b/tests/domain/test_requests.py @@ -3,9 +3,10 @@ from uuid import uuid4 from atst.domain.exceptions import NotFoundError from atst.domain.requests import Requests +from atst.models.request import Request from atst.models.request_status_event import RequestStatus -from tests.factories import RequestFactory, UserFactory +from tests.factories import RequestFactory, UserFactory, RequestStatusEventFactory @pytest.fixture(scope="function") @@ -63,3 +64,15 @@ def test_exists(session): request = RequestFactory.create(creator=user_allowed) assert Requests.exists(request.id, user_allowed) assert not Requests.exists(request.id, user_denied) + + +def test_count_status(session): + # make sure table is empty + session.query(Request).delete() + + request1 = RequestFactory.create() + request2 = RequestFactory.create() + RequestStatusEventFactory.create(sequence=2, request_id=request2.id, new_status=RequestStatus.PENDING_FINANCIAL_VERIFICATION) + + assert Requests.count_status(RequestStatus.PENDING_FINANCIAL_VERIFICATION) == 1 + assert Requests.count_status(RequestStatus.STARTED) == 1 diff --git a/tests/factories.py b/tests/factories.py index 3e133c45..ad12ef56 100644 --- a/tests/factories.py +++ b/tests/factories.py @@ -39,6 +39,7 @@ class RequestStatusEventFactory(factory.alchemy.SQLAlchemyModelFactory): model = RequestStatusEvent id = factory.Sequence(lambda x: uuid4()) + sequence = 1 class RequestFactory(factory.alchemy.SQLAlchemyModelFactory):