Make claim_for_update easier to follow

This commit is contained in:
richard-dds
2019-09-16 17:03:57 -04:00
parent 4624acd1c5
commit 67a2905d51

View File

@@ -2,6 +2,7 @@ import pendulum
import pytest import pytest
from uuid import uuid4 from uuid import uuid4
from unittest.mock import Mock from unittest.mock import Mock
from threading import Thread
from atst.models import Environment from atst.models import Environment
from atst.domain.csp.cloud import MockCloudProvider from atst.domain.csp.cloud import MockCloudProvider
@@ -25,8 +26,6 @@ from tests.factories import (
PortfolioFactory, PortfolioFactory,
) )
from threading import Thread
def test_environment_job_failure(celery_app, celery_worker): def test_environment_job_failure(celery_app, celery_worker):
@celery_app.task(bind=True, base=RecordEnvironmentFailure) @celery_app.task(bind=True, base=RecordEnvironmentFailure)
@@ -239,7 +238,7 @@ def test_create_environment_no_dupes(session, celery_app, celery_worker):
assert environment.claimed_at == None assert environment.claimed_at == None
def test_claim(session): def test_claim_for_update(session):
portfolio = PortfolioFactory.create( portfolio = PortfolioFactory.create(
applications=[ applications=[
{ {
@@ -265,19 +264,21 @@ def test_claim(session):
) )
environment = portfolio.applications[0].environments[0] environment = portfolio.applications[0].environments[0]
events = [] satisfied_claims = []
# Two threads that race to acquire a claim on the same environment.
# SecondThread's claim will be rejected, which will result in a ClaimFailedException.
class FirstThread(Thread): class FirstThread(Thread):
def run(self): def run(self):
with claim_for_update(environment): with claim_for_update(environment):
events.append("first") satisfied_claims.append("FirstThread")
class SecondThread(Thread): class SecondThread(Thread):
def run(self): def run(self):
try: try:
with claim_for_update(environment): with claim_for_update(environment):
events.append("second") satisfied_claims.append("SecondThread")
except Exception: except ClaimFailedException:
pass pass
t1 = FirstThread() t1 = FirstThread()
@@ -287,4 +288,4 @@ def test_claim(session):
t1.join() t1.join()
t2.join() t2.join()
assert events == ["first"] assert satisfied_claims == ["FirstThread"]