Refactor RequiredIfNot custom validator, add tests
This commit is contained in:
parent
5d4fee9546
commit
ae494d3bb5
@ -10,9 +10,9 @@ from wtforms.fields import (
|
|||||||
)
|
)
|
||||||
from wtforms.fields.html5 import DateField, TelField
|
from wtforms.fields.html5 import DateField, TelField
|
||||||
from wtforms.widgets import ListWidget, CheckboxInput
|
from wtforms.widgets import ListWidget, CheckboxInput
|
||||||
from wtforms.validators import Required, Length, StopValidation
|
from wtforms.validators import Required, Length
|
||||||
|
|
||||||
from atst.forms.validators import IsNumber, PhoneNumber
|
from atst.forms.validators import IsNumber, PhoneNumber, RequiredIfNot
|
||||||
|
|
||||||
from .forms import CacheableForm
|
from .forms import CacheableForm
|
||||||
from .data import (
|
from .data import (
|
||||||
@ -26,24 +26,6 @@ from .data import (
|
|||||||
from atst.utils.localization import translate
|
from atst.utils.localization import translate
|
||||||
|
|
||||||
|
|
||||||
class RequiredIfNot(Required):
|
|
||||||
# a validator which makes a field required only if
|
|
||||||
# another field has a falsy value
|
|
||||||
|
|
||||||
def __init__(self, other_field_name, *args, **kwargs):
|
|
||||||
self.other_field_name = other_field_name
|
|
||||||
super(RequiredIfNot, self).__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
def __call__(self, form, field):
|
|
||||||
other_field = form._fields.get(self.other_field_name)
|
|
||||||
if other_field is None:
|
|
||||||
raise Exception('no field named "%s" in form' % self.other_field_name)
|
|
||||||
if not bool(other_field.data):
|
|
||||||
super(RequiredIfNot, self).__call__(form, field)
|
|
||||||
else:
|
|
||||||
raise StopValidation()
|
|
||||||
|
|
||||||
|
|
||||||
class AppInfoForm(CacheableForm):
|
class AppInfoForm(CacheableForm):
|
||||||
portfolio_name = StringField(
|
portfolio_name = StringField(
|
||||||
translate("forms.task_order.portfolio_name_label"),
|
translate("forms.task_order.portfolio_name_label"),
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import re
|
import re
|
||||||
from wtforms.validators import ValidationError
|
from wtforms.validators import ValidationError, StopValidation
|
||||||
import pendulum
|
import pendulum
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from atst.utils.localization import translate
|
from atst.utils.localization import translate
|
||||||
@ -78,3 +78,18 @@ def ListItemsUnique(message=translate("forms.validators.list_items_unique_messag
|
|||||||
raise ValidationError(message)
|
raise ValidationError(message)
|
||||||
|
|
||||||
return _list_items_unique
|
return _list_items_unique
|
||||||
|
|
||||||
|
|
||||||
|
def RequiredIfNot(other_field_name, message=translate("forms.validators.is_required")):
|
||||||
|
def _required_if_not(form, field):
|
||||||
|
other_field = form._fields.get(other_field_name)
|
||||||
|
if other_field is None:
|
||||||
|
raise Exception('no field named "%s" in form' % self.other_field_name)
|
||||||
|
|
||||||
|
if not bool(other_field.data):
|
||||||
|
if field.data is None:
|
||||||
|
raise ValidationError(message)
|
||||||
|
else:
|
||||||
|
raise StopValidation()
|
||||||
|
|
||||||
|
return _required_if_not
|
||||||
|
16141
package-lock.json
generated
Normal file
16141
package-lock.json
generated
Normal file
File diff suppressed because it is too large
Load Diff
@ -6,6 +6,7 @@ import alembic.command
|
|||||||
from logging.config import dictConfig
|
from logging.config import dictConfig
|
||||||
from werkzeug.datastructures import FileStorage
|
from werkzeug.datastructures import FileStorage
|
||||||
from tempfile import TemporaryDirectory
|
from tempfile import TemporaryDirectory
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
from atst.app import make_app, make_config
|
from atst.app import make_app, make_config
|
||||||
from atst.database import db as _db
|
from atst.database import db as _db
|
||||||
@ -84,14 +85,17 @@ def session(db, request):
|
|||||||
|
|
||||||
|
|
||||||
class DummyForm(dict):
|
class DummyForm(dict):
|
||||||
pass
|
def __init__(self, data=OrderedDict(), errors=(), raw_data=None):
|
||||||
|
self._fields = data
|
||||||
|
self.errors = list(errors)
|
||||||
|
|
||||||
|
|
||||||
class DummyField(object):
|
class DummyField(object):
|
||||||
def __init__(self, data=None, errors=(), raw_data=None):
|
def __init__(self, data=None, errors=(), raw_data=None, name=None):
|
||||||
self.data = data
|
self.data = data
|
||||||
self.errors = list(errors)
|
self.errors = list(errors)
|
||||||
self.raw_data = raw_data
|
self.raw_data = raw_data
|
||||||
|
self.name = name
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -99,6 +103,15 @@ def dummy_form():
|
|||||||
return DummyForm()
|
return DummyForm()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def dummy_form_with_field():
|
||||||
|
def set_field(name, value):
|
||||||
|
data = DummyField(data=value, name=name)
|
||||||
|
return DummyForm(data=OrderedDict({name: data}))
|
||||||
|
|
||||||
|
return set_field
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def dummy_field():
|
def dummy_field():
|
||||||
return DummyField()
|
return DummyField()
|
||||||
|
@ -1,7 +1,13 @@
|
|||||||
from wtforms.validators import ValidationError
|
from wtforms.validators import ValidationError, StopValidation
|
||||||
import pytest
|
import pytest, copy
|
||||||
|
|
||||||
from atst.forms.validators import Name, IsNumber, PhoneNumber, ListItemsUnique
|
from atst.forms.validators import (
|
||||||
|
Name,
|
||||||
|
IsNumber,
|
||||||
|
PhoneNumber,
|
||||||
|
ListItemsUnique,
|
||||||
|
RequiredIfNot,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestIsNumber:
|
class TestIsNumber:
|
||||||
@ -73,3 +79,32 @@ class TestListItemsUnique:
|
|||||||
dummy_field.data = invalid
|
dummy_field.data = invalid
|
||||||
with pytest.raises(ValidationError):
|
with pytest.raises(ValidationError):
|
||||||
validator(dummy_form, dummy_field)
|
validator(dummy_form, dummy_field)
|
||||||
|
|
||||||
|
|
||||||
|
class TestRequiredIfNot:
|
||||||
|
def test_RequiredIfNot_requires_field_if_arg_is_falsy(
|
||||||
|
self, dummy_form_with_field, dummy_field
|
||||||
|
):
|
||||||
|
form = dummy_form_with_field("arg", False)
|
||||||
|
validator = RequiredIfNot("arg")
|
||||||
|
dummy_field.data = None
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
validator(form, dummy_field)
|
||||||
|
|
||||||
|
def test_RequiredIfNot_does_not_require_field_if_arg_is_truthy(
|
||||||
|
self, dummy_form_with_field, dummy_field
|
||||||
|
):
|
||||||
|
form = dummy_form_with_field("arg", True)
|
||||||
|
validator = RequiredIfNot("arg")
|
||||||
|
dummy_field.data = None
|
||||||
|
|
||||||
|
with pytest.raises(StopValidation):
|
||||||
|
validator(form, dummy_field)
|
||||||
|
|
||||||
|
def test_RequiredIfNot_arg_is_None_raises_error(self, dummy_form, dummy_field):
|
||||||
|
validator = RequiredIfNot("arg")
|
||||||
|
dummy_field.data = "some data"
|
||||||
|
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
validator(dummy_form, dummy_field)
|
||||||
|
@ -209,6 +209,7 @@ forms:
|
|||||||
list_items_unique_message: Items must be unique
|
list_items_unique_message: Items must be unique
|
||||||
name_message: 'This field accepts letters, numbers, commas, apostrophes, hyphens, and periods.'
|
name_message: 'This field accepts letters, numbers, commas, apostrophes, hyphens, and periods.'
|
||||||
phone_number_message: Please enter a valid 5 or 10 digit phone number.
|
phone_number_message: Please enter a valid 5 or 10 digit phone number.
|
||||||
|
is_required: This field is required.
|
||||||
portfolio:
|
portfolio:
|
||||||
name_label: Portfolio Name
|
name_label: Portfolio Name
|
||||||
name_length_validation_message: Portfolio names must be at least 4 and not more than 50 characters
|
name_length_validation_message: Portfolio names must be at least 4 and not more than 50 characters
|
||||||
|
Loading…
x
Reference in New Issue
Block a user