diff --git a/atst/forms/task_order.py b/atst/forms/task_order.py index 93268aae..961c45de 100644 --- a/atst/forms/task_order.py +++ b/atst/forms/task_order.py @@ -10,7 +10,7 @@ from wtforms.fields import ( ) from wtforms.fields.html5 import DateField, TelField from wtforms.widgets import ListWidget, CheckboxInput -from wtforms.validators import Required, Length +from wtforms.validators import Length from atst.forms.validators import IsNumber, PhoneNumber, RequiredIf @@ -117,7 +117,11 @@ class OversightForm(CacheableForm): ) ko_dod_id = StringField( translate("forms.task_order.oversight_dod_id_label"), - validators=[RequiredIf("ko_invite"), Length(min=10), IsNumber()], + validators=[ + RequiredIf(lambda form: form._fields.get("ko_invite").data), + Length(min=10), + IsNumber(), + ], ) am_cor = BooleanField(translate("forms.task_order.oversight_am_cor_label")) @@ -128,13 +132,16 @@ class OversightForm(CacheableForm): cor_email = StringField(translate("forms.task_order.oversight_email_label")) cor_phone_number = TelField( translate("forms.task_order.oversight_phone_label"), - validators=[RequiredIf("am_cor", False), PhoneNumber()], + validators=[ + RequiredIf(lambda form: not form._fields.get("am_cor").data), + PhoneNumber(), + ], ) cor_dod_id = StringField( translate("forms.task_order.oversight_dod_id_label"), validators=[ - RequiredIf("am_cor", False), - RequiredIf("cor_invite"), + RequiredIf(lambda form: not form._fields.get("am_cor").data), + RequiredIf(lambda form: form._fields.get("cor_invite").data), Length(min=10), IsNumber(), ], @@ -150,7 +157,11 @@ class OversightForm(CacheableForm): ) so_dod_id = StringField( translate("forms.task_order.oversight_dod_id_label"), - validators=[RequiredIf("so_invite"), Length(min=10), IsNumber()], + validators=[ + RequiredIf(lambda form: form._fields.get("so_invite").data), + Length(min=10), + IsNumber(), + ], ) ko_invite = BooleanField( diff --git a/atst/forms/validators.py b/atst/forms/validators.py index 716225e7..c239c12e 100644 --- a/atst/forms/validators.py +++ b/atst/forms/validators.py @@ -80,32 +80,19 @@ def ListItemsUnique(message=translate("forms.validators.list_items_unique_messag return _list_items_unique -def RequiredIf( - other_field_name, checked=True, message=translate("forms.validators.is_required") -): +def RequiredIf(other_field, message=translate("forms.validators.is_required")): """ A validator which makes a field required only if another field has a truthy value Args: - other_field_name (str): the name of the field we check before - determining if this field is required - checked (bool): the value of other_field_name that we want to check against; - if checked is True, we require the field if other_field_name's field value - is truthy; if checked is False, we require the field if other_field_name's - field value is falsy + other_field_value (function): calling this on form results in + the boolean value of another field that we want to check against; + if it's True, we require the field message (str): an optional message to display if the field is required but hasNone value """ def _required_if(form, field): - other_field = form._fields.get(other_field_name) - if other_field is None: - raise Exception('no field named "%s" in form' % self.other_field_name) - - field_required = ( - bool(other_field.data) if checked else not bool(other_field.data) - ) - - if field_required: + if other_field(form): if field.data is None: raise ValidationError(message) else: diff --git a/tests/conftest.py b/tests/conftest.py index 781fd515..def8bf93 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -103,15 +103,6 @@ def dummy_form(): return DummyForm() -@pytest.fixture -def dummy_form_with_field(): - def set_field(name, value): - data = DummyField(data=value, name=name) - return DummyForm(data=OrderedDict({name: data})) - - return set_field - - @pytest.fixture def dummy_field(): return DummyField() diff --git a/tests/forms/test_validators.py b/tests/forms/test_validators.py index 267ddbc6..fe12cd58 100644 --- a/tests/forms/test_validators.py +++ b/tests/forms/test_validators.py @@ -82,56 +82,18 @@ class TestListItemsUnique: class TestRequiredIf: - def test_RequiredIf_requires_field_if_arg_is_truthy( - self, dummy_form_with_field, dummy_field - ): - form = dummy_form_with_field("arg", True) - validator = RequiredIf("arg") + def test_RequiredIf_requires_field_if_arg_is_truthy(self, dummy_form, dummy_field): + validator = RequiredIf(lambda form: True) dummy_field.data = None with pytest.raises(ValidationError): - validator(form, dummy_field) + validator(dummy_form, dummy_field) def test_RequiredIf_does_not_require_field_if_arg_is_falsy( - self, dummy_form_with_field, dummy_field + self, dummy_form, dummy_field ): - form = dummy_form_with_field("arg", False) - validator = RequiredIf("arg") + validator = RequiredIf(lambda form: False) dummy_field.data = None with pytest.raises(StopValidation): - validator(form, dummy_field) - - def test_RequiredIf_arg_is_None_raises_error(self, dummy_form, dummy_field): - validator = RequiredIf("arg") - dummy_field.data = "some data" - - with pytest.raises(Exception): - validator(dummy_form, dummy_field) - - def test_not_RequiredIf_requires_field_if_arg_is_falsy( - self, dummy_form_with_field, dummy_field - ): - form = dummy_form_with_field("arg", False) - validator = RequiredIf("arg", False) - dummy_field.data = None - - with pytest.raises(ValidationError): - validator(form, dummy_field) - - def test_not_RequiredIf_does_not_require_field_if_arg_is_truthy( - self, dummy_form_with_field, dummy_field - ): - form = dummy_form_with_field("arg", True) - validator = RequiredIf("arg", False) - dummy_field.data = None - - with pytest.raises(StopValidation): - validator(form, dummy_field) - - def test_not_RequiredIf_arg_is_None_raises_error(self, dummy_form, dummy_field): - validator = RequiredIf("arg", False) - dummy_field.data = "some data" - - with pytest.raises(Exception): validator(dummy_form, dummy_field)