diff --git a/atst/routes/task_orders/new.py b/atst/routes/task_orders/new.py index e1f01fe8..e3b24ba4 100644 --- a/atst/routes/task_orders/new.py +++ b/atst/routes/task_orders/new.py @@ -46,10 +46,11 @@ TASK_ORDER_SECTIONS = [ class ShowTaskOrderWorkflow: - def __init__(self, user, screen=1, task_order_id=None): + def __init__(self, user, screen=1, task_order_id=None, portfolio_id=None): self.user = user self.screen = screen self.task_order_id = task_order_id + self.portfolio_id = portfolio_id self._section = TASK_ORDER_SECTIONS[screen - 1] self._task_order = None self._form = None @@ -120,19 +121,19 @@ class ShowTaskOrderWorkflow: else: return False - def pf_attributes_read_only(self, portfolio_id=None): + def pf_attributes_read_only(self): if self.task_order: if self.task_order.portfolio.num_task_orders > 1: return True - elif portfolio_id: - if self.get_portfolio(portfolio_id).num_task_orders > 0: + elif self.portfolio_id: + if self.get_portfolio().num_task_orders > 0: return True return False - def get_portfolio(self, portfolio_id=None): + def get_portfolio(self): if self.task_order: return self.task_order.portfolio - return Portfolios.get(self.user, portfolio_id) + return Portfolios.get(self.user, self.portfolio_id) class UpdateTaskOrderWorkflow(ShowTaskOrderWorkflow): @@ -155,7 +156,7 @@ class UpdateTaskOrderWorkflow(ShowTaskOrderWorkflow): @property def form(self): - if self.pf_attributes_read_only(self.portfolio_id) and self.screen == 1: + if self.pf_attributes_read_only() and self.screen == 1: return task_order_form.AppInfoWithExistingPortfolioForm(self.form_data) return self._form @@ -226,7 +227,9 @@ def get_started(): @task_orders_bp.route("/task_orders/new//") @task_orders_bp.route("/portfolios//task_orders/new/") def new(screen, task_order_id=None, portfolio_id=None): - workflow = ShowTaskOrderWorkflow(g.current_user, screen, task_order_id) + workflow = ShowTaskOrderWorkflow( + g.current_user, screen, task_order_id, portfolio_id + ) template_args = { "current": screen, "task_order_id": task_order_id, @@ -235,8 +238,8 @@ def new(screen, task_order_id=None, portfolio_id=None): "complete": workflow.is_complete, } - if workflow.pf_attributes_read_only(portfolio_id): - template_args["portfolio"] = workflow.get_portfolio(portfolio_id=portfolio_id) + if workflow.pf_attributes_read_only(): + template_args["portfolio"] = workflow.get_portfolio() if screen == 1: workflow.form = task_order_form.AppInfoWithExistingPortfolioForm( obj=workflow.task_order diff --git a/tests/routes/task_orders/test_new_task_order.py b/tests/routes/task_orders/test_new_task_order.py index 0cdbf432..4793be65 100644 --- a/tests/routes/task_orders/test_new_task_order.py +++ b/tests/routes/task_orders/test_new_task_order.py @@ -44,8 +44,8 @@ def serialize_dates(data): def test_new_to_can_edit_pf_attributes_screen_1(): portfolio = PortfolioFactory.create() - workflow = ShowTaskOrderWorkflow(user=portfolio.owner) - assert not workflow.pf_attributes_read_only(portfolio.id) + workflow = ShowTaskOrderWorkflow(user=portfolio.owner, portfolio_id=portfolio.id) + assert not workflow.pf_attributes_read_only() def test_new_pf_can_edit_pf_attributes_on_back_navigation(): @@ -61,9 +61,9 @@ def test_to_on_pf_cannot_edit_pf_attributes(): portfolio = PortfolioFactory.create() pf_task_order = TaskOrderFactory(portfolio=portfolio) - workflow = ShowTaskOrderWorkflow(user=portfolio.owner) + workflow = ShowTaskOrderWorkflow(user=portfolio.owner, portfolio_id=portfolio.id) assert portfolio.num_task_orders == 1 - assert workflow.pf_attributes_read_only(portfolio.id) + assert workflow.pf_attributes_read_only() second_task_order = TaskOrderFactory(portfolio=portfolio) second_workflow = ShowTaskOrderWorkflow( @@ -87,8 +87,8 @@ def test_get_portfolio_when_task_order_exists(): def test_get_portfolio_with_portfolio_id(): user = UserFactory.create() portfolio = PortfolioFactory.create(owner=user) - workflow = ShowTaskOrderWorkflow(user=portfolio.owner) - assert portfolio == workflow.get_portfolio(portfolio_id=portfolio.id) + workflow = ShowTaskOrderWorkflow(user=portfolio.owner, portfolio_id=portfolio.id) + assert portfolio == workflow.get_portfolio() # TODO: this test will need to be more complicated when we add validation to