diff --git a/atst/domain/authz.py b/atst/domain/authz.py index ae7e2b38..ce422736 100644 --- a/atst/domain/authz.py +++ b/atst/domain/authz.py @@ -17,28 +17,6 @@ class Authorization(object): def is_in_workspace(cls, user, workspace): return user in workspace.users - @classmethod - def can_view_request(cls, user, request): - if ( - Permissions.REVIEW_AND_APPROVE_JEDI_WORKSPACE_REQUEST - in user.atat_permissions - ): - return True - elif request.creator == user: - return True - - return False - - @classmethod - def check_can_approve_request(cls, user): - if ( - Permissions.REVIEW_AND_APPROVE_JEDI_WORKSPACE_REQUEST - in user.atat_permissions - ): - return True - else: - raise UnauthorizedError(user, "cannot review and approve requests") - @classmethod def check_workspace_permission(cls, user, workspace, permission, message): if not Authorization.has_workspace_permission(user, workspace, permission): diff --git a/atst/domain/requests/authorization.py b/atst/domain/requests/authorization.py new file mode 100644 index 00000000..adda0dbe --- /dev/null +++ b/atst/domain/requests/authorization.py @@ -0,0 +1,29 @@ +from atst.models.permissions import Permissions +from atst.domain.authz import Authorization +from atst.domain.exceptions import UnauthorizedError + + +class RequestsAuthorization(object): + def __init__(self, user, request): + self.user = user + self.request = request + + @property + def can_view(self): + return ( + Authorization.has_atat_permission( + self.user, Permissions.REVIEW_AND_APPROVE_JEDI_WORKSPACE_REQUEST + ) + or self.request.creator == self.user + ) + + def check_can_view(self, message): + if not self.can_view: + raise UnauthorizedError(self.user, message) + + def check_can_approve(self): + return Authorization.check_atat_permission( + self.user, + Permissions.REVIEW_AND_APPROVE_JEDI_WORKSPACE_REQUEST, + "cannot review and approve requests", + ) diff --git a/atst/domain/requests/requests.py b/atst/domain/requests/requests.py index e61df2a9..4011791c 100644 --- a/atst/domain/requests/requests.py +++ b/atst/domain/requests/requests.py @@ -1,7 +1,6 @@ from werkzeug.datastructures import FileStorage import dateutil -from atst.domain.authz import Authorization from atst.domain.task_orders import TaskOrders from atst.domain.workspaces import Workspaces from atst.models.request_revision import RequestRevision @@ -10,9 +9,8 @@ from atst.models.request_review import RequestReview from atst.models.request_internal_comment import RequestInternalComment from atst.utils import deep_merge -from atst.domain.exceptions import UnauthorizedError - from .query import RequestsQuery +from .authorization import RequestsAuthorization def create_revision_from_request_body(body): @@ -47,18 +45,13 @@ class Requests(object): @classmethod def get(cls, user, request_id): request = RequestsQuery.get(request_id) - - if not Authorization.can_view_request(user, request): - raise UnauthorizedError(user, "get request") - + RequestsAuthorization(user, request).check_can_view("get request") return request @classmethod def get_for_approval(cls, user, request_id): request = RequestsQuery.get(request_id) - - Authorization.check_can_approve_request(user) - + RequestsAuthorization(user, request).check_can_approve() return request @classmethod @@ -226,7 +219,7 @@ class Requests(object): @classmethod def add_internal_comment(cls, user, request, comment_text): - Authorization.check_can_approve_request(user) + RequestsAuthorization(user, request).check_can_approve() comment = RequestInternalComment(request=request, text=comment_text, user=user) RequestsQuery.add_and_commit(comment) return request diff --git a/tests/domain/test_authz.py b/tests/domain/test_authz.py deleted file mode 100644 index 32368cb6..00000000 --- a/tests/domain/test_authz.py +++ /dev/null @@ -1,20 +0,0 @@ -from atst.domain.authz import Authorization -from atst.domain.roles import Roles - -from tests.factories import RequestFactory, UserFactory - - -def test_creator_can_view_own_request(): - user = UserFactory.create() - request = RequestFactory.create(creator=user) - assert Authorization.can_view_request(user, request) - - other_user = UserFactory.create() - assert not Authorization.can_view_request(other_user, request) - - -def test_ccpo_user_can_view_request(): - role = Roles.get("ccpo") - ccpo_user = UserFactory.create(atat_role=role) - request = RequestFactory.create() - assert Authorization.can_view_request(ccpo_user, request)