Fix some LGTM errors and start sketching in credential update functionality

This commit is contained in:
tomdds 2020-01-24 11:12:06 -05:00
parent 910920af44
commit e9d03ec68b
3 changed files with 15 additions and 13 deletions

View File

@ -380,7 +380,7 @@ class CloudProviderInterface:
def set_secret(self, secret_key: str, secret_value: str): def set_secret(self, secret_key: str, secret_value: str):
raise NotImplementedError() raise NotImplementedError()
def get_secret(self, secret_key: str, secret_value: str): def get_secret(self, secret_key: str):
raise NotImplementedError() raise NotImplementedError()
def root_creds(self) -> Dict: def root_creds(self) -> Dict:
@ -520,6 +520,12 @@ class MockCloudProvider(CloudProviderInterface):
def root_creds(self): def root_creds(self):
return self._auth_credentials return self._auth_credentials
def set_secret(self, secret_key: str, secret_value: str):
pass
def get_secret(self, secret_key: str):
return {}
def create_environment(self, auth_credentials, user, environment): def create_environment(self, auth_credentials, user, environment):
self._authorize(auth_credentials) self._authorize(auth_credentials)
@ -846,7 +852,7 @@ class AzureCloudProvider(CloudProviderInterface):
self.policy_manager = AzurePolicyManager(config["AZURE_POLICY_LOCATION"]) self.policy_manager = AzurePolicyManager(config["AZURE_POLICY_LOCATION"])
def set_secret(secret_key, secret_value): def set_secret(self, secret_key, secret_value):
credential = self._get_client_secret_credential_obj() credential = self._get_client_secret_credential_obj()
secret_client = self.secrets.SecretClient( secret_client = self.secrets.SecretClient(
vault_url=self.vault_url, credential=credential, vault_url=self.vault_url, credential=credential,
@ -859,7 +865,7 @@ class AzureCloudProvider(CloudProviderInterface):
exc_info=1, exc_info=1,
) )
def get_secret(secret_key): def get_secret(self, secret_key):
credential = self._get_client_secret_credential_obj() credential = self._get_client_secret_credential_obj()
secret_client = self.secrets.SecretClient( secret_client = self.secrets.SecretClient(
vault_url=self.vault_url, credential=credential, vault_url=self.vault_url, credential=credential,

View File

@ -1,5 +1,4 @@
from random import choice, choices from random import choice, choices
import re
import string import string
from sqlalchemy import Column, ForeignKey, Enum as SQLAEnum from sqlalchemy import Column, ForeignKey, Enum as SQLAEnum
@ -119,7 +118,6 @@ class PortfolioStateMachine(
elif state_obj.is_CREATED: elif state_obj.is_CREATED:
# the create trigger for the next stage should be in the available # the create trigger for the next stage should be in the available
# triggers for the current state # triggers for the current state
triggers = self.machine.get_triggers(state_obj.name)
create_trigger = next( create_trigger = next(
filter( filter(
lambda trigger: trigger.startswith("create_"), lambda trigger: trigger.startswith("create_"),
@ -205,11 +203,10 @@ class PortfolioStateMachine(
dc = cls(**stage_data) dc = cls(**stage_data)
if getattr(dc, "get_creds", None) is not None: if getattr(dc, "get_creds", None) is not None:
new_creds = dc.get_creds() new_creds = dc.get_creds()
print("creds to report") tenant_id = new_creds.get("tenant_id")
print(new_creds) secret = self.csp.get_secret(tenant_id)
# TODO: how/where to store these secret.update(new_creds)
# TODO: credential schema self.csp.set_secret(tenant_id, secret)
# self.store_creds(self.portfolio, new_creds)
except PydanticValidationError as exc: except PydanticValidationError as exc:
app.logger.error( app.logger.error(

View File

@ -2,7 +2,6 @@ import pytest
import re import re
from tests.factories import ( from tests.factories import (
PortfolioFactory,
PortfolioStateMachineFactory, PortfolioStateMachineFactory,
TaskOrderFactory, TaskOrderFactory,
CLINFactory, CLINFactory,
@ -36,7 +35,7 @@ def test_state_machine_trigger_next_transition(portfolio):
def test_state_machine_compose_state(portfolio): def test_state_machine_compose_state(portfolio):
sm = PortfolioStateMachineFactory.create(portfolio=portfolio) PortfolioStateMachineFactory.create(portfolio=portfolio)
assert ( assert (
compose_state(AzureStages.TENANT, StageStates.CREATED) compose_state(AzureStages.TENANT, StageStates.CREATED)
== FSMStates.TENANT_CREATED == FSMStates.TENANT_CREATED
@ -44,7 +43,7 @@ def test_state_machine_compose_state(portfolio):
def test_state_machine_valid_data_classes_for_stages(portfolio): def test_state_machine_valid_data_classes_for_stages(portfolio):
sm = PortfolioStateMachineFactory.create(portfolio=portfolio) PortfolioStateMachineFactory.create(portfolio=portfolio)
for stage in AzureStages: for stage in AzureStages:
assert get_stage_csp_class(stage.name.lower(), "payload") is not None assert get_stage_csp_class(stage.name.lower(), "payload") is not None
assert get_stage_csp_class(stage.name.lower(), "result") is not None assert get_stage_csp_class(stage.name.lower(), "result") is not None