Function for claiming multiple resources at once.

Like claim_for_update, the claim_many_for_update claims resources with
an expiring lock. This was written to allow the updating of multiple
application roles with a single cloud_id, since multiple application
roles will map to a single Azure Active Directory user.
This commit is contained in:
dandds 2020-02-01 12:07:36 -05:00
parent 1b45502fe5
commit cc28f53999
3 changed files with 155 additions and 64 deletions

View File

@ -1,3 +1,5 @@
from typing import List
from sqlalchemy import func, sql, Interval, and_, or_ from sqlalchemy import func, sql, Interval, and_, or_
from contextlib import contextmanager from contextlib import contextmanager
@ -28,7 +30,7 @@ def claim_for_update(resource, minutes=30):
.filter( .filter(
and_( and_(
Model.id == resource.id, Model.id == resource.id,
or_(Model.claimed_until == None, Model.claimed_until <= func.now()), or_(Model.claimed_until.is_(None), Model.claimed_until <= func.now()),
) )
) )
.update({"claimed_until": claim_until}, synchronize_session="fetch") .update({"claimed_until": claim_until}, synchronize_session="fetch")
@ -48,3 +50,51 @@ def claim_for_update(resource, minutes=30):
Model.claimed_until != None Model.claimed_until != None
).update({"claimed_until": None}, synchronize_session="fetch") ).update({"claimed_until": None}, synchronize_session="fetch")
db.session.commit() db.session.commit()
@contextmanager
def claim_many_for_update(resources: List, minutes=30):
"""
Claim a mutually exclusive expiring hold on a group of resources.
Uses the database as a central source of time in case the server clocks have drifted.
Args:
resources: A list of SQLAlchemy model instances with a `claimed_until` attribute.
minutes: The maximum amount of time, in minutes, to hold the claim.
"""
Model = resources[0].__class__
claim_until = func.now() + func.cast(
sql.functions.concat(minutes, " MINUTES"), Interval
)
ids = tuple(r.id for r in resources)
# Optimistically query for and update the resources in question. If they're
# already claimed, `rows_updated` will be 0 and we can give up.
rows_updated = (
db.session.query(Model)
.filter(
and_(
Model.id.in_(ids),
or_(Model.claimed_until.is_(None), Model.claimed_until <= func.now()),
)
)
.update({"claimed_until": claim_until}, synchronize_session="fetch")
)
if rows_updated < 1:
# TODO: Generalize this exception class so it can take multiple resources
raise ClaimFailedException(resources[0])
# Fetch the claimed resources
claimed = db.session.query(Model).filter(Model.id.in_(ids)).all()
try:
# Give the resource to the caller.
yield claimed
finally:
# Release the claim.
db.session.query(Model).filter(Model.id.in_(ids)).filter(
Model.claimed_until != None
).update({"claimed_until": None}, synchronize_session="fetch")
db.session.commit()

104
tests/models/test_utils.py Normal file
View File

@ -0,0 +1,104 @@
from threading import Thread
from atst.domain.exceptions import ClaimFailedException
from atst.models.utils import claim_for_update, claim_many_for_update
from tests.factories import EnvironmentFactory
def test_claim_for_update(session):
environment = EnvironmentFactory.create()
satisfied_claims = []
exceptions = []
# Two threads race to do work on environment and check out the lock
class FirstThread(Thread):
def run(self):
try:
with claim_for_update(environment) as env:
assert env.claimed_until
satisfied_claims.append("FirstThread")
except ClaimFailedException:
exceptions.append("FirstThread")
class SecondThread(Thread):
def run(self):
try:
with claim_for_update(environment) as env:
assert env.claimed_until
satisfied_claims.append("SecondThread")
except ClaimFailedException:
exceptions.append("SecondThread")
t1 = FirstThread()
t2 = SecondThread()
t1.start()
t2.start()
t1.join()
t2.join()
session.refresh(environment)
assert len(satisfied_claims) == 1
assert len(exceptions) == 1
if satisfied_claims == ["FirstThread"]:
assert exceptions == ["SecondThread"]
else:
assert satisfied_claims == ["SecondThread"]
assert exceptions == ["FirstThread"]
# The claim is released
assert environment.claimed_until is None
def test_claim_many_for_update(session):
environments = [
EnvironmentFactory.create(),
EnvironmentFactory.create(),
]
satisfied_claims = []
exceptions = []
# Two threads race to do work on environment and check out the lock
class FirstThread(Thread):
def run(self):
try:
with claim_many_for_update(environments) as envs:
assert all([e.claimed_until for e in envs])
satisfied_claims.append("FirstThread")
except ClaimFailedException:
exceptions.append("FirstThread")
class SecondThread(Thread):
def run(self):
try:
with claim_many_for_update(environments) as envs:
assert all([e.claimed_until for e in envs])
satisfied_claims.append("SecondThread")
except ClaimFailedException:
exceptions.append("SecondThread")
t1 = FirstThread()
t2 = SecondThread()
t1.start()
t2.start()
t1.join()
t2.join()
for env in environments:
session.refresh(env)
assert len(satisfied_claims) == 1
assert len(exceptions) == 1
if satisfied_claims == ["FirstThread"]:
assert exceptions == ["SecondThread"]
else:
assert satisfied_claims == ["SecondThread"]
assert exceptions == ["FirstThread"]
# The claim is released
# assert environment.claimed_until is None

View File

@ -2,7 +2,6 @@ import pendulum
import pytest import pytest
from uuid import uuid4 from uuid import uuid4
from unittest.mock import Mock from unittest.mock import Mock
from threading import Thread
from atst.domain.csp.cloud import MockCloudProvider from atst.domain.csp.cloud import MockCloudProvider
from atst.domain.portfolios import Portfolios from atst.domain.portfolios import Portfolios
@ -21,8 +20,6 @@ from atst.jobs import (
do_create_application, do_create_application,
do_create_atat_admin_user, do_create_atat_admin_user,
) )
from atst.models.utils import claim_for_update
from atst.domain.exceptions import ClaimFailedException
from tests.factories import ( from tests.factories import (
EnvironmentFactory, EnvironmentFactory,
EnvironmentRoleFactory, EnvironmentRoleFactory,
@ -240,66 +237,6 @@ def test_create_environment_no_dupes(session, celery_app, celery_worker):
assert environment.claimed_until == None assert environment.claimed_until == None
def test_claim_for_update(session):
portfolio = PortfolioFactory.create(
applications=[
{"environments": [{"cloud_id": uuid4().hex, "root_user_info": {}}]}
],
task_orders=[
{
"create_clins": [
{
"start_date": pendulum.now().subtract(days=1),
"end_date": pendulum.now().add(days=1),
}
]
}
],
)
environment = portfolio.applications[0].environments[0]
satisfied_claims = []
exceptions = []
# Two threads race to do work on environment and check out the lock
class FirstThread(Thread):
def run(self):
try:
with claim_for_update(environment):
satisfied_claims.append("FirstThread")
except ClaimFailedException:
exceptions.append("FirstThread")
class SecondThread(Thread):
def run(self):
try:
with claim_for_update(environment):
satisfied_claims.append("SecondThread")
except ClaimFailedException:
exceptions.append("SecondThread")
t1 = FirstThread()
t2 = SecondThread()
t1.start()
t2.start()
t1.join()
t2.join()
session.refresh(environment)
assert len(satisfied_claims) == 1
assert len(exceptions) == 1
if satisfied_claims == ["FirstThread"]:
assert exceptions == ["SecondThread"]
else:
assert satisfied_claims == ["SecondThread"]
assert exceptions == ["FirstThread"]
# The claim is released
assert environment.claimed_until is None
def test_dispatch_provision_user(csp, session, celery_app, celery_worker, monkeypatch): def test_dispatch_provision_user(csp, session, celery_app, celery_worker, monkeypatch):
# Given that I have four environment roles: # Given that I have four environment roles: