diff --git a/atst/forms/forms.py b/atst/forms/forms.py index ce0ff791..44405b56 100644 --- a/atst/forms/forms.py +++ b/atst/forms/forms.py @@ -1,7 +1,8 @@ from flask_wtf import FlaskForm +from flask import current_app, request as http_request -class ValidatedForm(FlaskForm): +class _ValidatedForm(FlaskForm): def perform_extra_validation(self, *args, **kwargs): """Performs any applicable extra validation. Must return True if the form is valid or False otherwise.""" @@ -12,3 +13,11 @@ class ValidatedForm(FlaskForm): _data = super().data _data.pop("csrf_token", None) return _data + + +class ValidatedForm(_ValidatedForm): + def __init__(self, formdata=None, **kwargs): + formdata = formdata or {} + cached_data = current_app.form_cache.from_request(http_request) + cached_data.update(formdata) + super().__init__(cached_data, **kwargs) diff --git a/atst/routes/requests/financial_verification.py b/atst/routes/requests/financial_verification.py index 2f609d31..65c6d8ed 100644 --- a/atst/routes/requests/financial_verification.py +++ b/atst/routes/requests/financial_verification.py @@ -92,14 +92,13 @@ class FinancialVerificationBase(object): class GetFinancialVerificationForm(FinancialVerificationBase): - def __init__(self, user, request, cached_data=None, is_extended=False): + def __init__(self, user, request, is_extended=False): self.user = user self.request = request - self.cached_data = cached_data or {} self.is_extended = is_extended def execute(self): - form = self._get_form(self.request, self.is_extended, formdata=self.cached_data) + form = self._get_form(self.request, self.is_extended) form.reset() return form @@ -194,10 +193,7 @@ def financial_verification(request_id): ) form = GetFinancialVerificationForm( - g.current_user, - request, - is_extended=is_extended, - cached_data=app.form_cache.from_request(http_request), + g.current_user, request, is_extended=is_extended ).execute() return render_template( diff --git a/atst/routes/requests/requests_form.py b/atst/routes/requests/requests_form.py index a2a04e4f..617b0940 100644 --- a/atst/routes/requests/requests_form.py +++ b/atst/routes/requests/requests_form.py @@ -36,10 +36,7 @@ def option_data(): @requests_bp.route("/requests/new/", methods=["GET"]) def requests_form_new(screen): - cached_data = current_app.form_cache.from_request(http_request) - jedi_flow = JEDIRequestFlow( - screen, request=None, current_user=g.current_user, post_data=cached_data - ) + jedi_flow = JEDIRequestFlow(screen, request=None, current_user=g.current_user) return render_template( "requests/screen-%d.html" % int(screen), diff --git a/atst/utils/form_cache.py b/atst/utils/form_cache.py index c7ac6d80..7dec1470 100644 --- a/atst/utils/form_cache.py +++ b/atst/utils/form_cache.py @@ -1,6 +1,6 @@ from hashlib import sha256 import json -from werkzeug.datastructures import ImmutableMultiDict +from werkzeug.datastructures import MultiDict DEFAULT_CACHE_NAME = "formcache" @@ -16,10 +16,11 @@ class FormCache(object): cache_key = http_request.args.get(self.PARAM_NAME) if cache_key: return self.read(cache_key) + return MultiDict() def write(self, formdata, expiry_seconds=3600, key_prefix=DEFAULT_CACHE_NAME): value = json.dumps(formdata) - hash_ = sha256().hexdigest() + hash_ = self._hash() self.redis.setex( name=self._key(key_prefix, hash_), value=value, time=expiry_seconds ) @@ -28,8 +29,12 @@ class FormCache(object): def read(self, formdata_key, key_prefix=DEFAULT_CACHE_NAME): data = self.redis.get(self._key(key_prefix, formdata_key)) dict_data = json.loads(data) if data is not None else {} - return ImmutableMultiDict(dict_data) + return MultiDict(dict_data) @staticmethod def _key(prefix, hash_): return "{}:{}".format(prefix, hash_) + + @staticmethod + def _hash(): + return sha256().hexdigest()