Refactor RequiredIf validator

This commit is contained in:
Montana 2019-01-20 09:47:50 -05:00
parent 9eca3c6acc
commit d51663e075
4 changed files with 27 additions and 76 deletions

View File

@ -10,7 +10,7 @@ from wtforms.fields import (
) )
from wtforms.fields.html5 import DateField, TelField from wtforms.fields.html5 import DateField, TelField
from wtforms.widgets import ListWidget, CheckboxInput from wtforms.widgets import ListWidget, CheckboxInput
from wtforms.validators import Required, Length from wtforms.validators import Length
from atst.forms.validators import IsNumber, PhoneNumber, RequiredIf from atst.forms.validators import IsNumber, PhoneNumber, RequiredIf
@ -117,7 +117,11 @@ class OversightForm(CacheableForm):
) )
ko_dod_id = StringField( ko_dod_id = StringField(
translate("forms.task_order.oversight_dod_id_label"), translate("forms.task_order.oversight_dod_id_label"),
validators=[RequiredIf("ko_invite"), Length(min=10), IsNumber()], validators=[
RequiredIf(lambda form: form._fields.get("ko_invite").data),
Length(min=10),
IsNumber(),
],
) )
am_cor = BooleanField(translate("forms.task_order.oversight_am_cor_label")) am_cor = BooleanField(translate("forms.task_order.oversight_am_cor_label"))
@ -128,13 +132,16 @@ class OversightForm(CacheableForm):
cor_email = StringField(translate("forms.task_order.oversight_email_label")) cor_email = StringField(translate("forms.task_order.oversight_email_label"))
cor_phone_number = TelField( cor_phone_number = TelField(
translate("forms.task_order.oversight_phone_label"), translate("forms.task_order.oversight_phone_label"),
validators=[RequiredIf("am_cor", False), PhoneNumber()], validators=[
RequiredIf(lambda form: not form._fields.get("am_cor").data),
PhoneNumber(),
],
) )
cor_dod_id = StringField( cor_dod_id = StringField(
translate("forms.task_order.oversight_dod_id_label"), translate("forms.task_order.oversight_dod_id_label"),
validators=[ validators=[
RequiredIf("am_cor", False), RequiredIf(lambda form: not form._fields.get("am_cor").data),
RequiredIf("cor_invite"), RequiredIf(lambda form: form._fields.get("cor_invite").data),
Length(min=10), Length(min=10),
IsNumber(), IsNumber(),
], ],
@ -150,7 +157,11 @@ class OversightForm(CacheableForm):
) )
so_dod_id = StringField( so_dod_id = StringField(
translate("forms.task_order.oversight_dod_id_label"), translate("forms.task_order.oversight_dod_id_label"),
validators=[RequiredIf("so_invite"), Length(min=10), IsNumber()], validators=[
RequiredIf(lambda form: form._fields.get("so_invite").data),
Length(min=10),
IsNumber(),
],
) )
ko_invite = BooleanField( ko_invite = BooleanField(

View File

@ -80,32 +80,19 @@ def ListItemsUnique(message=translate("forms.validators.list_items_unique_messag
return _list_items_unique return _list_items_unique
def RequiredIf( def RequiredIf(other_field, message=translate("forms.validators.is_required")):
other_field_name, checked=True, message=translate("forms.validators.is_required")
):
""" A validator which makes a field required only if another field """ A validator which makes a field required only if another field
has a truthy value has a truthy value
Args: Args:
other_field_name (str): the name of the field we check before other_field_value (function): calling this on form results in
determining if this field is required the boolean value of another field that we want to check against;
checked (bool): the value of other_field_name that we want to check against; if it's True, we require the field
if checked is True, we require the field if other_field_name's field value
is truthy; if checked is False, we require the field if other_field_name's
field value is falsy
message (str): an optional message to display if the field is message (str): an optional message to display if the field is
required but hasNone value required but hasNone value
""" """
def _required_if(form, field): def _required_if(form, field):
other_field = form._fields.get(other_field_name) if other_field(form):
if other_field is None:
raise Exception('no field named "%s" in form' % self.other_field_name)
field_required = (
bool(other_field.data) if checked else not bool(other_field.data)
)
if field_required:
if field.data is None: if field.data is None:
raise ValidationError(message) raise ValidationError(message)
else: else:

View File

@ -103,15 +103,6 @@ def dummy_form():
return DummyForm() return DummyForm()
@pytest.fixture
def dummy_form_with_field():
def set_field(name, value):
data = DummyField(data=value, name=name)
return DummyForm(data=OrderedDict({name: data}))
return set_field
@pytest.fixture @pytest.fixture
def dummy_field(): def dummy_field():
return DummyField() return DummyField()

View File

@ -82,56 +82,18 @@ class TestListItemsUnique:
class TestRequiredIf: class TestRequiredIf:
def test_RequiredIf_requires_field_if_arg_is_truthy( def test_RequiredIf_requires_field_if_arg_is_truthy(self, dummy_form, dummy_field):
self, dummy_form_with_field, dummy_field validator = RequiredIf(lambda form: True)
):
form = dummy_form_with_field("arg", True)
validator = RequiredIf("arg")
dummy_field.data = None dummy_field.data = None
with pytest.raises(ValidationError): with pytest.raises(ValidationError):
validator(form, dummy_field) validator(dummy_form, dummy_field)
def test_RequiredIf_does_not_require_field_if_arg_is_falsy( def test_RequiredIf_does_not_require_field_if_arg_is_falsy(
self, dummy_form_with_field, dummy_field self, dummy_form, dummy_field
): ):
form = dummy_form_with_field("arg", False) validator = RequiredIf(lambda form: False)
validator = RequiredIf("arg")
dummy_field.data = None dummy_field.data = None
with pytest.raises(StopValidation): with pytest.raises(StopValidation):
validator(form, dummy_field)
def test_RequiredIf_arg_is_None_raises_error(self, dummy_form, dummy_field):
validator = RequiredIf("arg")
dummy_field.data = "some data"
with pytest.raises(Exception):
validator(dummy_form, dummy_field)
def test_not_RequiredIf_requires_field_if_arg_is_falsy(
self, dummy_form_with_field, dummy_field
):
form = dummy_form_with_field("arg", False)
validator = RequiredIf("arg", False)
dummy_field.data = None
with pytest.raises(ValidationError):
validator(form, dummy_field)
def test_not_RequiredIf_does_not_require_field_if_arg_is_truthy(
self, dummy_form_with_field, dummy_field
):
form = dummy_form_with_field("arg", True)
validator = RequiredIf("arg", False)
dummy_field.data = None
with pytest.raises(StopValidation):
validator(form, dummy_field)
def test_not_RequiredIf_arg_is_None_raises_error(self, dummy_form, dummy_field):
validator = RequiredIf("arg", False)
dummy_field.data = "some data"
with pytest.raises(Exception):
validator(dummy_form, dummy_field) validator(dummy_form, dummy_field)