diff --git a/atst/routes/__init__.py b/atst/routes/__init__.py index 780ff0c3..0ea12952 100644 --- a/atst/routes/__init__.py +++ b/atst/routes/__init__.py @@ -96,7 +96,10 @@ def _make_authentication_context(): def redirect_after_login_url(): if request.args.get("next"): - return request.args.get("next") + returl = request.args.get("next") + if request.args.get("formCache"): + returl += "?" + url.urlencode({"formCache": request.args.get("formCache")}) + return returl else: return url_for("atst.home") diff --git a/atst/routes/errors.py b/atst/routes/errors.py index 7cf58b46..56551ee0 100644 --- a/atst/routes/errors.py +++ b/atst/routes/errors.py @@ -8,6 +8,7 @@ from atst.domain.invitations import ( ExpiredError as InvitationExpiredError, WrongUserError as InvitationWrongUserError, ) +from atst.utils.form_cache import cache_form_data def log_error(e): @@ -37,7 +38,10 @@ def make_error_pages(app): # pylint: disable=unused-variable def session_expired(e): log_error(e) - return redirect(url_for("atst.root", sessionExpired=True, next=request.path)) + url_args = {"sessionExpired": True, "next": request.path} + if request.method == "POST": + url_args["formCache"] = cache_form_data(app.redis, request.form) + return redirect(url_for("atst.root", **url_args)) @app.errorhandler(Exception) # pylint: disable=unused-variable diff --git a/atst/routes/requests/financial_verification.py b/atst/routes/requests/financial_verification.py index 320db9c6..58940be0 100644 --- a/atst/routes/requests/financial_verification.py +++ b/atst/routes/requests/financial_verification.py @@ -1,4 +1,4 @@ -from flask import g, render_template, redirect, url_for +from flask import g, render_template, redirect, url_for, current_app as app from flask import request as http_request from werkzeug.datastructures import ImmutableMultiDict, FileStorage @@ -13,6 +13,7 @@ from atst.domain.requests.financial_verification import ( ) from atst.models.attachment import Attachment from atst.domain.task_orders import TaskOrders +from atst.utils.form_cache import retrieve_form_data def fv_extended(_http_request): @@ -91,6 +92,12 @@ class FinancialVerificationBase(object): raise FormValidationError(form) +def existing_form_data(): + key = http_request.args.get("formCache") + if key: + return retrieve_form_data(app.redis, key) + + class GetFinancialVerificationForm(FinancialVerificationBase): def __init__(self, user, request, is_extended=False): self.user = user @@ -98,7 +105,7 @@ class GetFinancialVerificationForm(FinancialVerificationBase): self.is_extended = is_extended def execute(self): - form = self._get_form(self.request, self.is_extended) + form = self._get_form(self.request, self.is_extended, formdata=existing_form_data()) return form