Format project

This commit is contained in:
richard-dds 2018-08-23 16:25:36 -04:00
parent e9fa4d9ecb
commit daa8634cb4
48 changed files with 415 additions and 282 deletions

View File

@ -85,7 +85,9 @@ def map_config(config):
"SQLALCHEMY_DATABASE_URI": config["default"]["DATABASE_URI"], "SQLALCHEMY_DATABASE_URI": config["default"]["DATABASE_URI"],
"SQLALCHEMY_TRACK_MODIFICATIONS": False, "SQLALCHEMY_TRACK_MODIFICATIONS": False,
"WTF_CSRF_ENABLED": config.getboolean("default", "WTF_CSRF_ENABLED"), "WTF_CSRF_ENABLED": config.getboolean("default", "WTF_CSRF_ENABLED"),
"PERMANENT_SESSION_LIFETIME": config.getint("default", "PERMANENT_SESSION_LIFETIME"), "PERMANENT_SESSION_LIFETIME": config.getint(
"default", "PERMANENT_SESSION_LIFETIME"
),
} }
@ -127,8 +129,10 @@ def make_config():
return map_config(config) return map_config(config)
def make_redis(config): def make_redis(config):
return redis.Redis.from_url(config['REDIS_URI']) return redis.Redis.from_url(config["REDIS_URI"])
def make_crl_validator(app): def make_crl_validator(app):
crl_locations = [] crl_locations = []
@ -136,5 +140,6 @@ def make_crl_validator(app):
crl_locations.append(filename.absolute()) crl_locations.append(filename.absolute())
app.crl_cache = CRLCache(app.config["CA_CHAIN"], crl_locations, logger=app.logger) app.crl_cache = CRLCache(app.config["CA_CHAIN"], crl_locations, logger=app.logger)
def make_eda_client(app): def make_eda_client(app):
app.eda_client = MockEDAClient() app.eda_client = MockEDAClient()

View File

@ -3,14 +3,10 @@ from flask_assets import Environment, Bundle
environment = Environment() environment = Environment()
css = Bundle( css = Bundle(
"../static/assets/index.css", "../static/assets/index.css", output="../static/assets/index.%(version)s.css"
output="../static/assets/index.%(version)s.css",
) )
environment.register("css", css) environment.register("css", css)
js = Bundle( js = Bundle("../static/assets/index.js", output="../static/assets/index.%(version)s.js")
'../static/assets/index.js', environment.register("js_all", js)
output='../static/assets/index.%(version)s.js'
)
environment.register('js_all', js)

View File

@ -3,7 +3,14 @@ from flask import g, redirect, url_for, session, request
from atst.domain.users import Users from atst.domain.users import Users
UNPROTECTED_ROUTES = ["atst.root", "dev.login_dev", "atst.login_redirect", "atst.unauthorized", "static"] UNPROTECTED_ROUTES = [
"atst.root",
"dev.login_dev",
"atst.login_redirect",
"atst.unauthorized",
"static",
]
def apply_authentication(app): def apply_authentication(app):
@app.before_request @app.before_request
@ -26,7 +33,7 @@ def get_current_user():
else: else:
return False return False
def _unprotected_route(request): def _unprotected_route(request):
if request.endpoint in UNPROTECTED_ROUTES: if request.endpoint in UNPROTECTED_ROUTES:
return True return True

View File

@ -4,8 +4,7 @@ from .utils import parse_sdn, email_from_certificate
from .crl import CRLRevocationException from .crl import CRLRevocationException
class AuthenticationContext(): class AuthenticationContext:
def __init__(self, crl_cache, auth_status, sdn, cert): def __init__(self, crl_cache, auth_status, sdn, cert):
if None in locals().values(): if None in locals().values():
raise UnauthenticatedError( raise UnauthenticatedError(

View File

@ -9,14 +9,16 @@ class CRLRevocationException(Exception):
pass pass
class CRLCache(): class CRLCache:
_PEM_RE = re.compile( _PEM_RE = re.compile(
b"-----BEGIN CERTIFICATE-----\r?.+?\r?-----END CERTIFICATE-----\r?\n?", b"-----BEGIN CERTIFICATE-----\r?.+?\r?-----END CERTIFICATE-----\r?\n?",
re.DOTALL, re.DOTALL,
) )
def __init__(self, root_location, crl_locations=[], store_class=crypto.X509Store, logger=None): def __init__(
self, root_location, crl_locations=[], store_class=crypto.X509Store, logger=None
):
self.store_class = store_class self.store_class = store_class
self.certificate_authorities = {} self.certificate_authorities = {}
self._load_roots(root_location) self._load_roots(root_location)
@ -57,7 +59,11 @@ class CRLCache():
with open(crl_location, "rb") as crl_file: with open(crl_location, "rb") as crl_file:
crl = crypto.load_crl(crypto.FILETYPE_ASN1, crl_file.read()) crl = crypto.load_crl(crypto.FILETYPE_ASN1, crl_file.read())
store.add_crl(crl) store.add_crl(crl)
self.log_info("STORE ID: {}. Adding CRL with issuer {}".format(id(store), crl.get_issuer())) self.log_info(
"STORE ID: {}. Adding CRL with issuer {}".format(
id(store), crl.get_issuer()
)
)
store = self._add_certificate_chain_to_store(store, crl.get_issuer()) store = self._add_certificate_chain_to_store(store, crl.get_issuer())
return store return store
@ -75,7 +81,11 @@ class CRLCache():
def _add_certificate_chain_to_store(self, store, issuer): def _add_certificate_chain_to_store(self, store, issuer):
ca = self.certificate_authorities.get(issuer.der()) ca = self.certificate_authorities.get(issuer.der())
store.add_cert(ca) store.add_cert(ca)
self.log_info("STORE ID: {}. Adding CA with subject {}".format(id(store), ca.get_subject())) self.log_info(
"STORE ID: {}. Adding CA with subject {}".format(
id(store), ca.get_subject()
)
)
if issuer == ca.get_issuer(): if issuer == ca.get_issuer():
# i.e., it is the root CA and we are at the end of the chain # i.e., it is the root CA and we are at the end of the chain

View File

@ -25,7 +25,15 @@ def email_from_certificate(cert_file):
return email[0] return email[0]
else: else:
raise ValueError("No email available for certificate with serial {}".format(cert.serial_number)) raise ValueError(
"No email available for certificate with serial {}".format(
cert.serial_number
)
)
except x509.extensions.ExtensionNotFound: except x509.extensions.ExtensionNotFound:
raise ValueError("No subjectAltName available for certificate with serial {}".format(cert.serial_number)) raise ValueError(
"No subjectAltName available for certificate with serial {}".format(
cert.serial_number
)
)

View File

@ -6,7 +6,6 @@ from .exceptions import NotFoundError
class PENumbers(object): class PENumbers(object):
@classmethod @classmethod
def get(cls, number): def get(cls, number):
pe_number = db.session.query(PENumber).get(number) pe_number = db.session.query(PENumber).get(number)

View File

@ -73,9 +73,10 @@ class Requests(object):
filters.append(Request.creator == creator) filters.append(Request.creator == creator)
requests = ( requests = (
db.session.query(Request).filter(*filters).order_by( db.session.query(Request)
Request.time_created.desc() .filter(*filters)
).all() .order_by(Request.time_created.desc())
.all()
) )
return requests return requests
@ -113,9 +114,10 @@ class Requests(object):
# Query for request matching id, acquiring a row-level write lock. # Query for request matching id, acquiring a row-level write lock.
# https://www.postgresql.org/docs/10/static/sql-select.html#SQL-FOR-UPDATE-SHARE # https://www.postgresql.org/docs/10/static/sql-select.html#SQL-FOR-UPDATE-SHARE
return ( return (
db.session.query(Request).filter_by(id=request_id).with_for_update( db.session.query(Request)
of=Request .filter_by(id=request_id)
).one() .with_for_update(of=Request)
.one()
) )
except NoResultFound: except NoResultFound:
@ -153,9 +155,7 @@ class Requests(object):
RequestStatus.STARTED: "mission_owner", RequestStatus.STARTED: "mission_owner",
RequestStatus.PENDING_FINANCIAL_VERIFICATION: "mission_owner", RequestStatus.PENDING_FINANCIAL_VERIFICATION: "mission_owner",
RequestStatus.PENDING_CCPO_APPROVAL: "ccpo", RequestStatus.PENDING_CCPO_APPROVAL: "ccpo",
}.get( }.get(request.status)
request.status
)
@classmethod @classmethod
def should_auto_approve(cls, request): def should_auto_approve(cls, request):
@ -167,13 +167,16 @@ class Requests(object):
return dollar_value < cls.AUTO_APPROVE_THRESHOLD return dollar_value < cls.AUTO_APPROVE_THRESHOLD
_VALID_SUBMISSION_STATUSES = [ _VALID_SUBMISSION_STATUSES = [
RequestStatus.STARTED, RequestStatus.CHANGES_REQUESTED RequestStatus.STARTED,
RequestStatus.CHANGES_REQUESTED,
] ]
@classmethod @classmethod
def should_allow_submission(cls, request): def should_allow_submission(cls, request):
all_request_sections = [ all_request_sections = [
"details_of_use", "information_about_you", "primary_poc" "details_of_use",
"information_about_you",
"primary_poc",
] ]
existing_request_sections = request.body.keys() existing_request_sections = request.body.keys()
return request.status in Requests._VALID_SUBMISSION_STATUSES and all( return request.status in Requests._VALID_SUBMISSION_STATUSES and all(

View File

@ -6,7 +6,6 @@ from .exceptions import NotFoundError
class Roles(object): class Roles(object):
@classmethod @classmethod
def get(cls, role_name): def get(cls, role_name):
try: try:

View File

@ -9,7 +9,6 @@ from .exceptions import NotFoundError, AlreadyExistsError
class Users(object): class Users(object):
@classmethod @classmethod
def get(cls, user_id): def get(cls, user_id):
try: try:

View File

@ -11,7 +11,6 @@ from .exceptions import NotFoundError
class WorkspaceUsers(object): class WorkspaceUsers(object):
@classmethod @classmethod
def get(cls, workspace_id, user_id): def get(cls, workspace_id, user_id):
try: try:

View File

@ -64,8 +64,6 @@ class Workspaces(object):
@classmethod @classmethod
def _create_workspace_role(cls, user, workspace, role_name): def _create_workspace_role(cls, user, workspace, role_name):
role = Roles.get(role_name) role = Roles.get(role_name)
workspace_role = WorkspaceRole( workspace_role = WorkspaceRole(user=user, role=role, workspace=workspace)
user=user, role=role, workspace=workspace
)
db.session.add(workspace_role) db.session.add(workspace_role)
return workspace_role return workspace_role

View File

@ -1,7 +1,8 @@
import re import re
def iconSvg(name): def iconSvg(name):
with open('static/icons/'+name+'.svg') as contents: with open("static/icons/" + name + ".svg") as contents:
return contents.read() return contents.read()
@ -14,8 +15,8 @@ def dollars(value):
def usPhone(number): def usPhone(number):
phone = re.sub(r'\D', '', number) phone = re.sub(r"\D", "", number)
return '+1 ({}) {} - {}'.format(phone[0:3], phone[3:6], phone[6:]) return "+1 ({}) {} - {}".format(phone[0:3], phone[3:6], phone[6:])
def readableInteger(value): def readableInteger(value):
@ -31,9 +32,8 @@ def getOptionLabel(value, options):
def register_filters(app): def register_filters(app):
app.jinja_env.filters['iconSvg'] = iconSvg app.jinja_env.filters["iconSvg"] = iconSvg
app.jinja_env.filters['dollars'] = dollars app.jinja_env.filters["dollars"] = dollars
app.jinja_env.filters['usPhone'] = usPhone app.jinja_env.filters["usPhone"] = usPhone
app.jinja_env.filters['readableInteger'] = readableInteger app.jinja_env.filters["readableInteger"] = readableInteger
app.jinja_env.filters['getOptionLabel'] = getOptionLabel app.jinja_env.filters["getOptionLabel"] = getOptionLabel

View File

@ -3,7 +3,10 @@ SERVICE_BRANCHES = [
("Air Force, Department of the", "Air Force, Department of the"), ("Air Force, Department of the", "Air Force, Department of the"),
("Army and Air Force Exchange Service", "Army and Air Force Exchange Service"), ("Army and Air Force Exchange Service", "Army and Air Force Exchange Service"),
("Army, Department of the", "Army, Department of the"), ("Army, Department of the", "Army, Department of the"),
("Defense Advanced Research Projects Agency", "Defense Advanced Research Projects Agency"), (
"Defense Advanced Research Projects Agency",
"Defense Advanced Research Projects Agency",
),
("Defense Commissary Agency", "Defense Commissary Agency"), ("Defense Commissary Agency", "Defense Commissary Agency"),
("Defense Contract Audit Agency", "Defense Contract Audit Agency"), ("Defense Contract Audit Agency", "Defense Contract Audit Agency"),
("Defense Contract Management Agency", "Defense Contract Management Agency"), ("Defense Contract Management Agency", "Defense Contract Management Agency"),
@ -19,31 +22,55 @@ SERVICE_BRANCHES = [
("Defense Security Cooperation Agency", "Defense Security Cooperation Agency"), ("Defense Security Cooperation Agency", "Defense Security Cooperation Agency"),
("Defense Security Service", "Defense Security Service"), ("Defense Security Service", "Defense Security Service"),
("Defense Technical Information Center", "Defense Technical Information Center"), ("Defense Technical Information Center", "Defense Technical Information Center"),
("Defense Technology Security Administration", "Defense Technology Security Administration"), (
"Defense Technology Security Administration",
"Defense Technology Security Administration",
),
("Defense Threat Reduction Agency", "Defense Threat Reduction Agency"), ("Defense Threat Reduction Agency", "Defense Threat Reduction Agency"),
("DoD Education Activity", "DoD Education Activity"), ("DoD Education Activity", "DoD Education Activity"),
("DoD Human Recourses Activity", "DoD Human Recourses Activity"), ("DoD Human Recourses Activity", "DoD Human Recourses Activity"),
("DoD Inspector General", "DoD Inspector General"), ("DoD Inspector General", "DoD Inspector General"),
("DoD Test Resource Management Center", "DoD Test Resource Management Center"), ("DoD Test Resource Management Center", "DoD Test Resource Management Center"),
("Headquarters Defense Human Resource Activity ", "Headquarters Defense Human Resource Activity "), (
"Headquarters Defense Human Resource Activity ",
"Headquarters Defense Human Resource Activity ",
),
("Joint Staff", "Joint Staff"), ("Joint Staff", "Joint Staff"),
("Missile Defense Agency", "Missile Defense Agency"), ("Missile Defense Agency", "Missile Defense Agency"),
("National Defense University", "National Defense University"), ("National Defense University", "National Defense University"),
("National Geospatial Intelligence Agency (NGA)", "National Geospatial Intelligence Agency (NGA)"), (
("National Oceanic and Atmospheric Administration (NOAA)", "National Oceanic and Atmospheric Administration (NOAA)"), "National Geospatial Intelligence Agency (NGA)",
"National Geospatial Intelligence Agency (NGA)",
),
(
"National Oceanic and Atmospheric Administration (NOAA)",
"National Oceanic and Atmospheric Administration (NOAA)",
),
("National Reconnaissance Office", "National Reconnaissance Office"), ("National Reconnaissance Office", "National Reconnaissance Office"),
("National Reconnaissance Office (NRO)", "National Reconnaissance Office (NRO)"), ("National Reconnaissance Office (NRO)", "National Reconnaissance Office (NRO)"),
("National Security Agency (NSA)", "National Security Agency (NSA)"), ("National Security Agency (NSA)", "National Security Agency (NSA)"),
("National Security Agency-Central Security Service", "National Security Agency-Central Security Service"), (
"National Security Agency-Central Security Service",
"National Security Agency-Central Security Service",
),
("Navy, Department of the", "Navy, Department of the"), ("Navy, Department of the", "Navy, Department of the"),
("Office of Economic Adjustment", "Office of Economic Adjustment"), ("Office of Economic Adjustment", "Office of Economic Adjustment"),
("Office of the Secretary of Defense", "Office of the Secretary of Defense"), ("Office of the Secretary of Defense", "Office of the Secretary of Defense"),
("Pentagon Force Protection Agency", "Pentagon Force Protection Agency"), ("Pentagon Force Protection Agency", "Pentagon Force Protection Agency"),
("Uniform Services University of the Health Sciences", "Uniform Services University of the Health Sciences"), (
"Uniform Services University of the Health Sciences",
"Uniform Services University of the Health Sciences",
),
("US Cyber Command (USCYBERCOM)", "US Cyber Command (USCYBERCOM)"), ("US Cyber Command (USCYBERCOM)", "US Cyber Command (USCYBERCOM)"),
("US Special Operations Command (USSOCOM)", "US Special Operations Command (USSOCOM)"), (
"US Special Operations Command (USSOCOM)",
"US Special Operations Command (USSOCOM)",
),
("US Strategic Command (USSTRATCOM)", "US Strategic Command (USSTRATCOM)"), ("US Strategic Command (USSTRATCOM)", "US Strategic Command (USSTRATCOM)"),
("US Transportation Command (USTRANSCOM)", "US Transportation Command (USTRANSCOM)"), (
"US Transportation Command (USTRANSCOM)",
"US Transportation Command (USTRANSCOM)",
),
("Washington Headquarters Services", "Washington Headquarters Services"), ("Washington Headquarters Services", "Washington Headquarters Services"),
] ]

View File

@ -24,7 +24,7 @@ class NewlineListField(Field):
def _value(self): def _value(self):
if isinstance(self.data, list): if isinstance(self.data, list):
return '\n'.join(self.data) return "\n".join(self.data)
elif self.data: elif self.data:
return self.data return self.data
else: else:
@ -46,8 +46,5 @@ class NewlineListField(Field):
class SelectField(SelectField_): class SelectField(SelectField_):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
render_kw = kwargs.get("render_kw", {}) render_kw = kwargs.get("render_kw", {})
kwargs["render_kw"] = { kwargs["render_kw"] = {**render_kw, "required": False}
**render_kw,
"required": False
}
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)

View File

@ -26,6 +26,7 @@ TREASURY_CODE_REGEX = re.compile(r"^0*([1-9]{4}|[1-9]{6})$")
BA_CODE_REGEX = re.compile(r"^0*[1-9]{2}\w?$") BA_CODE_REGEX = re.compile(r"^0*[1-9]{2}\w?$")
def suggest_pe_id(pe_id): def suggest_pe_id(pe_id):
suggestion = pe_id suggestion = pe_id
match = PE_REGEX.match(pe_id) match = PE_REGEX.match(pe_id)
@ -94,19 +95,26 @@ class BaseFinancialForm(ValidatedForm):
task_order_number = StringField( task_order_number = StringField(
"Task Order Number associated with this request", "Task Order Number associated with this request",
description="Include the original Task Order number (including the 000X at the end). Do not include any modification numbers. Note that there may be a lag between approving a task order and when it becomes available in our system.", description="Include the original Task Order number (including the 000X at the end). Do not include any modification numbers. Note that there may be a lag between approving a task order and when it becomes available in our system.",
validators=[Required()] validators=[Required()],
) )
uii_ids = NewlineListField( uii_ids = NewlineListField(
"Unique Item Identifier (UII)s related to your application(s) if you already have them", "Unique Item Identifier (UII)s related to your application(s) if you already have them",
validators=[Required()] validators=[Required()],
) )
pe_id = StringField("Program Element (PE) Number related to your request", validators=[Required()]) pe_id = StringField(
"Program Element (PE) Number related to your request", validators=[Required()]
)
treasury_code = StringField("Program Treasury Code", validators=[Required(), Regexp(TREASURY_CODE_REGEX)]) treasury_code = StringField(
"Program Treasury Code", validators=[Required(), Regexp(TREASURY_CODE_REGEX)]
)
ba_code = StringField("Program Budget Activity (BA) Code", validators=[Required(), Regexp(BA_CODE_REGEX)]) ba_code = StringField(
"Program Budget Activity (BA) Code",
validators=[Required(), Regexp(BA_CODE_REGEX)],
)
fname_co = StringField("Contracting Officer First Name", validators=[Required()]) fname_co = StringField("Contracting Officer First Name", validators=[Required()])
lname_co = StringField("Contracting Officer Last Name", validators=[Required()]) lname_co = StringField("Contracting Officer Last Name", validators=[Required()])
@ -160,7 +168,7 @@ class ExtendedFinancialForm(BaseFinancialForm):
("OTHER", "Other"), ("OTHER", "Other"),
], ],
validators=[Required()], validators=[Required()],
render_kw={"required": False} render_kw={"required": False},
) )
funding_type_other = StringField("If other, please specify") funding_type_other = StringField("If other, please specify")
@ -169,40 +177,40 @@ class ExtendedFinancialForm(BaseFinancialForm):
"<dl><dt>CLIN 0001</dt> - <dd>Unclassified IaaS and PaaS Amount</dd></dl>", "<dl><dt>CLIN 0001</dt> - <dd>Unclassified IaaS and PaaS Amount</dd></dl>",
validators=[Required()], validators=[Required()],
description="Review your task order document, the amounts for each CLIN must match exactly here", description="Review your task order document, the amounts for each CLIN must match exactly here",
filters=[number_to_int] filters=[number_to_int],
) )
clin_0003 = StringField( clin_0003 = StringField(
"<dl><dt>CLIN 0003</dt> - <dd>Unclassified Cloud Support Package</dd></dl>", "<dl><dt>CLIN 0003</dt> - <dd>Unclassified Cloud Support Package</dd></dl>",
validators=[Required()], validators=[Required()],
description="Review your task order document, the amounts for each CLIN must match exactly here", description="Review your task order document, the amounts for each CLIN must match exactly here",
filters=[number_to_int] filters=[number_to_int],
) )
clin_1001 = StringField( clin_1001 = StringField(
"<dl><dt>CLIN 1001</dt> - <dd>Unclassified IaaS and PaaS Amount <br> OPTION PERIOD 1</dd></dl>", "<dl><dt>CLIN 1001</dt> - <dd>Unclassified IaaS and PaaS Amount <br> OPTION PERIOD 1</dd></dl>",
validators=[Required()], validators=[Required()],
description="Review your task order document, the amounts for each CLIN must match exactly here", description="Review your task order document, the amounts for each CLIN must match exactly here",
filters=[number_to_int] filters=[number_to_int],
) )
clin_1003 = StringField( clin_1003 = StringField(
"<dl><dt>CLIN 1003</dt> - <dd>Unclassified Cloud Support Package <br> OPTION PERIOD 1</dd></dl>", "<dl><dt>CLIN 1003</dt> - <dd>Unclassified Cloud Support Package <br> OPTION PERIOD 1</dd></dl>",
validators=[Required()], validators=[Required()],
description="Review your task order document, the amounts for each CLIN must match exactly here", description="Review your task order document, the amounts for each CLIN must match exactly here",
filters=[number_to_int] filters=[number_to_int],
) )
clin_2001 = StringField( clin_2001 = StringField(
"<dl><dt>CLIN 2001</dt> - <dd>Unclassified IaaS and PaaS Amount <br> OPTION PERIOD 2</dd></dl>", "<dl><dt>CLIN 2001</dt> - <dd>Unclassified IaaS and PaaS Amount <br> OPTION PERIOD 2</dd></dl>",
validators=[Required()], validators=[Required()],
description="Review your task order document, the amounts for each CLIN must match exactly here", description="Review your task order document, the amounts for each CLIN must match exactly here",
filters=[number_to_int] filters=[number_to_int],
) )
clin_2003 = StringField( clin_2003 = StringField(
"<dl><dt>CLIN 2003</dt> - <dd>Unclassified Cloud Support Package <br> OPTION PERIOD 2</dd></dl>", "<dl><dt>CLIN 2003</dt> - <dd>Unclassified Cloud Support Package <br> OPTION PERIOD 2</dd></dl>",
validators=[Required()], validators=[Required()],
description="Review your task order document, the amounts for each CLIN must match exactly here", description="Review your task order document, the amounts for each CLIN must match exactly here",
filters=[number_to_int] filters=[number_to_int],
) )

View File

@ -16,9 +16,11 @@ class OrgForm(ValidatedForm):
email_request = EmailField("E-mail Address", validators=[Required(), Email()]) email_request = EmailField("E-mail Address", validators=[Required(), Email()])
phone_number = TelField("Phone Number", phone_number = TelField(
description='Enter a 10-digit phone number', "Phone Number",
validators=[Required(), PhoneNumber()]) description="Enter a 10-digit phone number",
validators=[Required(), PhoneNumber()],
)
service_branch = SelectField( service_branch = SelectField(
"Service Branch or Agency", "Service Branch or Agency",
@ -49,7 +51,7 @@ class OrgForm(ValidatedForm):
date_latest_training = DateField( date_latest_training = DateField(
"Latest Information Assurance (IA) Training Completion Date", "Latest Information Assurance (IA) Training Completion Date",
description="To complete the training, you can find it in <a class=\"icon-link\" href=\"https://iatraining.disa.mil/eta/disa_cac2018/launchPage.htm\" target=\"_blank\">Information Assurance Cyber Awareness Challange</a> website.", description='To complete the training, you can find it in <a class="icon-link" href="https://iatraining.disa.mil/eta/disa_cac2018/launchPage.htm" target="_blank">Information Assurance Cyber Awareness Challange</a> website.',
validators=[ validators=[
Required(), Required(),
DateRange( DateRange(

View File

@ -6,7 +6,6 @@ from .validators import IsNumber
class POCForm(ValidatedForm): class POCForm(ValidatedForm):
def validate(self, *args, **kwargs): def validate(self, *args, **kwargs):
if self.am_poc.data: if self.am_poc.data:
# Prepend Optional validators so that the validation chain # Prepend Optional validators so that the validation chain
@ -18,11 +17,10 @@ class POCForm(ValidatedForm):
return super().validate(*args, **kwargs) return super().validate(*args, **kwargs)
am_poc = BooleanField( am_poc = BooleanField(
"I am the Workspace Owner", "I am the Workspace Owner",
default=False, default=False,
false_values=(False, "false", "False", "no", "") false_values=(False, "false", "False", "no", ""),
) )
fname_poc = StringField("First Name", validators=[Required()]) fname_poc = StringField("First Name", validators=[Required()])

View File

@ -4,22 +4,26 @@ from wtforms.validators import Optional, Required
from .fields import DateField, SelectField from .fields import DateField, SelectField
from .forms import ValidatedForm from .forms import ValidatedForm
from .data import SERVICE_BRANCHES, ASSISTANCE_ORG_TYPES, DATA_TRANSFER_AMOUNTS, COMPLETION_DATE_RANGES from .data import (
SERVICE_BRANCHES,
ASSISTANCE_ORG_TYPES,
DATA_TRANSFER_AMOUNTS,
COMPLETION_DATE_RANGES,
)
from atst.domain.requests import Requests from atst.domain.requests import Requests
class RequestForm(ValidatedForm): class RequestForm(ValidatedForm):
def validate(self, *args, **kwargs): def validate(self, *args, **kwargs):
if self.jedi_migration.data == 'no': if self.jedi_migration.data == "no":
self.rationalization_software_systems.validators.append(Optional()) self.rationalization_software_systems.validators.append(Optional())
self.technical_support_team.validators.append(Optional()) self.technical_support_team.validators.append(Optional())
self.organization_providing_assistance.validators.append(Optional()) self.organization_providing_assistance.validators.append(Optional())
self.engineering_assessment.validators.append(Optional()) self.engineering_assessment.validators.append(Optional())
self.data_transfers.validators.append(Optional()) self.data_transfers.validators.append(Optional())
self.expected_completion_date.validators.append(Optional()) self.expected_completion_date.validators.append(Optional())
elif self.jedi_migration.data == 'yes': elif self.jedi_migration.data == "yes":
if self.technical_support_team.data == 'no': if self.technical_support_team.data == "no":
self.organization_providing_assistance.validators.append(Optional()) self.organization_providing_assistance.validators.append(Optional())
self.cloud_native.validators.append(Optional()) self.cloud_native.validators.append(Optional())
@ -39,16 +43,15 @@ class RequestForm(ValidatedForm):
"DoD Component", "DoD Component",
description="Identify the DoD component that is requesting access to the JEDI Cloud", description="Identify the DoD component that is requesting access to the JEDI Cloud",
choices=SERVICE_BRANCHES, choices=SERVICE_BRANCHES,
validators=[Required()] validators=[Required()],
) )
jedi_usage = TextAreaField( jedi_usage = TextAreaField(
"JEDI Usage", "JEDI Usage",
description="Your answer will help us provide tangible examples to DoD leadership how and why commercial cloud resources are accelerating the Department's missions", description="Your answer will help us provide tangible examples to DoD leadership how and why commercial cloud resources are accelerating the Department's missions",
validators=[Required()] validators=[Required()],
) )
# Details of Use: Cloud Readiness # Details of Use: Cloud Readiness
num_software_systems = IntegerField( num_software_systems = IntegerField(
"Number of Software Systems", "Number of Software Systems",
@ -121,16 +124,15 @@ class RequestForm(ValidatedForm):
average_daily_traffic = IntegerField( average_daily_traffic = IntegerField(
"Average Daily Traffic (Number of Requests)", "Average Daily Traffic (Number of Requests)",
description="What is the average daily traffic you expect the systems under this cloud contract to use?" description="What is the average daily traffic you expect the systems under this cloud contract to use?",
) )
average_daily_traffic_gb = IntegerField( average_daily_traffic_gb = IntegerField(
"Average Daily Traffic (GB)", "Average Daily Traffic (GB)",
description="What is the average daily traffic you expect the systems under this cloud contract to use?" description="What is the average daily traffic you expect the systems under this cloud contract to use?",
) )
start_date = DateField( start_date = DateField(
description="When do you expect to start using the JEDI Cloud (not for billing purposes)?", description="When do you expect to start using the JEDI Cloud (not for billing purposes)?",
validators=[ validators=[Required()],
Required()]
) )

View File

@ -60,5 +60,3 @@ def ListItemRequired(message="Please provide at least one.", empty_values=("", N
raise ValidationError(message) raise ValidationError(message)
return _list_item_required return _list_item_required

View File

@ -2,6 +2,12 @@ from sqlalchemy import Column, func, TIMESTAMP
class TimestampsMixin(object): class TimestampsMixin(object):
time_created = Column(TIMESTAMP(timezone=True), nullable=False, server_default=func.now()) time_created = Column(
time_updated = Column(TIMESTAMP(timezone=True), nullable=False, server_default=func.now(), onupdate=func.current_timestamp()) TIMESTAMP(timezone=True), nullable=False, server_default=func.now()
)
time_updated = Column(
TIMESTAMP(timezone=True),
nullable=False,
server_default=func.now(),
onupdate=func.current_timestamp(),
)

View File

@ -4,6 +4,7 @@ from sqlalchemy import Column, Integer, String, Enum as SQLAEnum
from atst.models import Base from atst.models import Base
class Source(Enum): class Source(Enum):
MANUAL = "Manual" MANUAL = "Manual"
EDA = "eda" EDA = "eda"

View File

@ -25,7 +25,7 @@ def styleguide():
return render_template("styleguide.html") return render_template("styleguide.html")
@bp.route('/<path:path>') @bp.route("/<path:path>")
def catch_all(path): def catch_all(path):
return render_template("{}.html".format(path)) return render_template("{}.html".format(path))
@ -35,11 +35,11 @@ def _make_authentication_context():
crl_cache=app.crl_cache, crl_cache=app.crl_cache,
auth_status=request.environ.get("HTTP_X_SSL_CLIENT_VERIFY"), auth_status=request.environ.get("HTTP_X_SSL_CLIENT_VERIFY"),
sdn=request.environ.get("HTTP_X_SSL_CLIENT_S_DN"), sdn=request.environ.get("HTTP_X_SSL_CLIENT_S_DN"),
cert=request.environ.get("HTTP_X_SSL_CLIENT_CERT") cert=request.environ.get("HTTP_X_SSL_CLIENT_CERT"),
) )
@bp.route('/login-redirect') @bp.route("/login-redirect")
def login_redirect(): def login_redirect():
auth_context = _make_authentication_context() auth_context = _make_authentication_context()
auth_context.authenticate() auth_context.authenticate()
@ -53,7 +53,7 @@ def login_redirect():
def _is_valid_certificate(request): def _is_valid_certificate(request):
cert = request.environ.get('HTTP_X_SSL_CLIENT_CERT') cert = request.environ.get("HTTP_X_SSL_CLIENT_CERT")
if cert: if cert:
result = app.crl_validator.validate(cert.encode()) result = app.crl_validator.validate(cert.encode())
return result return result

View File

@ -10,45 +10,46 @@ _DEV_USERS = {
"first_name": "Sam", "first_name": "Sam",
"last_name": "Seeceepio", "last_name": "Seeceepio",
"atat_role_name": "ccpo", "atat_role_name": "ccpo",
"email": "sam@test.com" "email": "sam@test.com",
}, },
"amanda": { "amanda": {
"dod_id": "2345678901", "dod_id": "2345678901",
"first_name": "Amanda", "first_name": "Amanda",
"last_name": "Adamson", "last_name": "Adamson",
"atat_role_name": "default", "atat_role_name": "default",
"email": "amanda@test.com" "email": "amanda@test.com",
}, },
"brandon": { "brandon": {
"dod_id": "3456789012", "dod_id": "3456789012",
"first_name": "Brandon", "first_name": "Brandon",
"last_name": "Buchannan", "last_name": "Buchannan",
"atat_role_name": "default", "atat_role_name": "default",
"email": "brandon@test.com" "email": "brandon@test.com",
}, },
"christina": { "christina": {
"dod_id": "4567890123", "dod_id": "4567890123",
"first_name": "Christina", "first_name": "Christina",
"last_name": "Collins", "last_name": "Collins",
"atat_role_name": "default", "atat_role_name": "default",
"email": "christina@test.com" "email": "christina@test.com",
}, },
"dominick": { "dominick": {
"dod_id": "5678901234", "dod_id": "5678901234",
"first_name": "Dominick", "first_name": "Dominick",
"last_name": "Domingo", "last_name": "Domingo",
"atat_role_name": "default", "atat_role_name": "default",
"email": "dominick@test.com" "email": "dominick@test.com",
}, },
"erica": { "erica": {
"dod_id": "6789012345", "dod_id": "6789012345",
"first_name": "Erica", "first_name": "Erica",
"last_name": "Eichner", "last_name": "Eichner",
"atat_role_name": "default", "atat_role_name": "default",
"email": "erica@test.com" "email": "erica@test.com",
}, },
} }
@bp.route("/login-dev") @bp.route("/login-dev")
def login_dev(): def login_dev():
role = request.args.get("username", "amanda") role = request.args.get("username", "amanda")
@ -58,7 +59,7 @@ def login_dev():
atat_role_name=user_data["atat_role_name"], atat_role_name=user_data["atat_role_name"],
first_name=user_data["first_name"], first_name=user_data["first_name"],
last_name=user_data["last_name"], last_name=user_data["last_name"],
email=user_data["email"] email=user_data["email"],
) )
session["user_id"] = user.id session["user_id"] = user.id

View File

@ -11,11 +11,10 @@ def make_error_pages(app):
app.logger.error(e.message) app.logger.error(e.message)
return render_template("not_found.html"), 404 return render_template("not_found.html"), 404
@app.errorhandler(exceptions.UnauthenticatedError) @app.errorhandler(exceptions.UnauthenticatedError)
# pylint: disable=unused-variable # pylint: disable=unused-variable
def unauthorized(e): def unauthorized(e):
app.logger.error(e.message) app.logger.error(e.message)
return render_template('unauthenticated.html'), 401 return render_template("unauthenticated.html"), 401
return app return app

View File

@ -8,6 +8,7 @@ from . import index
from . import requests_form from . import requests_form
from . import financial_verification from . import financial_verification
@requests_bp.context_processor @requests_bp.context_processor
def annual_spend_threshold(): def annual_spend_threshold():
return { "annual_spend_threshold": Requests.ANNUAL_SPEND_THRESHOLD } return {"annual_spend_threshold": Requests.ANNUAL_SPEND_THRESHOLD}

View File

@ -43,7 +43,13 @@ def update_financial_verification(request_id):
if valid: if valid:
Requests.submit_financial_verification(request_id) Requests.submit_financial_verification(request_id)
new_workspace = Requests.approve_and_create_workspace(updated_request) new_workspace = Requests.approve_and_create_workspace(updated_request)
return redirect(url_for("workspaces.workspace_projects", workspace_id=new_workspace.id, newWorkspace=True)) return redirect(
url_for(
"workspaces.workspace_projects",
workspace_id=new_workspace.id,
newWorkspace=True,
)
)
else: else:
form.reset() form.reset()

View File

@ -24,15 +24,18 @@ def map_request(request):
"date": time_created.format("M/DD/YYYY"), "date": time_created.format("M/DD/YYYY"),
"full_name": request.creator.full_name, "full_name": request.creator.full_name,
"annual_usage": annual_usage, "annual_usage": annual_usage,
"edit_link": verify_url if Requests.is_pending_financial_verification( "edit_link": verify_url
request if Requests.is_pending_financial_verification(request)
) else update_url, else update_url,
} }
@requests_bp.route("/requests", methods=["GET"]) @requests_bp.route("/requests", methods=["GET"])
def requests_index(): def requests_index():
if Permissions.REVIEW_AND_APPROVE_JEDI_WORKSPACE_REQUEST in g.current_user.atat_permissions: if (
Permissions.REVIEW_AND_APPROVE_JEDI_WORKSPACE_REQUEST
in g.current_user.atat_permissions
):
return _ccpo_view() return _ccpo_view()
else: else:

View File

@ -129,7 +129,9 @@ class JEDIRequestFlow(object):
if section == "primary_poc": if section == "primary_poc":
if data.get("am_poc", False): if data.get("am_poc", False):
try: try:
request_user_info = self.existing_request.body.get("information_about_you", {}) request_user_info = self.existing_request.body.get(
"information_about_you", {}
)
except AttributeError: except AttributeError:
request_user_info = {} request_user_info = {}

View File

@ -6,7 +6,12 @@ from atst.routes.requests.jedi_request_flow import JEDIRequestFlow
from atst.models.permissions import Permissions from atst.models.permissions import Permissions
from atst.models.request_status_event import RequestStatus from atst.models.request_status_event import RequestStatus
from atst.domain.exceptions import UnauthorizedError from atst.domain.exceptions import UnauthorizedError
from atst.forms.data import SERVICE_BRANCHES, ASSISTANCE_ORG_TYPES, DATA_TRANSFER_AMOUNTS, COMPLETION_DATE_RANGES from atst.forms.data import (
SERVICE_BRANCHES,
ASSISTANCE_ORG_TYPES,
DATA_TRANSFER_AMOUNTS,
COMPLETION_DATE_RANGES,
)
@requests_bp.route("/requests/new/<int:screen>", methods=["GET"]) @requests_bp.route("/requests/new/<int:screen>", methods=["GET"])
@ -27,6 +32,7 @@ def requests_form_new(screen):
completion_date_ranges=COMPLETION_DATE_RANGES, completion_date_ranges=COMPLETION_DATE_RANGES,
) )
@requests_bp.route( @requests_bp.route(
"/requests/new/<int:screen>", methods=["GET"], defaults={"request_id": None} "/requests/new/<int:screen>", methods=["GET"], defaults={"request_id": None}
) )
@ -36,7 +42,9 @@ def requests_form_update(screen=1, request_id=None):
_check_can_view_request(request_id) _check_can_view_request(request_id)
request = Requests.get(request_id) if request_id is not None else None request = Requests.get(request_id) if request_id is not None else None
jedi_flow = JEDIRequestFlow(screen, request=request, request_id=request_id, current_user=g.current_user) jedi_flow = JEDIRequestFlow(
screen, request=request, request_id=request_id, current_user=g.current_user
)
return render_template( return render_template(
"requests/screen-%d.html" % int(screen), "requests/screen-%d.html" % int(screen),
@ -114,10 +122,12 @@ def requests_submit(request_id=None):
# TODO: generalize this, along with other authorizations, into a policy-pattern # TODO: generalize this, along with other authorizations, into a policy-pattern
# for authorization in the application # for authorization in the application
def _check_can_view_request(request_id): def _check_can_view_request(request_id):
if Permissions.REVIEW_AND_APPROVE_JEDI_WORKSPACE_REQUEST in g.current_user.atat_permissions: if (
Permissions.REVIEW_AND_APPROVE_JEDI_WORKSPACE_REQUEST
in g.current_user.atat_permissions
):
pass pass
elif Requests.exists(request_id, g.current_user): elif Requests.exists(request_id, g.current_user):
pass pass
else: else:
raise UnauthorizedError(g.current_user, "view request {}".format(request_id)) raise UnauthorizedError(g.current_user, "view request {}".format(request_id))

View File

@ -8,12 +8,7 @@ from atst.app import make_app, make_config
from atst.database import db as _db from atst.database import db as _db
import tests.factories as factories import tests.factories as factories
dictConfig({ dictConfig({"version": 1, "handlers": {"wsgi": {"class": "logging.NullHandler"}}})
'version': 1,
'handlers': {'wsgi': {
'class': 'logging.NullHandler',
}}
})
@pytest.fixture(scope="session") @pytest.fixture(scope="session")

View File

@ -11,7 +11,7 @@ from tests.factories import UserFactory
CERT = open("tests/fixtures/{}.crt".format(FIXTURE_EMAIL_ADDRESS)).read() CERT = open("tests/fixtures/{}.crt".format(FIXTURE_EMAIL_ADDRESS)).read()
class MockCRLCache(): class MockCRLCache:
def __init__(self, valid=True): def __init__(self, valid=True):
self.valid = valid self.valid = valid
@ -23,16 +23,12 @@ class MockCRLCache():
def test_can_authenticate(): def test_can_authenticate():
auth_context = AuthenticationContext( auth_context = AuthenticationContext(MockCRLCache(), "SUCCESS", DOD_SDN, CERT)
MockCRLCache(), "SUCCESS", DOD_SDN, CERT
)
assert auth_context.authenticate() assert auth_context.authenticate()
def test_unsuccessful_status(): def test_unsuccessful_status():
auth_context = AuthenticationContext( auth_context = AuthenticationContext(MockCRLCache(), "FAILURE", DOD_SDN, CERT)
MockCRLCache(), "FAILURE", DOD_SDN, CERT
)
with pytest.raises(UnauthenticatedError) as excinfo: with pytest.raises(UnauthenticatedError) as excinfo:
assert auth_context.authenticate() assert auth_context.authenticate()
@ -41,9 +37,7 @@ def test_unsuccessful_status():
def test_crl_check_fails(): def test_crl_check_fails():
auth_context = AuthenticationContext( auth_context = AuthenticationContext(MockCRLCache(False), "SUCCESS", DOD_SDN, CERT)
MockCRLCache(False), "SUCCESS", DOD_SDN, CERT
)
with pytest.raises(UnauthenticatedError) as excinfo: with pytest.raises(UnauthenticatedError) as excinfo:
assert auth_context.authenticate() assert auth_context.authenticate()
@ -52,9 +46,7 @@ def test_crl_check_fails():
def test_bad_sdn(): def test_bad_sdn():
auth_context = AuthenticationContext( auth_context = AuthenticationContext(MockCRLCache(), "SUCCESS", "abc123", CERT)
MockCRLCache(), "SUCCESS", "abc123", CERT
)
with pytest.raises(UnauthenticatedError) as excinfo: with pytest.raises(UnauthenticatedError) as excinfo:
auth_context.get_user() auth_context.get_user()
@ -64,9 +56,7 @@ def test_bad_sdn():
def test_user_exists(): def test_user_exists():
user = UserFactory.create(**DOD_SDN_INFO) user = UserFactory.create(**DOD_SDN_INFO)
auth_context = AuthenticationContext( auth_context = AuthenticationContext(MockCRLCache(), "SUCCESS", DOD_SDN, CERT)
MockCRLCache(), "SUCCESS", DOD_SDN, CERT
)
auth_user = auth_context.get_user() auth_user = auth_context.get_user()
assert auth_user == user assert auth_user == user
@ -77,9 +67,7 @@ def test_creates_user():
with pytest.raises(NotFoundError): with pytest.raises(NotFoundError):
Users.get_by_dod_id(DOD_SDN_INFO["dod_id"]) Users.get_by_dod_id(DOD_SDN_INFO["dod_id"])
auth_context = AuthenticationContext( auth_context = AuthenticationContext(MockCRLCache(), "SUCCESS", DOD_SDN, CERT)
MockCRLCache(), "SUCCESS", DOD_SDN, CERT
)
user = auth_context.get_user() user = auth_context.get_user()
assert user.dod_id == DOD_SDN_INFO["dod_id"] assert user.dod_id == DOD_SDN_INFO["dod_id"]
assert user.email == FIXTURE_EMAIL_ADDRESS assert user.email == FIXTURE_EMAIL_ADDRESS
@ -87,9 +75,7 @@ def test_creates_user():
def test_user_cert_has_no_email(): def test_user_cert_has_no_email():
cert = open("ssl/client-certs/atat.mil.crt").read() cert = open("ssl/client-certs/atat.mil.crt").read()
auth_context = AuthenticationContext( auth_context = AuthenticationContext(MockCRLCache(), "SUCCESS", DOD_SDN, cert)
MockCRLCache(), "SUCCESS", DOD_SDN, cert
)
user = auth_context.get_user() user = auth_context.get_user()
assert user.email == None assert user.email == None

View File

@ -11,8 +11,7 @@ import atst.domain.authnid.crl.util as util
from tests.mocks import FIXTURE_EMAIL_ADDRESS from tests.mocks import FIXTURE_EMAIL_ADDRESS
class MockX509Store(): class MockX509Store:
def __init__(self): def __init__(self):
self.crls = [] self.crls = []
self.certs = [] self.certs = []
@ -98,8 +97,7 @@ def test_parse_disa_pki_list():
assert len(crl_list) == len(href_matches) assert len(crl_list) == len(href_matches)
class MockStreamingResponse(): class MockStreamingResponse:
def __init__(self, content_chunks, code=200): def __init__(self, content_chunks, code=200):
self.content_chunks = content_chunks self.content_chunks = content_chunks
self.status_code = code self.status_code = code

View File

@ -18,4 +18,3 @@ def test_invalid_date():
date_str = "This is not a valid data" date_str = "This is not a valid data"
with pytest.raises(ValueError): with pytest.raises(ValueError):
parse_date(date_str) parse_date(date_str)

View File

@ -7,7 +7,9 @@ from tests.factories import PENumberFactory
def test_can_get_pe_number(): def test_can_get_pe_number():
new_pen = PENumberFactory.create(number="0701367F", description="Combat Support - Offensive") new_pen = PENumberFactory.create(
number="0701367F", description="Combat Support - Offensive"
)
pen = PENumbers.get(new_pen.number) pen = PENumbers.get(new_pen.number)
assert pen.number == new_pen.number assert pen.number == new_pen.number
@ -17,8 +19,9 @@ def test_nonexistent_pe_number_raises():
with pytest.raises(NotFoundError): with pytest.raises(NotFoundError):
PENumbers.get("some fake number") PENumbers.get("some fake number")
def test_create_many(): def test_create_many():
pen_list = [['123456', 'Land Speeder'], ['7891011', 'Lightsaber']] pen_list = [["123456", "Land Speeder"], ["7891011", "Lightsaber"]]
PENumbers.create_many(pen_list) PENumbers.create_many(pen_list)
assert PENumbers.get(pen_list[0][0]) assert PENumbers.get(pen_list[0][0])

View File

@ -7,7 +7,12 @@ from atst.models.request import Request
from atst.models.request_status_event import RequestStatus from atst.models.request_status_event import RequestStatus
from atst.models.task_order import Source as TaskOrderSource from atst.models.task_order import Source as TaskOrderSource
from tests.factories import RequestFactory, UserFactory, RequestStatusEventFactory, TaskOrderFactory from tests.factories import (
RequestFactory,
UserFactory,
RequestStatusEventFactory,
TaskOrderFactory,
)
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
@ -55,10 +60,12 @@ def test_dont_auto_approve_if_no_dollar_value_specified(new_request):
def test_should_allow_submission(new_request): def test_should_allow_submission(new_request):
assert Requests.should_allow_submission(new_request) assert Requests.should_allow_submission(new_request)
RequestStatusEventFactory.create(request=new_request, new_status=RequestStatus.CHANGES_REQUESTED) RequestStatusEventFactory.create(
request=new_request, new_status=RequestStatus.CHANGES_REQUESTED
)
assert Requests.should_allow_submission(new_request) assert Requests.should_allow_submission(new_request)
del new_request.body['details_of_use'] del new_request.body["details_of_use"]
assert not Requests.should_allow_submission(new_request) assert not Requests.should_allow_submission(new_request)
@ -76,12 +83,17 @@ def test_status_count(session):
request1 = RequestFactory.create() request1 = RequestFactory.create()
request2 = RequestFactory.create() request2 = RequestFactory.create()
RequestStatusEventFactory.create(sequence=2, request_id=request2.id, new_status=RequestStatus.PENDING_FINANCIAL_VERIFICATION) RequestStatusEventFactory.create(
sequence=2,
request_id=request2.id,
new_status=RequestStatus.PENDING_FINANCIAL_VERIFICATION,
)
assert Requests.status_count(RequestStatus.PENDING_FINANCIAL_VERIFICATION) == 1 assert Requests.status_count(RequestStatus.PENDING_FINANCIAL_VERIFICATION) == 1
assert Requests.status_count(RequestStatus.STARTED) == 1 assert Requests.status_count(RequestStatus.STARTED) == 1
assert Requests.in_progress_count() == 2 assert Requests.in_progress_count() == 2
def test_status_count_scoped_to_creator(session): def test_status_count_scoped_to_creator(session):
# make sure table is empty # make sure table is empty
session.query(Request).delete() session.query(Request).delete()
@ -123,7 +135,7 @@ task_order_financial_data = {
def test_update_financial_verification_without_task_order(): def test_update_financial_verification_without_task_order():
request = RequestFactory.create() request = RequestFactory.create()
financial_data = { **request_financial_data, **task_order_financial_data } financial_data = {**request_financial_data, **task_order_financial_data}
Requests.update_financial_verification(request.id, financial_data) Requests.update_financial_verification(request.id, financial_data)
assert request.task_order assert request.task_order
assert request.task_order.clin_0001 == task_order_financial_data["clin_0001"] assert request.task_order.clin_0001 == task_order_financial_data["clin_0001"]
@ -132,7 +144,7 @@ def test_update_financial_verification_without_task_order():
def test_update_financial_verification_with_task_order(): def test_update_financial_verification_with_task_order():
task_order = TaskOrderFactory.create(source=TaskOrderSource.EDA) task_order = TaskOrderFactory.create(source=TaskOrderSource.EDA)
financial_data = { **request_financial_data, "task_order_number": task_order.number } financial_data = {**request_financial_data, "task_order_number": task_order.number}
request = RequestFactory.create() request = RequestFactory.create()
Requests.update_financial_verification(request.id, financial_data) Requests.update_financial_verification(request.id, financial_data)
assert request.task_order == task_order assert request.task_order == task_order
@ -142,4 +154,3 @@ def test_update_financial_verification_with_invalid_task_order():
request = RequestFactory.create() request = RequestFactory.create()
Requests.update_financial_verification(request.id, request_financial_data) Requests.update_financial_verification(request.id, request_financial_data)
assert not request.task_order assert not request.task_order

View File

@ -16,7 +16,9 @@ def test_can_get_task_order():
def test_can_get_task_order_from_eda(monkeypatch): def test_can_get_task_order_from_eda(monkeypatch):
monkeypatch.setattr("atst.domain.task_orders.TaskOrders._client", lambda: MockEDAClient()) monkeypatch.setattr(
"atst.domain.task_orders.TaskOrders._client", lambda: MockEDAClient()
)
to = TaskOrders.get(MockEDAClient.MOCK_CONTRACT_NUMBER) to = TaskOrders.get(MockEDAClient.MOCK_CONTRACT_NUMBER)
assert to.number == MockEDAClient.MOCK_CONTRACT_NUMBER assert to.number == MockEDAClient.MOCK_CONTRACT_NUMBER
@ -29,6 +31,8 @@ def test_nonexistent_task_order_raises_without_client():
def test_nonexistent_task_order_raises_with_client(monkeypatch): def test_nonexistent_task_order_raises_with_client(monkeypatch):
monkeypatch.setattr("atst.domain.task_orders.TaskOrders._client", lambda: MockEDAClient()) monkeypatch.setattr(
"atst.domain.task_orders.TaskOrders._client", lambda: MockEDAClient()
)
with pytest.raises(NotFoundError): with pytest.raises(NotFoundError):
TaskOrders.get("some other fake numer") TaskOrders.get("some other fake numer")

View File

@ -30,13 +30,16 @@ def test_date_insane_format():
form.date._value() form.date._value()
@pytest.mark.parametrize("input_,expected", [ @pytest.mark.parametrize(
"input_,expected",
[
("", []), ("", []),
("hello", ["hello"]), ("hello", ["hello"]),
("hello\n", ["hello"]), ("hello\n", ["hello"]),
("hello\nworld", ["hello", "world"]), ("hello\nworld", ["hello", "world"]),
("hello\nworld\n", ["hello", "world"]) ("hello\nworld\n", ["hello", "world"]),
]) ],
)
def test_newline_list_process(input_, expected): def test_newline_list_process(input_, expected):
form_data = ImmutableMultiDict({"newline_list": input_}) form_data = ImmutableMultiDict({"newline_list": input_})
form = NewlineListForm(form_data) form = NewlineListForm(form_data)
@ -45,11 +48,10 @@ def test_newline_list_process(input_, expected):
assert form.data == {"newline_list": expected} assert form.data == {"newline_list": expected}
@pytest.mark.parametrize("input_,expected", [ @pytest.mark.parametrize(
([], ""), "input_,expected",
(["hello"], "hello"), [([], ""), (["hello"], "hello"), (["hello", "world"], "hello\nworld")],
(["hello", "world"], "hello\nworld") )
])
def test_newline_list_value(input_, expected): def test_newline_list_value(input_, expected):
form_data = {"newline_list": input_} form_data = {"newline_list": input_}
form = NewlineListForm(data=form_data) form = NewlineListForm(data=form_data)

View File

@ -4,36 +4,37 @@ from atst.forms.financial import suggest_pe_id, FinancialForm, ExtendedFinancial
from atst.eda_client import MockEDAClient from atst.eda_client import MockEDAClient
@pytest.mark.parametrize("input_,expected", [ @pytest.mark.parametrize(
('0603502N', None), "input_,expected",
('0603502NZ', None), [
('603502N', '0603502N'), ("0603502N", None),
('063502N', '0603502N'), ("0603502NZ", None),
('63502N', '0603502N'), ("603502N", "0603502N"),
]) ("063502N", "0603502N"),
("63502N", "0603502N"),
],
)
def test_suggest_pe_id(input_, expected): def test_suggest_pe_id(input_, expected):
assert suggest_pe_id(input_) == expected assert suggest_pe_id(input_) == expected
def test_funding_type_other_not_required_if_funding_type_is_not_other(): def test_funding_type_other_not_required_if_funding_type_is_not_other():
form_data = { form_data = {"funding_type": "PROC"}
"funding_type": "PROC"
}
form = ExtendedFinancialForm(data=form_data) form = ExtendedFinancialForm(data=form_data)
form.validate() form.validate()
assert "funding_type_other" not in form.errors assert "funding_type_other" not in form.errors
def test_funding_type_other_required_if_funding_type_is_other(): def test_funding_type_other_required_if_funding_type_is_other():
form_data = { form_data = {"funding_type": "OTHER"}
"funding_type": "OTHER"
}
form = ExtendedFinancialForm(data=form_data) form = ExtendedFinancialForm(data=form_data)
form.validate() form.validate()
assert "funding_type_other" in form.errors assert "funding_type_other" in form.errors
@pytest.mark.parametrize("input_,expected", [ @pytest.mark.parametrize(
"input_,expected",
[
("1234", True), ("1234", True),
("123456", True), ("123456", True),
("0001234", True), ("0001234", True),
@ -42,7 +43,8 @@ def test_funding_type_other_required_if_funding_type_is_other():
("00012345", False), ("00012345", False),
("0001234567", False), ("0001234567", False),
("000000", False), ("000000", False),
]) ],
)
def test_treasury_code_validation(input_, expected): def test_treasury_code_validation(input_, expected):
form_data = {"treasury_code": input_} form_data = {"treasury_code": input_}
form = FinancialForm(data=form_data) form = FinancialForm(data=form_data)
@ -52,7 +54,9 @@ def test_treasury_code_validation(input_, expected):
assert is_valid == expected assert is_valid == expected
@pytest.mark.parametrize("input_,expected", [ @pytest.mark.parametrize(
"input_,expected",
[
("12", True), ("12", True),
("00012", True), ("00012", True),
("12A", True), ("12A", True),
@ -60,7 +64,8 @@ def test_treasury_code_validation(input_, expected):
("00012A", True), ("00012A", True),
("0001", False), ("0001", False),
("00012AB", False), ("00012AB", False),
]) ],
)
def test_ba_code_validation(input_, expected): def test_ba_code_validation(input_, expected):
form_data = {"ba_code": input_} form_data = {"ba_code": input_}
form = FinancialForm(data=form_data) form = FinancialForm(data=form_data)
@ -69,16 +74,21 @@ def test_ba_code_validation(input_, expected):
assert is_valid == expected assert is_valid == expected
def test_task_order_number_validation(monkeypatch): def test_task_order_number_validation(monkeypatch):
monkeypatch.setattr("atst.domain.task_orders.TaskOrders._client", lambda: MockEDAClient()) monkeypatch.setattr(
"atst.domain.task_orders.TaskOrders._client", lambda: MockEDAClient()
)
monkeypatch.setattr("atst.forms.financial.validate_pe_id", lambda *args: True) monkeypatch.setattr("atst.forms.financial.validate_pe_id", lambda *args: True)
form_invalid = FinancialForm(data={"task_order_number": "1234"}) form_invalid = FinancialForm(data={"task_order_number": "1234"})
form_invalid.perform_extra_validation({}) form_invalid.perform_extra_validation({})
assert "task_order_number" in form_invalid.errors assert "task_order_number" in form_invalid.errors
form_valid = FinancialForm(data={"task_order_number": MockEDAClient.MOCK_CONTRACT_NUMBER}, eda_client=MockEDAClient()) form_valid = FinancialForm(
data={"task_order_number": MockEDAClient.MOCK_CONTRACT_NUMBER},
eda_client=MockEDAClient(),
)
form_valid.perform_extra_validation({}) form_valid.perform_extra_validation({})
assert "task_order_number" not in form_valid.errors assert "task_order_number" not in form_valid.errors

View File

@ -5,7 +5,6 @@ from atst.forms.validators import Alphabet, IsNumber, PhoneNumber
class TestIsNumber: class TestIsNumber:
@pytest.mark.parametrize("valid", ["0", "12", "-12"]) @pytest.mark.parametrize("valid", ["0", "12", "-12"])
def test_IsNumber_accepts_integers(self, valid, dummy_form, dummy_field): def test_IsNumber_accepts_integers(self, valid, dummy_form, dummy_field):
validator = IsNumber() validator = IsNumber()
@ -21,24 +20,18 @@ class TestIsNumber:
class TestPhoneNumber: class TestPhoneNumber:
@pytest.mark.parametrize("valid", ["12345", "1234567890", "(123) 456-7890"])
@pytest.mark.parametrize("valid", [
"12345",
"1234567890",
"(123) 456-7890",
])
def test_PhoneNumber_accepts_valid_numbers(self, valid, dummy_form, dummy_field): def test_PhoneNumber_accepts_valid_numbers(self, valid, dummy_form, dummy_field):
validator = PhoneNumber() validator = PhoneNumber()
dummy_field.data = valid dummy_field.data = valid
validator(dummy_form, dummy_field) validator(dummy_form, dummy_field)
@pytest.mark.parametrize("invalid", [ @pytest.mark.parametrize(
"1234", "invalid", ["1234", "123456", "1234567abc", "(123) 456-789012"]
"123456", )
"1234567abc", def test_PhoneNumber_rejects_invalid_numbers(
"(123) 456-789012", self, invalid, dummy_form, dummy_field
]) ):
def test_PhoneNumber_rejects_invalid_numbers(self, invalid, dummy_form, dummy_field):
validator = PhoneNumber() validator = PhoneNumber()
dummy_field.data = invalid dummy_field.data = invalid
with pytest.raises(ValidationError): with pytest.raises(ValidationError):
@ -46,7 +39,6 @@ class TestPhoneNumber:
class TestAlphabet: class TestAlphabet:
@pytest.mark.parametrize("valid", ["a", "abcde"]) @pytest.mark.parametrize("valid", ["a", "abcde"])
def test_Alphabet_accepts_letters(self, valid, dummy_form, dummy_field): def test_Alphabet_accepts_letters(self, valid, dummy_form, dummy_field):
validator = Alphabet() validator = Alphabet()

View File

@ -69,6 +69,7 @@ def test_request_status_pending_deleted_displayname():
assert request.status_displayname == "Canceled" assert request.status_displayname == "Canceled"
def test_annual_spend(): def test_annual_spend():
request = RequestFactory.create() request = RequestFactory.create()
monthly = request.body.get("details_of_use").get("estimated_monthly_spend") monthly = request.body.get("details_of_use").get("estimated_monthly_spend")

View File

@ -22,7 +22,7 @@ class TestPENumberInForm:
"office_cor": "WHS", "office_cor": "WHS",
"uii_ids": "1234", "uii_ids": "1234",
"treasury_code": "00123456", "treasury_code": "00123456",
"ba_code": "024A" "ba_code": "024A",
} }
extended_data = { extended_data = {
"funding_type": "RDTE", "funding_type": "RDTE",
@ -36,8 +36,12 @@ class TestPENumberInForm:
} }
def _set_monkeypatches(self, monkeypatch): def _set_monkeypatches(self, monkeypatch):
monkeypatch.setattr("atst.forms.financial.FinancialForm.validate", lambda s: True) monkeypatch.setattr(
monkeypatch.setattr("atst.domain.auth.get_current_user", lambda *args: MOCK_USER) "atst.forms.financial.FinancialForm.validate", lambda s: True
)
monkeypatch.setattr(
"atst.domain.auth.get_current_user", lambda *args: MOCK_USER
)
def submit_data(self, client, data, extended=False): def submit_data(self, client, data, extended=False):
request = RequestFactory.create(body=MOCK_REQUEST.body) request = RequestFactory.create(body=MOCK_REQUEST.body)
@ -64,7 +68,7 @@ class TestPENumberInForm:
self._set_monkeypatches(monkeypatch) self._set_monkeypatches(monkeypatch)
data = dict(self.required_data) data = dict(self.required_data)
data['pe_id'] = MOCK_REQUEST.body['financial_verification']['pe_id'] data["pe_id"] = MOCK_REQUEST.body["financial_verification"]["pe_id"]
response = self.submit_data(client, data) response = self.submit_data(client, data)
@ -76,7 +80,7 @@ class TestPENumberInForm:
pe = PENumberFactory.create(number="8675309U", description="sample PE number") pe = PENumberFactory.create(number="8675309U", description="sample PE number")
data = dict(self.required_data) data = dict(self.required_data)
data['pe_id'] = pe.number data["pe_id"] = pe.number
response = self.submit_data(client, data) response = self.submit_data(client, data)
@ -87,32 +91,36 @@ class TestPENumberInForm:
self._set_monkeypatches(monkeypatch) self._set_monkeypatches(monkeypatch)
data = dict(self.required_data) data = dict(self.required_data)
data['pe_id'] = '' data["pe_id"] = ""
response = self.submit_data(client, data) response = self.submit_data(client, data)
assert "There were some errors" in response.data.decode() assert "There were some errors" in response.data.decode()
assert response.status_code == 200 assert response.status_code == 200
def test_submit_financial_form_with_invalid_task_order(self, monkeypatch, user_session, client): def test_submit_financial_form_with_invalid_task_order(
self, monkeypatch, user_session, client
):
monkeypatch.setattr("atst.domain.requests.Requests.get", lambda i: MOCK_REQUEST) monkeypatch.setattr("atst.domain.requests.Requests.get", lambda i: MOCK_REQUEST)
user_session() user_session()
data = dict(self.required_data) data = dict(self.required_data)
data['pe_id'] = MOCK_REQUEST.body['financial_verification']['pe_id'] data["pe_id"] = MOCK_REQUEST.body["financial_verification"]["pe_id"]
data['task_order_number'] = '1234' data["task_order_number"] = "1234"
response = self.submit_data(client, data) response = self.submit_data(client, data)
assert "enter TO information manually" in response.data.decode() assert "enter TO information manually" in response.data.decode()
def test_submit_financial_form_with_valid_task_order(self, monkeypatch, user_session, client): def test_submit_financial_form_with_valid_task_order(
self, monkeypatch, user_session, client
):
monkeypatch.setattr("atst.domain.requests.Requests.get", lambda i: MOCK_REQUEST) monkeypatch.setattr("atst.domain.requests.Requests.get", lambda i: MOCK_REQUEST)
user_session() user_session()
data = dict(self.required_data) data = dict(self.required_data)
data['pe_id'] = MOCK_REQUEST.body['financial_verification']['pe_id'] data["pe_id"] = MOCK_REQUEST.body["financial_verification"]["pe_id"]
data['task_order_number'] = MockEDAClient.MOCK_CONTRACT_NUMBER data["task_order_number"] = MockEDAClient.MOCK_CONTRACT_NUMBER
response = self.submit_data(client, data) response = self.submit_data(client, data)
@ -122,9 +130,9 @@ class TestPENumberInForm:
monkeypatch.setattr("atst.domain.requests.Requests.get", lambda i: MOCK_REQUEST) monkeypatch.setattr("atst.domain.requests.Requests.get", lambda i: MOCK_REQUEST)
user_session() user_session()
data = { **self.required_data, **self.extended_data } data = {**self.required_data, **self.extended_data}
data['pe_id'] = MOCK_REQUEST.body['financial_verification']['pe_id'] data["pe_id"] = MOCK_REQUEST.body["financial_verification"]["pe_id"]
data['task_order_number'] = "1234567" data["task_order_number"] = "1234567"
response = self.submit_data(client, data, extended=True) response = self.submit_data(client, data, extended=True)

View File

@ -8,6 +8,7 @@ from tests.assert_util import dict_contains
ERROR_CLASS = "alert--error" ERROR_CLASS = "alert--error"
def test_submit_invalid_request_form(monkeypatch, client, user_session): def test_submit_invalid_request_form(monkeypatch, client, user_session):
user_session() user_session()
response = client.post( response = client.post(
@ -35,7 +36,9 @@ def test_owner_can_view_request(client, user_session):
user_session(user) user_session(user)
request = RequestFactory.create(creator=user) request = RequestFactory.create(creator=user)
response = client.get("/requests/new/1/{}".format(request.id), follow_redirects=True) response = client.get(
"/requests/new/1/{}".format(request.id), follow_redirects=True
)
assert response.status_code == 200 assert response.status_code == 200
@ -45,7 +48,9 @@ def test_non_owner_cannot_view_request(client, user_session):
user_session(user) user_session(user)
request = RequestFactory.create() request = RequestFactory.create()
response = client.get("/requests/new/1/{}".format(request.id), follow_redirects=True) response = client.get(
"/requests/new/1/{}".format(request.id), follow_redirects=True
)
assert response.status_code == 404 assert response.status_code == 404
@ -56,7 +61,9 @@ def test_ccpo_can_view_request(client, user_session):
user_session(user) user_session(user)
request = RequestFactory.create() request = RequestFactory.create()
response = client.get("/requests/new/1/{}".format(request.id), follow_redirects=True) response = client.get(
"/requests/new/1/{}".format(request.id), follow_redirects=True
)
assert response.status_code == 200 assert response.status_code == 200
@ -80,7 +87,9 @@ def test_creator_info_is_autopopulated(monkeypatch, client, user_session):
assert "initial-value='{}'".format(user.email) in body assert "initial-value='{}'".format(user.email) in body
def test_creator_info_is_autopopulated_for_new_request(monkeypatch, client, user_session): def test_creator_info_is_autopopulated_for_new_request(
monkeypatch, client, user_session
):
user = UserFactory.create() user = UserFactory.create()
user_session(user) user_session(user)
@ -103,6 +112,7 @@ def test_non_creator_info_is_not_autopopulated(monkeypatch, client, user_session
assert not user.last_name in body assert not user.last_name in body
assert not user.email in body assert not user.email in body
def test_am_poc_causes_poc_to_be_autopopulated(client, user_session): def test_am_poc_causes_poc_to_be_autopopulated(client, user_session):
creator = UserFactory.create() creator = UserFactory.create()
user_session(creator) user_session(creator)
@ -124,7 +134,7 @@ def test_not_am_poc_requires_poc_info_to_be_completed(client, user_session):
"/requests/new/3/{}".format(request.id), "/requests/new/3/{}".format(request.id),
headers={"Content-Type": "application/x-www-form-urlencoded"}, headers={"Content-Type": "application/x-www-form-urlencoded"},
data="am_poc=no", data="am_poc=no",
follow_redirects=True follow_redirects=True,
) )
assert ERROR_CLASS in response.data.decode() assert ERROR_CLASS in response.data.decode()
@ -156,7 +166,7 @@ def test_poc_details_can_be_autopopulated_on_new_request(client, user_session):
headers={"Content-Type": "application/x-www-form-urlencoded"}, headers={"Content-Type": "application/x-www-form-urlencoded"},
data="am_poc=yes", data="am_poc=yes",
) )
request_id = response.headers["Location"].split('/')[-1] request_id = response.headers["Location"].split("/")[-1]
request = Requests.get(request_id) request = Requests.get(request_id)
assert request.body["primary_poc"]["dodid_poc"] == creator.dod_id assert request.body["primary_poc"]["dodid_poc"] == creator.dod_id
@ -165,27 +175,31 @@ def test_poc_details_can_be_autopopulated_on_new_request(client, user_session):
def test_poc_autofill_checks_information_about_you_form_first(client, user_session): def test_poc_autofill_checks_information_about_you_form_first(client, user_session):
creator = UserFactory.create() creator = UserFactory.create()
user_session(creator) user_session(creator)
request = RequestFactory.create(creator=creator, body={ request = RequestFactory.create(
creator=creator,
body={
"information_about_you": { "information_about_you": {
"fname_request": "Alice", "fname_request": "Alice",
"lname_request": "Adams", "lname_request": "Adams",
"email_request": "alice.adams@mail.mil" "email_request": "alice.adams@mail.mil",
}
})
poc_input = {
"am_poc": "yes",
} }
},
)
poc_input = {"am_poc": "yes"}
client.post( client.post(
"/requests/new/3/{}".format(request.id), "/requests/new/3/{}".format(request.id),
headers={"Content-Type": "application/x-www-form-urlencoded"}, headers={"Content-Type": "application/x-www-form-urlencoded"},
data=urlencode(poc_input), data=urlencode(poc_input),
) )
request = Requests.get(request.id) request = Requests.get(request.id)
assert dict_contains(request.body["primary_poc"], { assert dict_contains(
request.body["primary_poc"],
{
"fname_poc": "Alice", "fname_poc": "Alice",
"lname_poc": "Adams", "lname_poc": "Adams",
"email_poc": "alice.adams@mail.mil" "email_poc": "alice.adams@mail.mil",
}) },
)
def test_can_review_data(user_session, client): def test_can_review_data(user_session, client):

View File

@ -28,11 +28,16 @@ def test_submit_autoapproved_reviewed_request(monkeypatch, client, user_session)
user_session() user_session()
monkeypatch.setattr("atst.domain.requests.Requests.get", _mock_func) monkeypatch.setattr("atst.domain.requests.Requests.get", _mock_func)
monkeypatch.setattr("atst.domain.requests.Requests.submit", _mock_func) monkeypatch.setattr("atst.domain.requests.Requests.submit", _mock_func)
monkeypatch.setattr("atst.models.request.Request.status", RequestStatus.PENDING_FINANCIAL_VERIFICATION) monkeypatch.setattr(
"atst.models.request.Request.status",
RequestStatus.PENDING_FINANCIAL_VERIFICATION,
)
response = client.post( response = client.post(
"/requests/submit/1", "/requests/submit/1",
headers={"Content-Type": "application/x-www-form-urlencoded"}, headers={"Content-Type": "application/x-www-form-urlencoded"},
data="", data="",
follow_redirects=False, follow_redirects=False,
) )
assert "/requests?modal=pendingFinancialVerification" in response.headers["Location"] assert (
"/requests?modal=pendingFinancialVerification" in response.headers["Location"]
)

View File

@ -15,8 +15,13 @@ def _fetch_user_info(c, t):
def test_successful_login_redirect_non_ccpo(client, monkeypatch): def test_successful_login_redirect_non_ccpo(client, monkeypatch):
monkeypatch.setattr("atst.domain.authnid.AuthenticationContext.authenticate", lambda *args: True) monkeypatch.setattr(
monkeypatch.setattr("atst.domain.authnid.AuthenticationContext.get_user", lambda *args: UserFactory.create()) "atst.domain.authnid.AuthenticationContext.authenticate", lambda *args: True
)
monkeypatch.setattr(
"atst.domain.authnid.AuthenticationContext.get_user",
lambda *args: UserFactory.create(),
)
resp = client.get( resp = client.get(
"/login-redirect", "/login-redirect",
@ -31,10 +36,16 @@ def test_successful_login_redirect_non_ccpo(client, monkeypatch):
assert "requests" in resp.headers["Location"] assert "requests" in resp.headers["Location"]
assert session["user_id"] assert session["user_id"]
def test_successful_login_redirect_ccpo(client, monkeypatch): def test_successful_login_redirect_ccpo(client, monkeypatch):
monkeypatch.setattr("atst.domain.authnid.AuthenticationContext.authenticate", lambda *args: True) monkeypatch.setattr(
"atst.domain.authnid.AuthenticationContext.authenticate", lambda *args: True
)
role = Roles.get("ccpo") role = Roles.get("ccpo")
monkeypatch.setattr("atst.domain.authnid.AuthenticationContext.get_user", lambda *args: UserFactory.create(atat_role=role)) monkeypatch.setattr(
"atst.domain.authnid.AuthenticationContext.get_user",
lambda *args: UserFactory.create(atat_role=role),
)
resp = client.get( resp = client.get(
"/login-redirect", "/login-redirect",
@ -114,7 +125,9 @@ def test_crl_validation_on_login(client):
def test_creates_new_user_on_login(monkeypatch, client): def test_creates_new_user_on_login(monkeypatch, client):
monkeypatch.setattr("atst.domain.authnid.AuthenticationContext.authenticate", lambda *args: True) monkeypatch.setattr(
"atst.domain.authnid.AuthenticationContext.authenticate", lambda *args: True
)
cert_file = open("tests/fixtures/{}.crt".format(FIXTURE_EMAIL_ADDRESS)).read() cert_file = open("tests/fixtures/{}.crt".format(FIXTURE_EMAIL_ADDRESS)).read()
# ensure user does not exist # ensure user does not exist

View File

@ -3,15 +3,18 @@ from atst.eda_client import MockEDAClient
client = MockEDAClient() client = MockEDAClient()
def test_list_contracts(): def test_list_contracts():
results = client.list_contracts() results = client.list_contracts()
assert len(results) == 3 assert len(results) == 3
def test_get_contract(): def test_get_contract():
result = client.get_contract("DCA10096D0052", "y") result = client.get_contract("DCA10096D0052", "y")
assert result["contract_no"] == "DCA10096D0052" assert result["contract_no"] == "DCA10096D0052"
assert result["amount"] == 2000000 assert result["amount"] == 2000000
def test_contract_not_found(): def test_contract_not_found():
result = client.get_contract("abc", "y") result = client.get_contract("abc", "y")
assert result is None assert result is None

View File

@ -3,13 +3,15 @@ import pytest
from atst.filters import dollars from atst.filters import dollars
@pytest.mark.parametrize("input,expected", [ @pytest.mark.parametrize(
('0', '$0'), "input,expected",
('123.00', '$123'), [
('1234567', '$1,234,567'), ("0", "$0"),
('-1234', '$-1,234'), ("123.00", "$123"),
('one', '$0'), ("1234567", "$1,234,567"),
]) ("-1234", "$-1,234"),
("one", "$0"),
],
)
def test_dollar_fomatter(input, expected): def test_dollar_fomatter(input, expected):
assert dollars(input) == expected assert dollars(input) == expected

View File

@ -1,6 +1,9 @@
import pytest import pytest
@pytest.mark.parametrize("path", (
@pytest.mark.parametrize(
"path",
(
"/", "/",
"/home", "/home",
"/workspaces", "/workspaces",
@ -9,7 +12,8 @@ import pytest
"/users", "/users",
"/reports", "/reports",
"/calculator", "/calculator",
)) ),
)
def test_routes(path, client, user_session): def test_routes(path, client, user_session):
user_session() user_session()