diff --git a/atst/domain/audit_log.py b/atst/domain/audit_log.py index 238cff2e..4e7abca9 100644 --- a/atst/domain/audit_log.py +++ b/atst/domain/audit_log.py @@ -8,8 +8,9 @@ class AuditEventQuery(Query): model = AuditEvent @classmethod - def get_all(cls): - return db.session.query(cls.model).order_by(cls.model.time_created.desc()).all() + def get_all(cls, pagination): + query = db.session.query(cls.model).order_by(cls.model.time_created.desc()) + return cls.paginate(query, pagination) class AuditLog(object): @@ -28,11 +29,11 @@ class AuditLog(object): return cls._log(resource=resource, action=action) @classmethod - def get_all_events(cls, user): + def get_all_events(cls, user, pagination=None): Authorization.check_atat_permission( user, Permissions.VIEW_AUDIT_LOG, "view audit log" ) - return AuditEventQuery.get_all() + return AuditEventQuery.get_all(pagination) @classmethod def _resource_type(cls, resource): diff --git a/atst/domain/common/query.py b/atst/domain/common/query.py index 4f55d6c0..8f554268 100644 --- a/atst/domain/common/query.py +++ b/atst/domain/common/query.py @@ -5,6 +5,37 @@ from atst.domain.exceptions import NotFoundError from atst.database import db +class Paginator(object): + """ + Uses the Flask-SQLAlchemy extension's pagination method to paginate + a query set. + + Also acts as a proxy object so that the results of the query set can be iterated + over without needing to call `.items`. + """ + + def __init__(self, query_set): + self.query_set = query_set + + @classmethod + def paginate(cls, query, pagination=None): + if pagination is not None: + return cls( + query.paginate(page=pagination["page"], per_page=pagination["per_page"]) + ) + else: + return query.all() + + def __getattr__(self, name): + return getattr(self.query_set, name) + + def __iter__(self): + return self.items.__iter__() + + def __len__(self): + return self.items.__len__() + + class Query(object): model = None @@ -35,3 +66,7 @@ class Query(object): db.session.add(resource) db.session.commit() return resource + + @classmethod + def paginate(cls, query, pagination): + return Paginator.paginate(query, pagination) diff --git a/tests/domain/test_audit_log.py b/tests/domain/test_audit_log.py index 1a2d95f7..387da03c 100644 --- a/tests/domain/test_audit_log.py +++ b/tests/domain/test_audit_log.py @@ -22,3 +22,12 @@ def test_non_admin_cannot_view_audit_log(developer): def test_ccpo_can_iview_audit_log(ccpo): AuditLog.get_all_events(ccpo) + + +def test_paginate_audit_log(ccpo): + user = UserFactory.create() + for _ in range(100): + AuditLog.log_system_event(user, action="create") + + events = AuditLog.get_all_events(ccpo, pagination={"per_page": 25, "page": 2}) + assert len(events) == 25