diff --git a/atst/routes/requests/financial_verification.py b/atst/routes/requests/financial_verification.py index 23304e14..c1fe43bc 100644 --- a/atst/routes/requests/financial_verification.py +++ b/atst/routes/requests/financial_verification.py @@ -142,32 +142,30 @@ class SaveFinancialVerificationDraft(FinancialVerificationBase): valid = True if not form.validate_draft(): - valid = False - - if ( - valid - and form.pe_id.data - and not self.pe_validator.validate(self.request, form.pe_id.data) - ): - self._apply_pe_number_error(form.pe_id) - valid = False - - if ( - valid - and form.task_order_number.data - and not self.task_order_validator.validate(form.task_order_number.data) - ): - self._apply_task_order_number_error(form.task_order_number) - valid = False - - if not valid: form.reset() raise FormValidationError(form) - else: - updated_request = Requests.update_financial_verification( - self.request.id, form.data - ) + + if form.pe_id.data and not self.pe_validator.validate( + self.request, form.pe_id.data + ): + valid = False + self._apply_pe_number_error(form.pe_id) + + if form.task_order_number.data and not self.task_order_validator.validate( + form.task_order_number.data + ): + valid = False + self._apply_task_order_number_error(form.task_order_number) + + updated_request = Requests.update_financial_verification( + self.request.id, form.data + ) + + if valid: return {"request": updated_request} + else: + form.reset() + raise FormValidationError(form) @requests_bp.route("/requests/verify/", methods=["GET"]) diff --git a/tests/routes/test_financial_verification.py b/tests/routes/test_financial_verification.py index 7863b834..c2e22da2 100644 --- a/tests/routes/test_financial_verification.py +++ b/tests/routes/test_financial_verification.py @@ -48,16 +48,6 @@ FalseValidator = MagicMock() FalseValidator.validate = MagicMock(return_value=False) -class MockPEValidator(object): - def validate(self, request, field): - return True - - -class MockTaskOrderValidator(object): - def validate(self, field): - return True - - def test_update(fv_data): request = RequestFactory.create() user = UserFactory.create() @@ -171,3 +161,16 @@ def test_save_draft_with_invalid_pe_number(fv_data): with pytest.raises(FormValidationError): assert save_draft.execute() + + +def test_save_draft_re_enter_pe_number(fv_data): + request = RequestFactory.create() + user = UserFactory.create() + data = {**fv_data, "pe_id": "0101228M"} + save_fv = SaveFinancialVerificationDraft( + PENumberValidator(), TrueValidator, user, request, data, is_extended=False + ) + + with pytest.raises(FormValidationError): + save_fv.execute() + response_context = save_fv.execute()