diff --git a/alembic/versions/04fe150da553_add_request_revision.py b/alembic/versions/04fe150da553_add_request_revision.py index 2a3cb040..574d3e7d 100644 --- a/alembic/versions/04fe150da553_add_request_revision.py +++ b/alembic/versions/04fe150da553_add_request_revision.py @@ -30,20 +30,20 @@ def upgrade(): sa.Column('lname_poc', sa.String(), nullable=True), sa.Column('jedi_usage', sa.String(), nullable=True), sa.Column('start_date', sa.Date(), nullable=True), - sa.Column('cloud_native', sa.Boolean(), nullable=True), + sa.Column('cloud_native', sa.String(), nullable=True), sa.Column('dollar_value', sa.Integer(), nullable=True), sa.Column('dod_component', sa.String(), nullable=True), sa.Column('data_transfers', sa.String(), nullable=True), sa.Column('expected_completion_date', sa.String(), nullable=True), - sa.Column('jedi_migration', sa.Boolean(), nullable=True), + sa.Column('jedi_migration', sa.String(), nullable=True), sa.Column('num_software_systems', sa.Integer(), nullable=True), sa.Column('number_user_sessions', sa.Integer(), nullable=True), sa.Column('average_daily_traffic', sa.Integer(), nullable=True), - sa.Column('engineering_assessment', sa.Boolean(), nullable=True), - sa.Column('technical_support_team', sa.Boolean(), nullable=True), + sa.Column('engineering_assessment', sa.String(), nullable=True), + sa.Column('technical_support_team', sa.String(), nullable=True), sa.Column('estimated_monthly_spend', sa.Integer(), nullable=True), sa.Column('average_daily_traffic_gb', sa.Integer(), nullable=True), - sa.Column('rationalization_software_systems', sa.Boolean(), nullable=True), + sa.Column('rationalization_software_systems', sa.String(), nullable=True), sa.Column('organization_providing_assistance', sa.String(), nullable=True), sa.Column('citizenship', sa.String(), nullable=True), sa.Column('designation', sa.String(), nullable=True), diff --git a/alembic/versions/a903ebe91ad5_add_sequence_to_request_revision.py b/alembic/versions/a903ebe91ad5_add_sequence_to_request_revision.py new file mode 100644 index 00000000..14348c64 --- /dev/null +++ b/alembic/versions/a903ebe91ad5_add_sequence_to_request_revision.py @@ -0,0 +1,28 @@ +"""add sequence to request revision + +Revision ID: a903ebe91ad5 +Revises: 04fe150da553 +Create Date: 2018-08-30 13:45:35.561657 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = 'a903ebe91ad5' +down_revision = '04fe150da553' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + db = op.get_bind() + op.add_column('request_revisions', sa.Column('sequence', sa.BigInteger(), nullable=False)) + db.execute("CREATE SEQUENCE request_revisions_sequence_seq OWNED BY request_revisions.sequence;") + # ### end Alembic commands ### + + +def downgrade(): + op.drop_column('request_revisions', 'sequence') diff --git a/atst/domain/requests.py b/atst/domain/requests.py index d44be200..468878a4 100644 --- a/atst/domain/requests.py +++ b/atst/domain/requests.py @@ -4,33 +4,30 @@ from sqlalchemy.sql import text from sqlalchemy.orm.exc import NoResultFound from sqlalchemy.orm.attributes import flag_modified from werkzeug.datastructures import FileStorage +import pendulum from atst.database import db from atst.domain.authz import Authorization from atst.domain.task_orders import TaskOrders from atst.domain.workspaces import Workspaces from atst.models.request import Request +from atst.models.request_revision import RequestRevision from atst.models.request_status_event import RequestStatusEvent, RequestStatus +from atst.utils import deep_merge from .exceptions import NotFoundError, UnauthorizedError -def deep_merge(source, destination: dict): - """ - Merge source dict into destination dict recursively. - """ - - def _deep_merge(a, b): - for key, value in a.items(): - if isinstance(value, dict): - node = b.setdefault(key, {}) - _deep_merge(value, node) - else: - b[key] = value - - return b - - return _deep_merge(source, dict(destination)) +def create_revision_from_request_body(body): + body = {k: v for p in body.values() for k, v in p.items()} + TIMESTAMPS = ["start_date", "date_latest_training"] + coerced_timestamps = { + k: pendulum.parse(v) + for k, v in body.items() + if k in TIMESTAMPS and isinstance(v, str) + } + body = {**body, **coerced_timestamps} + return RequestRevision(**body) class Requests(object): @@ -39,7 +36,8 @@ class Requests(object): @classmethod def create(cls, creator, body): - request = Request(creator=creator, body=body) + revision = create_revision_from_request_body(body) + request = Request(creator=creator, revisions=[revision]) request = Requests.set_status(request, RequestStatus.STARTED) db.session.add(request) @@ -105,7 +103,10 @@ class Requests(object): @classmethod def update(cls, request_id, request_delta): request = Requests._get_with_lock(request_id) - request = Requests._merge_body(request, request_delta) + + new_body = deep_merge(request_delta, request.body) + revision = create_revision_from_request_body(new_body) + request.revisions.append(revision) db.session.add(request) db.session.commit() @@ -129,13 +130,7 @@ class Requests(object): @classmethod def _merge_body(cls, request, request_delta): - request.body = deep_merge(request_delta, request.body) - - # Without this, sqlalchemy won't notice the change to request.body, - # since it doesn't track dictionary mutations by default. - flag_modified(request, "body") - - return request + return deep_merge(request_delta, request.body) @classmethod def approve_and_create_workspace(cls, request): @@ -264,12 +259,7 @@ WHERE requests_with_status.status = :status if task_order: request.task_order = task_order - request = Requests._merge_body( - request, {"financial_verification": request_data} - ) - - db.session.add(request) - db.session.commit() + request = Requests.update(request.id, {"financial_verification": request_data}) return request diff --git a/atst/forms/fields.py b/atst/forms/fields.py index 60eb7f26..27db5742 100644 --- a/atst/forms/fields.py +++ b/atst/forms/fields.py @@ -16,7 +16,7 @@ class DateField(DateField): if values: self.data = values[0] else: - self.data = [] + self.data = None class NewlineListField(Field): diff --git a/atst/models/request.py b/atst/models/request.py index 68875e08..51825ebf 100644 --- a/atst/models/request.py +++ b/atst/models/request.py @@ -7,7 +7,23 @@ import pendulum from atst.models import Base from atst.models.types import Id from atst.models.request_status_event import RequestStatus -from atst.utils import first_or_none +from atst.utils import deep_merge, first_or_none + + +def map_properties_to_dict(properties, instance): + return { + field: getattr(instance, field) + for field in properties + if getattr(instance, field) is not None + } + + +def update_dict_with_properties(instance, body, top_level_key, properties): + new_properties = map_properties_to_dict(properties, instance) + if new_properties: + body[top_level_key] = new_properties + + return body class Request(Base): @@ -28,7 +44,77 @@ class Request(Base): task_order_id = Column(ForeignKey("task_order.id")) task_order = relationship("TaskOrder") - revisions = relationship("RequestRevision", back_populates="request") + revisions = relationship( + "RequestRevision", back_populates="request", order_by="RequestRevision.sequence" + ) + + @property + def latest_revision(self): + if self.revisions: + return self.revisions[-1] + + else: + return RequestRevision(request=self) + + PRIMARY_POC_FIELDS = ["am_poc", "dodid_poc", "email_poc", "fname_poc", "lname_poc"] + DETAILS_OF_USE_FIELDS = [ + "jedi_usage", + "start_date", + "cloud_native", + "dollar_value", + "dod_component", + "data_transfers", + "expected_completion_date", + "jedi_migration", + "num_software_systems", + "number_user_sessions", + "average_daily_traffic", + "engineering_assessment", + "technical_support_team", + "estimated_monthly_spend", + "average_daily_traffic_gb", + "rationalization_software_systems", + "organization_providing_assistance", + ] + INFORMATION_ABOUT_YOU_FIELDS = [ + "citizenship", + "designation", + "phone_number", + "email_request", + "fname_request", + "lname_request", + "service_branch", + "date_latest_training", + ] + FINANCIAL_VERIFICATION_FIELDS = [ + "pe_id", + "task_order_number", + "fname_co", + "lname_co", + "email_co", + "office_co", + "fname_cor", + "lname_cor", + "email_cor", + "office_cor", + "uii_ids", + "treasury_code", + "ba_code", + ] + + @property + def body(self): + current = self.latest_revision + body = {} + for top_level_key, properties in [ + ("primary_poc", Request.PRIMARY_POC_FIELDS), + ("details_of_use", Request.DETAILS_OF_USE_FIELDS), + ("information_about_you", Request.INFORMATION_ABOUT_YOU_FIELDS), + ("financial_verification", Request.FINANCIAL_VERIFICATION_FIELDS), + ]: + body = update_dict_with_properties(current, body, top_level_key, properties) + + return body @property def status(self): @@ -40,7 +126,7 @@ class Request(Base): @property def annual_spend(self): - monthly = self.body.get("details_of_use", {}).get("estimated_monthly_spend", 0) + monthly = self.latest_revision.estimated_monthly_spend or 0 return monthly * 12 @property diff --git a/atst/models/request_revision.py b/atst/models/request_revision.py index 138d32fb..193e2ca3 100644 --- a/atst/models/request_revision.py +++ b/atst/models/request_revision.py @@ -1,5 +1,15 @@ import pendulum -from sqlalchemy import Column, func, ForeignKey, String, Boolean, Integer, Date +from sqlalchemy import ( + Column, + func, + ForeignKey, + String, + Boolean, + Integer, + Date, + BigInteger, + Sequence, +) from sqlalchemy.types import DateTime from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.orm import relationship @@ -15,6 +25,9 @@ class RequestRevision(Base, TimestampsMixin): id = Id() request_id = Column(ForeignKey("requests.id"), nullable=False) request = relationship("Request", back_populates="revisions") + sequence = Column( + BigInteger, Sequence("request_revisions_sequence_seq"), nullable=False + ) # primary_poc am_poc = Column(Boolean, default=False) @@ -26,20 +39,20 @@ class RequestRevision(Base, TimestampsMixin): # details_of_use jedi_usage = Column(String) start_date = Column(Date()) - cloud_native = Column(Boolean) + cloud_native = Column(String) dollar_value = Column(Integer) dod_component = Column(String) data_transfers = Column(String) expected_completion_date = Column(String) - jedi_migration = Column(Boolean) + jedi_migration = Column(String) num_software_systems = Column(Integer) number_user_sessions = Column(Integer) average_daily_traffic = Column(Integer) - engineering_assessment = Column(Boolean) - technical_support_team = Column(Boolean) + engineering_assessment = Column(String) + technical_support_team = Column(String) estimated_monthly_spend = Column(Integer) average_daily_traffic_gb = Column(Integer) - rationalization_software_systems = Column(Boolean) + rationalization_software_systems = Column(String) organization_providing_assistance = Column(String) # information_about_you @@ -66,13 +79,3 @@ class RequestRevision(Base, TimestampsMixin): uii_ids = Column(String) treasury_code = Column(String) ba_code = Column(String) - - _BOOLS = ["am_poc", "jedi_migration", "engineering_assessment", "technical_support_team", "rationalization_software_systems", "cloud_native"] - _TIMESTAMPS = ["start_date", "date_latest_training"] - - @classmethod - def create_from_request_body(cls, request, **body): - coerced_bools = {k: v == "yes" for k,v in body.items() if k in RequestRevision._BOOLS} - coerced_timestamps = {k: pendulum.parse(v) for k,v in body.items() if k in RequestRevision._TIMESTAMPS} - body = {**body, **coerced_bools, **coerced_timestamps} - return RequestRevision(request=request, **body) diff --git a/atst/utils.py b/atst/utils.py index db4933f4..3923a341 100644 --- a/atst/utils.py +++ b/atst/utils.py @@ -1,2 +1,20 @@ def first_or_none(predicate, lst): return next((x for x in lst if predicate(x)), None) + + +def deep_merge(source, destination: dict): + """ + Merge source dict into destination dict recursively. + """ + + def _deep_merge(a, b): + for key, value in a.items(): + if isinstance(value, dict): + node = b.setdefault(key, {}) + _deep_merge(value, node) + else: + b[key] = value + + return b + + return _deep_merge(source, dict(destination)) diff --git a/tests/domain/test_requests.py b/tests/domain/test_requests.py index dfa6018a..d1ab45e8 100644 --- a/tests/domain/test_requests.py +++ b/tests/domain/test_requests.py @@ -12,6 +12,7 @@ from tests.factories import ( UserFactory, RequestStatusEventFactory, TaskOrderFactory, + RequestRevisionFactory, ) @@ -20,10 +21,11 @@ def new_request(session): return RequestFactory.create() -def test_can_get_request(new_request): - request = Requests.get(new_request.creator, new_request.id) +def test_can_get_request(): + factory_req = RequestFactory.create() + request = Requests.get(factory_req.creator, factory_req.id) - assert request.id == new_request.id + assert request.id == factory_req.id def test_nonexistent_request_raises(): @@ -37,28 +39,30 @@ def test_new_request_has_started_status(): assert request.status == RequestStatus.STARTED -def test_auto_approve_less_than_1m(new_request): - new_request.body = {"details_of_use": {"dollar_value": 999999}} +def test_auto_approve_less_than_1m(): + new_request = RequestFactory.create(initial_revision={"dollar_value": 999999}) request = Requests.submit(new_request) assert request.status == RequestStatus.PENDING_FINANCIAL_VERIFICATION -def test_dont_auto_approve_if_dollar_value_is_1m_or_above(new_request): - new_request.body = {"details_of_use": {"dollar_value": 1000000}} +def test_dont_auto_approve_if_dollar_value_is_1m_or_above(): + new_request = RequestFactory.create(initial_revision={"dollar_value": 1000000}) request = Requests.submit(new_request) assert request.status == RequestStatus.PENDING_CCPO_APPROVAL -def test_dont_auto_approve_if_no_dollar_value_specified(new_request): - new_request.body = {"details_of_use": {}} +def test_dont_auto_approve_if_no_dollar_value_specified(): + new_request = RequestFactory.create(initial_revision={}) request = Requests.submit(new_request) assert request.status == RequestStatus.PENDING_CCPO_APPROVAL -def test_should_allow_submission(new_request): +def test_should_allow_submission(): + new_request = RequestFactory.create() + assert Requests.should_allow_submission(new_request) RequestStatusEventFactory.create( @@ -66,7 +70,8 @@ def test_should_allow_submission(new_request): ) assert Requests.should_allow_submission(new_request) - del new_request.body["details_of_use"] + # new, blank revision + RequestRevisionFactory.create(request=new_request) assert not Requests.should_allow_submission(new_request) diff --git a/tests/factories.py b/tests/factories.py index 51be428b..4a90649b 100644 --- a/tests/factories.py +++ b/tests/factories.py @@ -2,9 +2,11 @@ import random import string import factory from uuid import uuid4 +import datetime from atst.forms.data import SERVICE_BRANCHES from atst.models.request import Request +from atst.models.request_revision import RequestRevision from atst.models.request_status_event import RequestStatusEvent, RequestStatus from atst.models.pe_number import PENumber from atst.models.task_order import TaskOrder @@ -42,6 +44,13 @@ class RequestStatusEventFactory(factory.alchemy.SQLAlchemyModelFactory): sequence = 1 +class RequestRevisionFactory(factory.alchemy.SQLAlchemyModelFactory): + class Meta: + model = RequestRevision + + id = factory.Sequence(lambda x: uuid4()) + + class RequestFactory(factory.alchemy.SQLAlchemyModelFactory): class Meta: model = Request @@ -51,48 +60,58 @@ class RequestFactory(factory.alchemy.SQLAlchemyModelFactory): RequestStatusEventFactory, "request", new_status=RequestStatus.STARTED ) creator = factory.SubFactory(UserFactory) - body = factory.LazyAttribute(lambda r: RequestFactory.build_request_body(r.creator)) + revisions = factory.LazyAttribute( + lambda r: [RequestFactory.create_initial_revision(r)] + ) + + class Params: + initial_revision = None @classmethod - def build_request_body(cls, user, dollar_value=1000000): - return { - "primary_poc": { - "am_poc": False, - "dodid_poc": user.dod_id, - "email_poc": user.email, - "fname_poc": user.first_name, - "lname_poc": user.last_name, - }, - "details_of_use": { - "jedi_usage": "adf", - "start_date": "2018-08-08", - "cloud_native": "yes", - "dollar_value": dollar_value, - "dod_component": SERVICE_BRANCHES[2][1], - "data_transfers": "Less than 100GB", - "expected_completion_date": "Less than 1 month", - "jedi_migration": "yes", - "num_software_systems": 1, - "number_user_sessions": 2, - "average_daily_traffic": 1, - "engineering_assessment": "yes", - "technical_support_team": "yes", - "estimated_monthly_spend": 100, - "average_daily_traffic_gb": 4, - "rationalization_software_systems": "yes", - "organization_providing_assistance": "In-house staff", - }, - "information_about_you": { - "citizenship": "United States", - "designation": "military", - "phone_number": "1234567890", - "email_request": user.email, - "fname_request": user.first_name, - "lname_request": user.last_name, - "service_branch": SERVICE_BRANCHES[1][1], - "date_latest_training": "2018-08-06", - }, - } + def create_initial_revision(cls, request, dollar_value=1000000): + user = request.creator + default_data = dict( + am_poc=False, + dodid_poc=user.dod_id, + email_poc=user.email, + fname_poc=user.first_name, + lname_poc=user.last_name, + jedi_usage="adf", + start_date=datetime.datetime(2018, 8, 8, tzinfo=datetime.timezone.utc), + cloud_native="yes", + dollar_value=dollar_value, + dod_component=SERVICE_BRANCHES[2][1], + data_transfers="Less than 100GB", + expected_completion_date="Less than 1 month", + jedi_migration="yes", + num_software_systems=1, + number_user_sessions=2, + average_daily_traffic=1, + engineering_assessment="yes", + technical_support_team="yes", + estimated_monthly_spend=100, + average_daily_traffic_gb=4, + rationalization_software_systems="yes", + organization_providing_assistance="In-house staff", + citizenship="United States", + designation="military", + phone_number="1234567890", + email_request=user.email, + fname_request=user.first_name, + lname_request=user.last_name, + service_branch=SERVICE_BRANCHES[1][1], + date_latest_training=datetime.datetime( + 2018, 8, 6, tzinfo=datetime.timezone.utc + ), + ) + + data = ( + request.initial_revision + if request.initial_revision is not None + else default_data + ) + + return RequestRevisionFactory.build(**data) class PENumberFactory(factory.alchemy.SQLAlchemyModelFactory): diff --git a/tests/mocks.py b/tests/mocks.py index f374b792..e61d80c1 100644 --- a/tests/mocks.py +++ b/tests/mocks.py @@ -2,9 +2,7 @@ from tests.factories import RequestFactory, UserFactory MOCK_USER = UserFactory.build() -MOCK_REQUEST = RequestFactory.build( - creator=MOCK_USER.id, body={"financial_verification": {"pe_id": "0203752A"}} -) +MOCK_REQUEST = RequestFactory.build(creator=MOCK_USER) DOD_SDN_INFO = {"first_name": "ART", "last_name": "GARFUNKEL", "dod_id": "5892460358"} DOD_SDN = f"CN={DOD_SDN_INFO['last_name']}.{DOD_SDN_INFO['first_name']}.G.{DOD_SDN_INFO['dod_id']},OU=OTHER,OU=PKI,OU=DoD,O=U.S. Government,C=US" diff --git a/tests/routes/test_financial_verification.py b/tests/routes/test_financial_verification.py index af1f0a09..527f4ffc 100644 --- a/tests/routes/test_financial_verification.py +++ b/tests/routes/test_financial_verification.py @@ -35,7 +35,7 @@ class TestPENumberInForm: return user def submit_data(self, client, user, data, extended=False): - request = RequestFactory.create(creator=user, body=MOCK_REQUEST.body) + request = RequestFactory.create(creator=user) url_kwargs = {"request_id": request.id} if extended: url_kwargs["extended"] = True @@ -58,7 +58,7 @@ class TestPENumberInForm: user = self._set_monkeypatches(monkeypatch) data = dict(self.required_data) - data["pe_id"] = MOCK_REQUEST.body["financial_verification"]["pe_id"] + data["pe_id"] = "0101110F" response = self.submit_data(client, user, data) @@ -95,7 +95,7 @@ class TestPENumberInForm: user_session(user) data = dict(self.required_data) - data["pe_id"] = MOCK_REQUEST.body["financial_verification"]["pe_id"] + data["pe_id"] = "0101110F" data["task_order_number"] = "1234" response = self.submit_data(client, user, data) @@ -112,7 +112,7 @@ class TestPENumberInForm: user_session(user) data = dict(self.required_data) - data["pe_id"] = MOCK_REQUEST.body["financial_verification"]["pe_id"] + data["pe_id"] = "0101110F" data["task_order_number"] = MockEDAClient.MOCK_CONTRACT_NUMBER response = self.submit_data(client, user, data) diff --git a/tests/routes/test_request_new.py b/tests/routes/test_request_new.py index 1b41eef5..594ea114 100644 --- a/tests/routes/test_request_new.py +++ b/tests/routes/test_request_new.py @@ -1,5 +1,5 @@ import re -from tests.factories import RequestFactory, UserFactory +from tests.factories import RequestFactory, UserFactory, RequestRevisionFactory from atst.domain.roles import Roles from atst.domain.requests import Requests from urllib.parse import urlencode @@ -75,10 +75,12 @@ def test_nonexistent_request(client, user_session): assert response.status_code == 404 -def test_creator_info_is_autopopulated(monkeypatch, client, user_session): +def test_creator_info_is_autopopulated_for_existing_request( + monkeypatch, client, user_session +): user = UserFactory.create() user_session(user) - request = RequestFactory.create(creator=user, body={"information_about_you": {}}) + request = RequestFactory.create(creator=user, initial_revision={}) response = client.get("/requests/new/2/{}".format(request.id)) body = response.data.decode() @@ -104,7 +106,7 @@ def test_non_creator_info_is_not_autopopulated(monkeypatch, client, user_session user = UserFactory.create() creator = UserFactory.create() user_session(user) - request = RequestFactory.create(creator=creator, body={"information_about_you": {}}) + request = RequestFactory.create(creator=creator, initial_revision={}) response = client.get("/requests/new/2/{}".format(request.id)) body = response.data.decode() @@ -116,7 +118,7 @@ def test_non_creator_info_is_not_autopopulated(monkeypatch, client, user_session def test_am_poc_causes_poc_to_be_autopopulated(client, user_session): creator = UserFactory.create() user_session(creator) - request = RequestFactory.create(creator=creator, body={}) + request = RequestFactory.create(creator=creator, initial_revision={}) client.post( "/requests/new/3/{}".format(request.id), headers={"Content-Type": "application/x-www-form-urlencoded"}, @@ -129,7 +131,7 @@ def test_am_poc_causes_poc_to_be_autopopulated(client, user_session): def test_not_am_poc_requires_poc_info_to_be_completed(client, user_session): creator = UserFactory.create() user_session(creator) - request = RequestFactory.create(creator=creator, body={}) + request = RequestFactory.create(creator=creator, initial_revision={}) response = client.post( "/requests/new/3/{}".format(request.id), headers={"Content-Type": "application/x-www-form-urlencoded"}, @@ -142,7 +144,7 @@ def test_not_am_poc_requires_poc_info_to_be_completed(client, user_session): def test_not_am_poc_allows_user_to_fill_in_poc_info(client, user_session): creator = UserFactory.create() user_session(creator) - request = RequestFactory.create(creator=creator, body={}) + request = RequestFactory.create(creator=creator, initial_revision={}) poc_input = { "am_poc": "no", "fname_poc": "test", @@ -177,13 +179,11 @@ def test_poc_autofill_checks_information_about_you_form_first(client, user_sessi user_session(creator) request = RequestFactory.create( creator=creator, - body={ - "information_about_you": { - "fname_request": "Alice", - "lname_request": "Adams", - "email_request": "alice.adams@mail.mil", - } - }, + initial_revision=dict( + fname_request="Alice", + lname_request="Adams", + email_request="alice.adams@mail.mil", + ), ) poc_input = {"am_poc": "yes"} client.post( diff --git a/tests/test_integration.py b/tests/test_integration.py index 8b674f9b..65236100 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -15,7 +15,8 @@ def screens(app): def test_stepthrough_request_form(user_session, screens, client): user = UserFactory.create() user_session(user) - mock_request = RequestFactory.stub() + mock_request = RequestFactory.create() + mock_body = mock_request.body def post_form(url, redirects=False, data=""): return client.post( @@ -33,6 +34,7 @@ def test_stepthrough_request_form(user_session, screens, client): # destination url prelim_resp = post_form(req_url, data=data) response = post_form(req_url, True, data=data) + assert prelim_resp.status_code == 302 return (prelim_resp.headers.get("Location"), response) # GET the initial form @@ -44,7 +46,7 @@ def test_stepthrough_request_form(user_session, screens, client): for i in range(1, len(screens)): # get appropriate form data to POST for this section section = screens[i - 1]["section"] - post_data = urlencode(mock_request.body[section]) + post_data = urlencode(mock_body[section]) effective_url, resp = take_a_step(i, req=req_id, data=post_data) req_id = effective_url.split("/")[-1] @@ -55,7 +57,7 @@ def test_stepthrough_request_form(user_session, screens, client): # at this point, the real request we made and the mock_request bodies # should be equivalent - assert Requests.get(user, req_id).body == mock_request.body + assert Requests.get(user, req_id).body == mock_body # finish the review and submit step client.post(