diff --git a/atst/domain/workspace_users.py b/atst/domain/workspace_users.py index 9ff139cd..857be057 100644 --- a/atst/domain/workspace_users.py +++ b/atst/domain/workspace_users.py @@ -1,17 +1,23 @@ from sqlalchemy.orm.exc import NoResultFound from sqlalchemy.dialects.postgresql import insert -from atst.models import User, WorkspaceRole, Role +from atst.models.workspace_role import WorkspaceRole +from atst.models.workspace_user import WorkspaceUser +from atst.models.user import User +from .roles import Roles +from .users import Users from .exceptions import NotFoundError class WorkspaceUsers(object): def __init__(self, db_session): self.db_session = db_session + self.roles_repo = Roles(db_session) + self.users_repo = Users(db_session) def get(self, workspace_id, user_id): try: - user = User.query.filter_by(id=user_id).one() + user = self.users_repo.get(user_id) except NoResultFound: raise NotFoundError("user") @@ -31,18 +37,18 @@ class WorkspaceUsers(object): for user_dict in workspace_user_dicts: try: - user = User.query.filter_by(id=user_dict["id"]).one() + user = self.users_repo.get(user_dict["id"]) except NoResultFound: - default_role = Role.query.filter_by(name="developer").one_or_none() + default_role = self.roles_repo.get("developer") user = User(id=user_dict["id"], atat_role=default_role) try: - role = Role.query.filter_by(name=user_dict["workspace_role"]).one() + role = self.roles_repo.get(user_dict["workspace_role"]) except NoResultFound: raise NotFoundError("role") try: - existing_workspace_role = WorkspaceRole.query.filter( + existing_workspace_role = self.db_session.query(WorkspaceRole).filter( WorkspaceRole.user == user, WorkspaceRole.workspace_id == workspace_id, ).one() diff --git a/tests/domain/test_workspace_users.py b/tests/domain/test_workspace_users.py new file mode 100644 index 00000000..e86662f3 --- /dev/null +++ b/tests/domain/test_workspace_users.py @@ -0,0 +1,43 @@ +import pytest +from uuid import uuid4 + +from atst.domain.workspace_users import WorkspaceUsers +from atst.domain.users import Users + + +@pytest.fixture() +def users_repo(db): + return Users(db) + +@pytest.fixture() +def workspace_users_repo(db): + return WorkspaceUsers(db) + +def test_can_create_new_workspace_user(users_repo, workspace_users_repo): + workspace_id = uuid4() + user = users_repo.create(uuid4(), "developer") + + workspace_user_dicts = [ + {"id": user.id, "workspace_role": "owner"} + ] + + workspace_users = workspace_users_repo.add_many(workspace_id, workspace_user_dicts) + + assert workspace_users[0].user.id == user.id + assert workspace_users[0].user.atat_role.name == "developer" + assert workspace_users[0].workspace_role.role.name == "owner" + + +def test_can_update_existing_workspace_user(users_repo, workspace_users_repo): + workspace_id = uuid4() + user = users_repo.create(uuid4(), "developer") + + workspace_users_repo.add_many(workspace_id, [ + {"id": user.id, "workspace_role": "owner"} + ]) + workspace_users = workspace_users_repo.add_many(workspace_id, [ + {"id": user.id, "workspace_role": "developer"} + ]) + + assert workspace_users[0].user.id == user.id + assert workspace_users[0].workspace_role.role.name == "developer"