diff --git a/atst/app.py b/atst/app.py
index 591e5c65..b309806a 100644
--- a/atst/app.py
+++ b/atst/app.py
@@ -24,6 +24,7 @@ from atst.models.permissions import Permissions
from atst.eda_client import MockEDAClient
from atst.uploader import Uploader
from atst.utils import mailer
+from atst.utils.form_cache import FormCache
from atst.queue import queue
@@ -67,6 +68,8 @@ def make_app(config):
if ENV != "prod":
app.register_blueprint(dev_routes)
+ app.form_cache = FormCache(app.redis)
+
apply_authentication(app)
return app
diff --git a/atst/forms/ccpo_review.py b/atst/forms/ccpo_review.py
index b2413a3e..b727cea0 100644
--- a/atst/forms/ccpo_review.py
+++ b/atst/forms/ccpo_review.py
@@ -2,11 +2,11 @@ from wtforms.fields.html5 import EmailField, TelField
from wtforms.fields import StringField, TextAreaField
from wtforms.validators import Email, Optional
-from .forms import ValidatedForm
+from .forms import CacheableForm
from .validators import Name, PhoneNumber
-class CCPOReviewForm(ValidatedForm):
+class CCPOReviewForm(CacheableForm):
comment = TextAreaField(
"Instructions or comments",
description="Provide instructions or notes for additional information that is necessary to approve the request here. The requestor may then re-submit the updated request or initiate contact outside of AT-AT if further discussion is required. This message will be shared with the person making the JEDI request..",
diff --git a/atst/forms/edit_user.py b/atst/forms/edit_user.py
index 3728bf84..6d406942 100644
--- a/atst/forms/edit_user.py
+++ b/atst/forms/edit_user.py
@@ -5,7 +5,7 @@ from wtforms.fields import RadioField, StringField
from wtforms.validators import Email, DataRequired, Optional
from .fields import SelectField
-from .forms import ValidatedForm
+from .forms import CacheableForm
from .data import SERVICE_BRANCHES
from atst.models.user import User
@@ -77,7 +77,7 @@ def inherit_user_field(field_name):
return inherit_field(USER_FIELDS[field_name], required=required)
-class EditUserForm(ValidatedForm):
+class EditUserForm(CacheableForm):
first_name = inherit_user_field("first_name")
last_name = inherit_user_field("last_name")
diff --git a/atst/forms/financial.py b/atst/forms/financial.py
index cfc86a20..7ee17cfa 100644
--- a/atst/forms/financial.py
+++ b/atst/forms/financial.py
@@ -7,7 +7,7 @@ from flask_wtf.file import FileAllowed
from werkzeug.datastructures import FileStorage
from .fields import NewlineListField, SelectField, NumberStringField
-from atst.forms.forms import ValidatedForm
+from atst.forms.forms import CacheableForm
from .data import FUNDING_TYPES
from .validators import DateRange
@@ -31,7 +31,7 @@ def coerce_choice(val):
return val.value
-class TaskOrderForm(ValidatedForm):
+class TaskOrderForm(CacheableForm):
def do_validate_number(self):
for field in self:
if field.name != "task_order-number":
@@ -127,7 +127,7 @@ class TaskOrderForm(ValidatedForm):
)
-class RequestFinancialVerificationForm(ValidatedForm):
+class RequestFinancialVerificationForm(CacheableForm):
uii_ids = NewlineListField(
"Unique Item Identifier (UII)s related to your application(s) if you already have them.",
description="If you have more than one UII, place each one on a new line.",
@@ -174,7 +174,7 @@ class RequestFinancialVerificationForm(ValidatedForm):
self.uii_ids.process_data(self.uii_ids.data)
-class FinancialVerificationForm(ValidatedForm):
+class FinancialVerificationForm(CacheableForm):
task_order = FormField(TaskOrderForm)
request = FormField(RequestFinancialVerificationForm)
diff --git a/atst/forms/forms.py b/atst/forms/forms.py
index ce0ff791..eeeb48d2 100644
--- a/atst/forms/forms.py
+++ b/atst/forms/forms.py
@@ -1,4 +1,5 @@
from flask_wtf import FlaskForm
+from flask import current_app, request as http_request
class ValidatedForm(FlaskForm):
@@ -12,3 +13,11 @@ class ValidatedForm(FlaskForm):
_data = super().data
_data.pop("csrf_token", None)
return _data
+
+
+class CacheableForm(ValidatedForm):
+ def __init__(self, formdata=None, **kwargs):
+ formdata = formdata or {}
+ cached_data = current_app.form_cache.from_request(http_request)
+ cached_data.update(formdata)
+ super().__init__(cached_data, **kwargs)
diff --git a/atst/forms/internal_comment.py b/atst/forms/internal_comment.py
index 583db7a1..7711ff04 100644
--- a/atst/forms/internal_comment.py
+++ b/atst/forms/internal_comment.py
@@ -1,10 +1,10 @@
from wtforms.fields import TextAreaField
from wtforms.validators import InputRequired
-from .forms import ValidatedForm
+from .forms import CacheableForm
-class InternalCommentForm(ValidatedForm):
+class InternalCommentForm(CacheableForm):
text = TextAreaField(
"CCPO Internal Notes",
default="",
diff --git a/atst/forms/new_request.py b/atst/forms/new_request.py
index 4b564c37..bc05f412 100644
--- a/atst/forms/new_request.py
+++ b/atst/forms/new_request.py
@@ -4,7 +4,7 @@ from wtforms.fields import BooleanField, RadioField, StringField, TextAreaField
from wtforms.validators import Email, Length, Optional, InputRequired, DataRequired
from .fields import SelectField
-from .forms import ValidatedForm
+from .forms import CacheableForm
from .edit_user import USER_FIELDS, inherit_field
from .data import (
SERVICE_BRANCHES,
@@ -16,7 +16,7 @@ from .validators import DateRange, IsNumber
from atst.domain.requests import Requests
-class DetailsOfUseForm(ValidatedForm):
+class DetailsOfUseForm(CacheableForm):
def validate(self, *args, **kwargs):
if self.jedi_migration.data == "no":
self.rationalization_software_systems.validators.append(Optional())
@@ -162,7 +162,7 @@ class DetailsOfUseForm(ValidatedForm):
)
-class InformationAboutYouForm(ValidatedForm):
+class InformationAboutYouForm(CacheableForm):
fname_request = inherit_field(USER_FIELDS["first_name"])
lname_request = inherit_field(USER_FIELDS["last_name"])
email_request = inherit_field(USER_FIELDS["email"])
@@ -174,7 +174,7 @@ class InformationAboutYouForm(ValidatedForm):
date_latest_training = inherit_field(USER_FIELDS["date_latest_training"])
-class WorkspaceOwnerForm(ValidatedForm):
+class WorkspaceOwnerForm(CacheableForm):
def validate(self, *args, **kwargs):
if self.am_poc.data:
# Prepend Optional validators so that the validation chain
@@ -203,5 +203,5 @@ class WorkspaceOwnerForm(ValidatedForm):
)
-class ReviewAndSubmitForm(ValidatedForm):
+class ReviewAndSubmitForm(CacheableForm):
reviewed = BooleanField("I have reviewed this data and it is correct.")
diff --git a/atst/forms/workspace.py b/atst/forms/workspace.py
index 9fde98bb..5434676f 100644
--- a/atst/forms/workspace.py
+++ b/atst/forms/workspace.py
@@ -1,10 +1,10 @@
from wtforms.fields import StringField
from wtforms.validators import Length
-from .forms import ValidatedForm
+from .forms import CacheableForm
-class WorkspaceForm(ValidatedForm):
+class WorkspaceForm(CacheableForm):
name = StringField(
"Workspace Name",
validators=[
diff --git a/atst/routes/__init__.py b/atst/routes/__init__.py
index 9fd690a9..93f04601 100644
--- a/atst/routes/__init__.py
+++ b/atst/routes/__init__.py
@@ -96,7 +96,12 @@ def _make_authentication_context():
def redirect_after_login_url():
if request.args.get("next"):
- return request.args.get("next")
+ returl = request.args.get("next")
+ if request.args.get(app.form_cache.PARAM_NAME):
+ returl += "?" + url.urlencode(
+ {app.form_cache.PARAM_NAME: request.args.get(app.form_cache.PARAM_NAME)}
+ )
+ return returl
else:
return url_for("atst.home")
diff --git a/atst/routes/errors.py b/atst/routes/errors.py
index 7cf58b46..3f32b110 100644
--- a/atst/routes/errors.py
+++ b/atst/routes/errors.py
@@ -37,7 +37,10 @@ def make_error_pages(app):
# pylint: disable=unused-variable
def session_expired(e):
log_error(e)
- return redirect(url_for("atst.root", sessionExpired=True, next=request.path))
+ url_args = {"sessionExpired": True, "next": request.path}
+ if request.method == "POST":
+ url_args[app.form_cache.PARAM_NAME] = app.form_cache.write(request.form)
+ return redirect(url_for("atst.root", **url_args))
@app.errorhandler(Exception)
# pylint: disable=unused-variable
diff --git a/atst/routes/requests/financial_verification.py b/atst/routes/requests/financial_verification.py
index 320db9c6..8f2369ac 100644
--- a/atst/routes/requests/financial_verification.py
+++ b/atst/routes/requests/financial_verification.py
@@ -99,6 +99,7 @@ class GetFinancialVerificationForm(FinancialVerificationBase):
def execute(self):
form = self._get_form(self.request, self.is_extended)
+ form.reset()
return form
@@ -178,6 +179,7 @@ class SaveFinancialVerificationDraft(FinancialVerificationBase):
return updated_request
+@requests_bp.route("/requests/verify//draft", methods=["GET"])
@requests_bp.route("/requests/verify/", methods=["GET"])
def financial_verification(request_id):
request = Requests.get(g.current_user, request_id)
diff --git a/atst/utils/form_cache.py b/atst/utils/form_cache.py
new file mode 100644
index 00000000..7dec1470
--- /dev/null
+++ b/atst/utils/form_cache.py
@@ -0,0 +1,40 @@
+from hashlib import sha256
+import json
+from werkzeug.datastructures import MultiDict
+
+
+DEFAULT_CACHE_NAME = "formcache"
+
+
+class FormCache(object):
+ PARAM_NAME = "formCache"
+
+ def __init__(self, redis):
+ self.redis = redis
+
+ def from_request(self, http_request):
+ cache_key = http_request.args.get(self.PARAM_NAME)
+ if cache_key:
+ return self.read(cache_key)
+ return MultiDict()
+
+ def write(self, formdata, expiry_seconds=3600, key_prefix=DEFAULT_CACHE_NAME):
+ value = json.dumps(formdata)
+ hash_ = self._hash()
+ self.redis.setex(
+ name=self._key(key_prefix, hash_), value=value, time=expiry_seconds
+ )
+ return hash_
+
+ def read(self, formdata_key, key_prefix=DEFAULT_CACHE_NAME):
+ data = self.redis.get(self._key(key_prefix, formdata_key))
+ dict_data = json.loads(data) if data is not None else {}
+ return MultiDict(dict_data)
+
+ @staticmethod
+ def _key(prefix, hash_):
+ return "{}:{}".format(prefix, hash_)
+
+ @staticmethod
+ def _hash():
+ return sha256().hexdigest()
diff --git a/tests/utils/test_form_cache.py b/tests/utils/test_form_cache.py
new file mode 100644
index 00000000..c399acd1
--- /dev/null
+++ b/tests/utils/test_form_cache.py
@@ -0,0 +1,22 @@
+import pytest
+from werkzeug.datastructures import ImmutableMultiDict
+
+from atst.utils.form_cache import DEFAULT_CACHE_NAME, FormCache
+
+
+@pytest.fixture
+def form_cache(app):
+ return FormCache(app.redis)
+
+
+def test_cache_form_data(app, form_cache):
+ data = ImmutableMultiDict({"kessel_run": "12 parsecs"})
+ key = form_cache.write(data)
+ assert app.redis.get("{}:{}".format(DEFAULT_CACHE_NAME, key))
+
+
+def test_retrieve_form_data(form_cache):
+ data = ImmutableMultiDict({"class": "corellian"})
+ key = form_cache.write(data)
+ retrieved = form_cache.read(key)
+ assert retrieved == data