diff --git a/atst/app.py b/atst/app.py index 591e5c65..b309806a 100644 --- a/atst/app.py +++ b/atst/app.py @@ -24,6 +24,7 @@ from atst.models.permissions import Permissions from atst.eda_client import MockEDAClient from atst.uploader import Uploader from atst.utils import mailer +from atst.utils.form_cache import FormCache from atst.queue import queue @@ -67,6 +68,8 @@ def make_app(config): if ENV != "prod": app.register_blueprint(dev_routes) + app.form_cache = FormCache(app.redis) + apply_authentication(app) return app diff --git a/atst/forms/ccpo_review.py b/atst/forms/ccpo_review.py index b2413a3e..b727cea0 100644 --- a/atst/forms/ccpo_review.py +++ b/atst/forms/ccpo_review.py @@ -2,11 +2,11 @@ from wtforms.fields.html5 import EmailField, TelField from wtforms.fields import StringField, TextAreaField from wtforms.validators import Email, Optional -from .forms import ValidatedForm +from .forms import CacheableForm from .validators import Name, PhoneNumber -class CCPOReviewForm(ValidatedForm): +class CCPOReviewForm(CacheableForm): comment = TextAreaField( "Instructions or comments", description="Provide instructions or notes for additional information that is necessary to approve the request here. The requestor may then re-submit the updated request or initiate contact outside of AT-AT if further discussion is required. This message will be shared with the person making the JEDI request..", diff --git a/atst/forms/edit_user.py b/atst/forms/edit_user.py index 3728bf84..6d406942 100644 --- a/atst/forms/edit_user.py +++ b/atst/forms/edit_user.py @@ -5,7 +5,7 @@ from wtforms.fields import RadioField, StringField from wtforms.validators import Email, DataRequired, Optional from .fields import SelectField -from .forms import ValidatedForm +from .forms import CacheableForm from .data import SERVICE_BRANCHES from atst.models.user import User @@ -77,7 +77,7 @@ def inherit_user_field(field_name): return inherit_field(USER_FIELDS[field_name], required=required) -class EditUserForm(ValidatedForm): +class EditUserForm(CacheableForm): first_name = inherit_user_field("first_name") last_name = inherit_user_field("last_name") diff --git a/atst/forms/financial.py b/atst/forms/financial.py index cfc86a20..7ee17cfa 100644 --- a/atst/forms/financial.py +++ b/atst/forms/financial.py @@ -7,7 +7,7 @@ from flask_wtf.file import FileAllowed from werkzeug.datastructures import FileStorage from .fields import NewlineListField, SelectField, NumberStringField -from atst.forms.forms import ValidatedForm +from atst.forms.forms import CacheableForm from .data import FUNDING_TYPES from .validators import DateRange @@ -31,7 +31,7 @@ def coerce_choice(val): return val.value -class TaskOrderForm(ValidatedForm): +class TaskOrderForm(CacheableForm): def do_validate_number(self): for field in self: if field.name != "task_order-number": @@ -127,7 +127,7 @@ class TaskOrderForm(ValidatedForm): ) -class RequestFinancialVerificationForm(ValidatedForm): +class RequestFinancialVerificationForm(CacheableForm): uii_ids = NewlineListField( "Unique Item Identifier (UII)s related to your application(s) if you already have them.", description="If you have more than one UII, place each one on a new line.", @@ -174,7 +174,7 @@ class RequestFinancialVerificationForm(ValidatedForm): self.uii_ids.process_data(self.uii_ids.data) -class FinancialVerificationForm(ValidatedForm): +class FinancialVerificationForm(CacheableForm): task_order = FormField(TaskOrderForm) request = FormField(RequestFinancialVerificationForm) diff --git a/atst/forms/forms.py b/atst/forms/forms.py index ce0ff791..eeeb48d2 100644 --- a/atst/forms/forms.py +++ b/atst/forms/forms.py @@ -1,4 +1,5 @@ from flask_wtf import FlaskForm +from flask import current_app, request as http_request class ValidatedForm(FlaskForm): @@ -12,3 +13,11 @@ class ValidatedForm(FlaskForm): _data = super().data _data.pop("csrf_token", None) return _data + + +class CacheableForm(ValidatedForm): + def __init__(self, formdata=None, **kwargs): + formdata = formdata or {} + cached_data = current_app.form_cache.from_request(http_request) + cached_data.update(formdata) + super().__init__(cached_data, **kwargs) diff --git a/atst/forms/internal_comment.py b/atst/forms/internal_comment.py index 583db7a1..7711ff04 100644 --- a/atst/forms/internal_comment.py +++ b/atst/forms/internal_comment.py @@ -1,10 +1,10 @@ from wtforms.fields import TextAreaField from wtforms.validators import InputRequired -from .forms import ValidatedForm +from .forms import CacheableForm -class InternalCommentForm(ValidatedForm): +class InternalCommentForm(CacheableForm): text = TextAreaField( "CCPO Internal Notes", default="", diff --git a/atst/forms/new_request.py b/atst/forms/new_request.py index 4b564c37..bc05f412 100644 --- a/atst/forms/new_request.py +++ b/atst/forms/new_request.py @@ -4,7 +4,7 @@ from wtforms.fields import BooleanField, RadioField, StringField, TextAreaField from wtforms.validators import Email, Length, Optional, InputRequired, DataRequired from .fields import SelectField -from .forms import ValidatedForm +from .forms import CacheableForm from .edit_user import USER_FIELDS, inherit_field from .data import ( SERVICE_BRANCHES, @@ -16,7 +16,7 @@ from .validators import DateRange, IsNumber from atst.domain.requests import Requests -class DetailsOfUseForm(ValidatedForm): +class DetailsOfUseForm(CacheableForm): def validate(self, *args, **kwargs): if self.jedi_migration.data == "no": self.rationalization_software_systems.validators.append(Optional()) @@ -162,7 +162,7 @@ class DetailsOfUseForm(ValidatedForm): ) -class InformationAboutYouForm(ValidatedForm): +class InformationAboutYouForm(CacheableForm): fname_request = inherit_field(USER_FIELDS["first_name"]) lname_request = inherit_field(USER_FIELDS["last_name"]) email_request = inherit_field(USER_FIELDS["email"]) @@ -174,7 +174,7 @@ class InformationAboutYouForm(ValidatedForm): date_latest_training = inherit_field(USER_FIELDS["date_latest_training"]) -class WorkspaceOwnerForm(ValidatedForm): +class WorkspaceOwnerForm(CacheableForm): def validate(self, *args, **kwargs): if self.am_poc.data: # Prepend Optional validators so that the validation chain @@ -203,5 +203,5 @@ class WorkspaceOwnerForm(ValidatedForm): ) -class ReviewAndSubmitForm(ValidatedForm): +class ReviewAndSubmitForm(CacheableForm): reviewed = BooleanField("I have reviewed this data and it is correct.") diff --git a/atst/forms/workspace.py b/atst/forms/workspace.py index 9fde98bb..5434676f 100644 --- a/atst/forms/workspace.py +++ b/atst/forms/workspace.py @@ -1,10 +1,10 @@ from wtforms.fields import StringField from wtforms.validators import Length -from .forms import ValidatedForm +from .forms import CacheableForm -class WorkspaceForm(ValidatedForm): +class WorkspaceForm(CacheableForm): name = StringField( "Workspace Name", validators=[ diff --git a/atst/routes/__init__.py b/atst/routes/__init__.py index 9fd690a9..93f04601 100644 --- a/atst/routes/__init__.py +++ b/atst/routes/__init__.py @@ -96,7 +96,12 @@ def _make_authentication_context(): def redirect_after_login_url(): if request.args.get("next"): - return request.args.get("next") + returl = request.args.get("next") + if request.args.get(app.form_cache.PARAM_NAME): + returl += "?" + url.urlencode( + {app.form_cache.PARAM_NAME: request.args.get(app.form_cache.PARAM_NAME)} + ) + return returl else: return url_for("atst.home") diff --git a/atst/routes/errors.py b/atst/routes/errors.py index 7cf58b46..3f32b110 100644 --- a/atst/routes/errors.py +++ b/atst/routes/errors.py @@ -37,7 +37,10 @@ def make_error_pages(app): # pylint: disable=unused-variable def session_expired(e): log_error(e) - return redirect(url_for("atst.root", sessionExpired=True, next=request.path)) + url_args = {"sessionExpired": True, "next": request.path} + if request.method == "POST": + url_args[app.form_cache.PARAM_NAME] = app.form_cache.write(request.form) + return redirect(url_for("atst.root", **url_args)) @app.errorhandler(Exception) # pylint: disable=unused-variable diff --git a/atst/routes/requests/financial_verification.py b/atst/routes/requests/financial_verification.py index 320db9c6..8f2369ac 100644 --- a/atst/routes/requests/financial_verification.py +++ b/atst/routes/requests/financial_verification.py @@ -99,6 +99,7 @@ class GetFinancialVerificationForm(FinancialVerificationBase): def execute(self): form = self._get_form(self.request, self.is_extended) + form.reset() return form @@ -178,6 +179,7 @@ class SaveFinancialVerificationDraft(FinancialVerificationBase): return updated_request +@requests_bp.route("/requests/verify//draft", methods=["GET"]) @requests_bp.route("/requests/verify/", methods=["GET"]) def financial_verification(request_id): request = Requests.get(g.current_user, request_id) diff --git a/atst/utils/form_cache.py b/atst/utils/form_cache.py new file mode 100644 index 00000000..7dec1470 --- /dev/null +++ b/atst/utils/form_cache.py @@ -0,0 +1,40 @@ +from hashlib import sha256 +import json +from werkzeug.datastructures import MultiDict + + +DEFAULT_CACHE_NAME = "formcache" + + +class FormCache(object): + PARAM_NAME = "formCache" + + def __init__(self, redis): + self.redis = redis + + def from_request(self, http_request): + cache_key = http_request.args.get(self.PARAM_NAME) + if cache_key: + return self.read(cache_key) + return MultiDict() + + def write(self, formdata, expiry_seconds=3600, key_prefix=DEFAULT_CACHE_NAME): + value = json.dumps(formdata) + hash_ = self._hash() + self.redis.setex( + name=self._key(key_prefix, hash_), value=value, time=expiry_seconds + ) + return hash_ + + def read(self, formdata_key, key_prefix=DEFAULT_CACHE_NAME): + data = self.redis.get(self._key(key_prefix, formdata_key)) + dict_data = json.loads(data) if data is not None else {} + return MultiDict(dict_data) + + @staticmethod + def _key(prefix, hash_): + return "{}:{}".format(prefix, hash_) + + @staticmethod + def _hash(): + return sha256().hexdigest() diff --git a/tests/utils/test_form_cache.py b/tests/utils/test_form_cache.py new file mode 100644 index 00000000..c399acd1 --- /dev/null +++ b/tests/utils/test_form_cache.py @@ -0,0 +1,22 @@ +import pytest +from werkzeug.datastructures import ImmutableMultiDict + +from atst.utils.form_cache import DEFAULT_CACHE_NAME, FormCache + + +@pytest.fixture +def form_cache(app): + return FormCache(app.redis) + + +def test_cache_form_data(app, form_cache): + data = ImmutableMultiDict({"kessel_run": "12 parsecs"}) + key = form_cache.write(data) + assert app.redis.get("{}:{}".format(DEFAULT_CACHE_NAME, key)) + + +def test_retrieve_form_data(form_cache): + data = ImmutableMultiDict({"class": "corellian"}) + key = form_cache.write(data) + retrieved = form_cache.read(key) + assert retrieved == data