Refactor RequiredIfNot custom validator, add tests

This commit is contained in:
Montana 2019-01-15 16:44:31 -05:00
parent 5d4fee9546
commit ae494d3bb5
6 changed files with 16213 additions and 26 deletions

View File

@ -10,9 +10,9 @@ from wtforms.fields import (
) )
from wtforms.fields.html5 import DateField, TelField from wtforms.fields.html5 import DateField, TelField
from wtforms.widgets import ListWidget, CheckboxInput from wtforms.widgets import ListWidget, CheckboxInput
from wtforms.validators import Required, Length, StopValidation from wtforms.validators import Required, Length
from atst.forms.validators import IsNumber, PhoneNumber from atst.forms.validators import IsNumber, PhoneNumber, RequiredIfNot
from .forms import CacheableForm from .forms import CacheableForm
from .data import ( from .data import (
@ -26,24 +26,6 @@ from .data import (
from atst.utils.localization import translate from atst.utils.localization import translate
class RequiredIfNot(Required):
# a validator which makes a field required only if
# another field has a falsy value
def __init__(self, other_field_name, *args, **kwargs):
self.other_field_name = other_field_name
super(RequiredIfNot, self).__init__(*args, **kwargs)
def __call__(self, form, field):
other_field = form._fields.get(self.other_field_name)
if other_field is None:
raise Exception('no field named "%s" in form' % self.other_field_name)
if not bool(other_field.data):
super(RequiredIfNot, self).__call__(form, field)
else:
raise StopValidation()
class AppInfoForm(CacheableForm): class AppInfoForm(CacheableForm):
portfolio_name = StringField( portfolio_name = StringField(
translate("forms.task_order.portfolio_name_label"), translate("forms.task_order.portfolio_name_label"),

View File

@ -1,5 +1,5 @@
import re import re
from wtforms.validators import ValidationError from wtforms.validators import ValidationError, StopValidation
import pendulum import pendulum
from datetime import datetime from datetime import datetime
from atst.utils.localization import translate from atst.utils.localization import translate
@ -78,3 +78,18 @@ def ListItemsUnique(message=translate("forms.validators.list_items_unique_messag
raise ValidationError(message) raise ValidationError(message)
return _list_items_unique return _list_items_unique
def RequiredIfNot(other_field_name, message=translate("forms.validators.is_required")):
def _required_if_not(form, field):
other_field = form._fields.get(other_field_name)
if other_field is None:
raise Exception('no field named "%s" in form' % self.other_field_name)
if not bool(other_field.data):
if field.data is None:
raise ValidationError(message)
else:
raise StopValidation()
return _required_if_not

16141
package-lock.json generated Normal file

File diff suppressed because it is too large Load Diff

View File

@ -6,6 +6,7 @@ import alembic.command
from logging.config import dictConfig from logging.config import dictConfig
from werkzeug.datastructures import FileStorage from werkzeug.datastructures import FileStorage
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from collections import OrderedDict
from atst.app import make_app, make_config from atst.app import make_app, make_config
from atst.database import db as _db from atst.database import db as _db
@ -84,14 +85,17 @@ def session(db, request):
class DummyForm(dict): class DummyForm(dict):
pass def __init__(self, data=OrderedDict(), errors=(), raw_data=None):
self._fields = data
self.errors = list(errors)
class DummyField(object): class DummyField(object):
def __init__(self, data=None, errors=(), raw_data=None): def __init__(self, data=None, errors=(), raw_data=None, name=None):
self.data = data self.data = data
self.errors = list(errors) self.errors = list(errors)
self.raw_data = raw_data self.raw_data = raw_data
self.name = name
@pytest.fixture @pytest.fixture
@ -99,6 +103,15 @@ def dummy_form():
return DummyForm() return DummyForm()
@pytest.fixture
def dummy_form_with_field():
def set_field(name, value):
data = DummyField(data=value, name=name)
return DummyForm(data=OrderedDict({name: data}))
return set_field
@pytest.fixture @pytest.fixture
def dummy_field(): def dummy_field():
return DummyField() return DummyField()

View File

@ -1,7 +1,13 @@
from wtforms.validators import ValidationError from wtforms.validators import ValidationError, StopValidation
import pytest import pytest, copy
from atst.forms.validators import Name, IsNumber, PhoneNumber, ListItemsUnique from atst.forms.validators import (
Name,
IsNumber,
PhoneNumber,
ListItemsUnique,
RequiredIfNot,
)
class TestIsNumber: class TestIsNumber:
@ -73,3 +79,32 @@ class TestListItemsUnique:
dummy_field.data = invalid dummy_field.data = invalid
with pytest.raises(ValidationError): with pytest.raises(ValidationError):
validator(dummy_form, dummy_field) validator(dummy_form, dummy_field)
class TestRequiredIfNot:
def test_RequiredIfNot_requires_field_if_arg_is_falsy(
self, dummy_form_with_field, dummy_field
):
form = dummy_form_with_field("arg", False)
validator = RequiredIfNot("arg")
dummy_field.data = None
with pytest.raises(ValidationError):
validator(form, dummy_field)
def test_RequiredIfNot_does_not_require_field_if_arg_is_truthy(
self, dummy_form_with_field, dummy_field
):
form = dummy_form_with_field("arg", True)
validator = RequiredIfNot("arg")
dummy_field.data = None
with pytest.raises(StopValidation):
validator(form, dummy_field)
def test_RequiredIfNot_arg_is_None_raises_error(self, dummy_form, dummy_field):
validator = RequiredIfNot("arg")
dummy_field.data = "some data"
with pytest.raises(Exception):
validator(dummy_form, dummy_field)

View File

@ -209,6 +209,7 @@ forms:
list_items_unique_message: Items must be unique list_items_unique_message: Items must be unique
name_message: 'This field accepts letters, numbers, commas, apostrophes, hyphens, and periods.' name_message: 'This field accepts letters, numbers, commas, apostrophes, hyphens, and periods.'
phone_number_message: Please enter a valid 5 or 10 digit phone number. phone_number_message: Please enter a valid 5 or 10 digit phone number.
is_required: This field is required.
portfolio: portfolio:
name_label: Portfolio Name name_label: Portfolio Name
name_length_validation_message: Portfolio names must be at least 4 and not more than 50 characters name_length_validation_message: Portfolio names must be at least 4 and not more than 50 characters