From cc28f539990255bdc0dd24a2dc62c2c8d3485ed7 Mon Sep 17 00:00:00 2001 From: dandds Date: Sat, 1 Feb 2020 12:07:36 -0500 Subject: [PATCH] Function for claiming multiple resources at once. Like claim_for_update, the claim_many_for_update claims resources with an expiring lock. This was written to allow the updating of multiple application roles with a single cloud_id, since multiple application roles will map to a single Azure Active Directory user. --- atst/models/utils.py | 52 ++++++++++++++++++- tests/models/test_utils.py | 104 +++++++++++++++++++++++++++++++++++++ tests/test_jobs.py | 63 ---------------------- 3 files changed, 155 insertions(+), 64 deletions(-) create mode 100644 tests/models/test_utils.py diff --git a/atst/models/utils.py b/atst/models/utils.py index 6059d33e..7fba3206 100644 --- a/atst/models/utils.py +++ b/atst/models/utils.py @@ -1,3 +1,5 @@ +from typing import List + from sqlalchemy import func, sql, Interval, and_, or_ from contextlib import contextmanager @@ -28,7 +30,7 @@ def claim_for_update(resource, minutes=30): .filter( and_( Model.id == resource.id, - or_(Model.claimed_until == None, Model.claimed_until <= func.now()), + or_(Model.claimed_until.is_(None), Model.claimed_until <= func.now()), ) ) .update({"claimed_until": claim_until}, synchronize_session="fetch") @@ -48,3 +50,51 @@ def claim_for_update(resource, minutes=30): Model.claimed_until != None ).update({"claimed_until": None}, synchronize_session="fetch") db.session.commit() + + +@contextmanager +def claim_many_for_update(resources: List, minutes=30): + """ + Claim a mutually exclusive expiring hold on a group of resources. + Uses the database as a central source of time in case the server clocks have drifted. + + Args: + resources: A list of SQLAlchemy model instances with a `claimed_until` attribute. + minutes: The maximum amount of time, in minutes, to hold the claim. + """ + Model = resources[0].__class__ + + claim_until = func.now() + func.cast( + sql.functions.concat(minutes, " MINUTES"), Interval + ) + + ids = tuple(r.id for r in resources) + + # Optimistically query for and update the resources in question. If they're + # already claimed, `rows_updated` will be 0 and we can give up. + rows_updated = ( + db.session.query(Model) + .filter( + and_( + Model.id.in_(ids), + or_(Model.claimed_until.is_(None), Model.claimed_until <= func.now()), + ) + ) + .update({"claimed_until": claim_until}, synchronize_session="fetch") + ) + if rows_updated < 1: + # TODO: Generalize this exception class so it can take multiple resources + raise ClaimFailedException(resources[0]) + + # Fetch the claimed resources + claimed = db.session.query(Model).filter(Model.id.in_(ids)).all() + + try: + # Give the resource to the caller. + yield claimed + finally: + # Release the claim. + db.session.query(Model).filter(Model.id.in_(ids)).filter( + Model.claimed_until != None + ).update({"claimed_until": None}, synchronize_session="fetch") + db.session.commit() diff --git a/tests/models/test_utils.py b/tests/models/test_utils.py new file mode 100644 index 00000000..d73de13a --- /dev/null +++ b/tests/models/test_utils.py @@ -0,0 +1,104 @@ +from threading import Thread + +from atst.domain.exceptions import ClaimFailedException +from atst.models.utils import claim_for_update, claim_many_for_update + +from tests.factories import EnvironmentFactory + + +def test_claim_for_update(session): + environment = EnvironmentFactory.create() + + satisfied_claims = [] + exceptions = [] + + # Two threads race to do work on environment and check out the lock + class FirstThread(Thread): + def run(self): + try: + with claim_for_update(environment) as env: + assert env.claimed_until + satisfied_claims.append("FirstThread") + except ClaimFailedException: + exceptions.append("FirstThread") + + class SecondThread(Thread): + def run(self): + try: + with claim_for_update(environment) as env: + assert env.claimed_until + satisfied_claims.append("SecondThread") + except ClaimFailedException: + exceptions.append("SecondThread") + + t1 = FirstThread() + t2 = SecondThread() + t1.start() + t2.start() + t1.join() + t2.join() + + session.refresh(environment) + + assert len(satisfied_claims) == 1 + assert len(exceptions) == 1 + + if satisfied_claims == ["FirstThread"]: + assert exceptions == ["SecondThread"] + else: + assert satisfied_claims == ["SecondThread"] + assert exceptions == ["FirstThread"] + + # The claim is released + assert environment.claimed_until is None + + +def test_claim_many_for_update(session): + environments = [ + EnvironmentFactory.create(), + EnvironmentFactory.create(), + ] + + satisfied_claims = [] + exceptions = [] + + # Two threads race to do work on environment and check out the lock + class FirstThread(Thread): + def run(self): + try: + with claim_many_for_update(environments) as envs: + assert all([e.claimed_until for e in envs]) + satisfied_claims.append("FirstThread") + except ClaimFailedException: + exceptions.append("FirstThread") + + class SecondThread(Thread): + def run(self): + try: + with claim_many_for_update(environments) as envs: + assert all([e.claimed_until for e in envs]) + satisfied_claims.append("SecondThread") + except ClaimFailedException: + exceptions.append("SecondThread") + + t1 = FirstThread() + t2 = SecondThread() + t1.start() + t2.start() + t1.join() + t2.join() + + for env in environments: + session.refresh(env) + + assert len(satisfied_claims) == 1 + assert len(exceptions) == 1 + + if satisfied_claims == ["FirstThread"]: + assert exceptions == ["SecondThread"] + else: + assert satisfied_claims == ["SecondThread"] + assert exceptions == ["FirstThread"] + + # The claim is released + # assert environment.claimed_until is None diff --git a/tests/test_jobs.py b/tests/test_jobs.py index 2ac5f408..827aeb67 100644 --- a/tests/test_jobs.py +++ b/tests/test_jobs.py @@ -2,7 +2,6 @@ import pendulum import pytest from uuid import uuid4 from unittest.mock import Mock -from threading import Thread from atst.domain.csp.cloud import MockCloudProvider from atst.domain.portfolios import Portfolios @@ -21,8 +20,6 @@ from atst.jobs import ( do_create_application, do_create_atat_admin_user, ) -from atst.models.utils import claim_for_update -from atst.domain.exceptions import ClaimFailedException from tests.factories import ( EnvironmentFactory, EnvironmentRoleFactory, @@ -240,66 +237,6 @@ def test_create_environment_no_dupes(session, celery_app, celery_worker): assert environment.claimed_until == None -def test_claim_for_update(session): - portfolio = PortfolioFactory.create( - applications=[ - {"environments": [{"cloud_id": uuid4().hex, "root_user_info": {}}]} - ], - task_orders=[ - { - "create_clins": [ - { - "start_date": pendulum.now().subtract(days=1), - "end_date": pendulum.now().add(days=1), - } - ] - } - ], - ) - environment = portfolio.applications[0].environments[0] - - satisfied_claims = [] - exceptions = [] - - # Two threads race to do work on environment and check out the lock - class FirstThread(Thread): - def run(self): - try: - with claim_for_update(environment): - satisfied_claims.append("FirstThread") - except ClaimFailedException: - exceptions.append("FirstThread") - - class SecondThread(Thread): - def run(self): - try: - with claim_for_update(environment): - satisfied_claims.append("SecondThread") - except ClaimFailedException: - exceptions.append("SecondThread") - - t1 = FirstThread() - t2 = SecondThread() - t1.start() - t2.start() - t1.join() - t2.join() - - session.refresh(environment) - - assert len(satisfied_claims) == 1 - assert len(exceptions) == 1 - - if satisfied_claims == ["FirstThread"]: - assert exceptions == ["SecondThread"] - else: - assert satisfied_claims == ["SecondThread"] - assert exceptions == ["FirstThread"] - - # The claim is released - assert environment.claimed_until is None - - def test_dispatch_provision_user(csp, session, celery_app, celery_worker, monkeypatch): # Given that I have four environment roles: