diff --git a/atst/domain/pe_numbers.py b/atst/domain/pe_numbers.py index 07d79585..a80e10ef 100644 --- a/atst/domain/pe_numbers.py +++ b/atst/domain/pe_numbers.py @@ -1,3 +1,5 @@ +from sqlalchemy.dialects.postgresql import insert + from atst.models.pe_number import PENumber from .exceptions import NotFoundError @@ -13,3 +15,12 @@ class PENumbers(object): raise NotFoundError("pe_number") return pe_number + + def create_many(self, list_of_pe_numbers): + stmt = insert(PENumber).values(list_of_pe_numbers) + do_update = stmt.on_conflict_do_update( + index_elements=["number"], + set_=dict(description=stmt.excluded.description) + ) + self.db_session.execute(do_update) + self.db_session.commit() diff --git a/script/ingest_pe_numbers.py b/script/ingest_pe_numbers.py index 1df5b132..e3bbecb0 100644 --- a/script/ingest_pe_numbers.py +++ b/script/ingest_pe_numbers.py @@ -1,8 +1,6 @@ from urllib.request import urlopen import csv -from sqlalchemy.dialects.postgresql import insert - # Add root project dir to the python path import os import sys @@ -10,7 +8,7 @@ parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) sys.path.append(parent_dir) from atst.app import make_deps, make_config -from atst.models import PENumber +from atst.domain.pe_numbers import PENumbers def get_pe_numbers(url): @@ -18,23 +16,16 @@ def get_pe_numbers(url): t = response.read().decode("utf-8") return list(csv.reader(t.split("\r\n"))) - -def insert_pe_numbers(db, list_of_pe_numbers): - stmt = insert(PENumber).values(list_of_pe_numbers) - do_update = stmt.on_conflict_do_update( - index_elements=["number"], - set_=dict(description=stmt.excluded.description) - ) - db.execute(do_update) - db.commit() - +def make_pe_number_repo(config): + deps = make_deps(config) + db = deps["db_session"] + return PENumbers(db) if __name__ == "__main__": config = make_config() - deps = make_deps(config) - db = deps["db_session"] url = config["default"]['PE_NUMBER_CSV_URL'] print("Fetching PE numbers from {}".format(url)) pe_numbers = get_pe_numbers(url) print("Inserting {} PE numbers".format(len(pe_numbers))) - insert_pe_numbers(db, pe_numbers) + pe_numbers_repo = make_pe_number_repo(config) + pe_numbers_repo.create_many(pe_numbers) diff --git a/tests/domain/test_pe_numbers.py b/tests/domain/test_pe_numbers.py index 332ae688..028aa712 100644 --- a/tests/domain/test_pe_numbers.py +++ b/tests/domain/test_pe_numbers.py @@ -32,3 +32,10 @@ def test_can_get_pe_number(pe_numbers, new_pe_number): def test_nonexistent_pe_number_raises(pe_numbers): with pytest.raises(NotFoundError): pe_numbers.get("some fake number") + +def test_create_many(pe_numbers): + pen_list = [['123456', 'Land Speeder'], ['7891011', 'Lightsaber']] + pe_numbers.create_many(pen_list) + + assert pe_numbers.get(pen_list[0][0]) + assert pe_numbers.get(pen_list[1][0])