Refactor to remove Uploader in favor of RackspaceFileProvider

This commit is contained in:
Patrick Smith 2019-01-02 17:12:55 -05:00
parent cd920373a8
commit e432da0d50
6 changed files with 60 additions and 67 deletions

View File

@ -1,7 +1,27 @@
from atst.uploader import Uploader from tempfile import NamedTemporaryFile
from uuid import uuid4
from libcloud.storage.types import Provider
from libcloud.storage.providers import get_driver
from atst.domain.exceptions import UploadError
class FileProviderInterface: class FileProviderInterface:
_PERMITTED_MIMETYPES = ["application/pdf"]
def _enforce_mimetype(self, fyle):
# TODO: for hardening, we should probably use a better library for
# determining mimetype and not rely on FileUpload's determination
# TODO: we should set MAX_CONTENT_LENGTH in the config to prevent large
# uploads
if not fyle.mimetype in self._PERMITTED_MIMETYPES:
raise UploadError(
"could not upload {} with mimetype {}".format(
fyle.filename, fyle.mimetype
)
)
def upload(self, fyle): # pragma: no cover def upload(self, fyle): # pragma: no cover
"""Store the file object `fyle` in the CSP. This method returns the """Store the file object `fyle` in the CSP. This method returns the
object name that can be used to later look up the file.""" object name that can be used to later look up the file."""
@ -16,15 +36,36 @@ class FileProviderInterface:
class RackspaceFileProvider(FileProviderInterface): class RackspaceFileProvider(FileProviderInterface):
def __init__(self, app): def __init__(self, app):
self.uploader = Uploader( self.container = self._get_container(
provider=app.config.get("STORAGE_PROVIDER"), provider=app.config.get("STORAGE_PROVIDER"),
container=app.config.get("STORAGE_CONTAINER"), container=app.config.get("STORAGE_CONTAINER"),
key=app.config.get("STORAGE_KEY"), key=app.config.get("STORAGE_KEY"),
secret=app.config.get("STORAGE_SECRET"), secret=app.config.get("STORAGE_SECRET"),
) )
def _get_container(self, provider, container=None, key=None, secret=None):
if provider == "LOCAL": # pragma: no branch
key = container
container = ""
driver = get_driver(getattr(Provider, provider))(key=key, secret=secret)
return driver.get_container(container)
def upload(self, fyle): def upload(self, fyle):
return self.uploader.upload(fyle) self._enforce_mimetype(fyle)
object_name = uuid4().hex
with NamedTemporaryFile() as tempfile:
tempfile.write(fyle.stream.read())
self.container.upload_object(
file_path=tempfile.name,
object_name=object_name,
extra={"acl": "private"},
)
return object_name
def download(self, object_name): def download(self, object_name):
return self.uploader.download_stream(object_name) obj = self.container.get_object(object_name=object_name)
with NamedTemporaryFile() as tempfile:
obj.download(tempfile.name, overwrite_existing=True)
return open(tempfile.name, "rb")

View File

@ -30,3 +30,7 @@ class UnauthenticatedError(Exception):
@property @property
def message(self): def message(self):
return str(self) return str(self)
class UploadError(Exception):
pass

View File

@ -5,8 +5,7 @@ from flask import current_app as app
from atst.models import Base, types, mixins from atst.models import Base, types, mixins
from atst.database import db from atst.database import db
from atst.uploader import UploadError from atst.domain.exceptions import NotFoundError, UploadError
from atst.domain.exceptions import NotFoundError
class AttachmentError(Exception): class AttachmentError(Exception):

View File

@ -1,52 +1,2 @@
from tempfile import NamedTemporaryFile
from uuid import uuid4
from libcloud.storage.types import Provider
from libcloud.storage.providers import get_driver
class UploadError(Exception):
pass
class Uploader:
_PERMITTED_MIMETYPES = ["application/pdf"]
def __init__(self, provider, container=None, key=None, secret=None):
self.container = self._get_container(provider, container, key, secret)
def upload(self, fyle):
# TODO: for hardening, we should probably use a better library for
# determining mimetype and not rely on FileUpload's determination
# TODO: we should set MAX_CONTENT_LENGTH in the config to prevent large
# uploads
if not fyle.mimetype in self._PERMITTED_MIMETYPES:
raise UploadError(
"could not upload {} with mimetype {}".format(
fyle.filename, fyle.mimetype
)
)
object_name = uuid4().hex
with NamedTemporaryFile() as tempfile:
tempfile.write(fyle.stream.read())
self.container.upload_object(
file_path=tempfile.name,
object_name=object_name,
extra={"acl": "private"},
)
return object_name
def download_stream(self, object_name):
obj = self.container.get_object(object_name=object_name)
with NamedTemporaryFile() as tempfile:
obj.download(tempfile.name, overwrite_existing=True)
return open(tempfile.name, "rb")
def _get_container(self, provider, container, key, secret):
if provider == "LOCAL":
key = container
container = ""
driver = get_driver(getattr(Provider, provider))(key=key, secret=secret)
return driver.get_container(container)

View File

@ -3,3 +3,4 @@ ENVIRONMENT = test
PGDATABASE = atat_test PGDATABASE = atat_test
CRL_DIRECTORY = tests/fixtures/crl CRL_DIRECTORY = tests/fixtures/crl
WTF_CSRF_ENABLED = false WTF_CSRF_ENABLED = false
STORAGE_PROVIDER=LOCAL

View File

@ -2,26 +2,23 @@ import os
import pytest import pytest
from werkzeug.datastructures import FileStorage from werkzeug.datastructures import FileStorage
from atst.uploader import Uploader, UploadError from atst.domain.csp.files import RackspaceFileProvider
from atst.domain.exceptions import UploadError
from tests.mocks import PDF_FILENAME from tests.mocks import PDF_FILENAME
@pytest.fixture(scope="function")
def upload_dir(tmpdir):
return tmpdir.mkdir("uploads")
@pytest.fixture @pytest.fixture
def uploader(upload_dir): def uploader(app):
return Uploader("LOCAL", container=upload_dir) return RackspaceFileProvider(app)
NONPDF_FILENAME = "tests/fixtures/disa-pki.html" NONPDF_FILENAME = "tests/fixtures/disa-pki.html"
def test_upload(uploader, upload_dir, pdf_upload): def test_upload(app, uploader, pdf_upload):
object_name = uploader.upload(pdf_upload) object_name = uploader.upload(pdf_upload)
upload_dir = app.config["STORAGE_CONTAINER"]
assert os.path.isfile(os.path.join(upload_dir, object_name)) assert os.path.isfile(os.path.join(upload_dir, object_name))
@ -32,17 +29,18 @@ def test_upload_fails_for_non_pdfs(uploader):
uploader.upload(fs) uploader.upload(fs)
def test_download_stream(upload_dir, uploader, pdf_upload): def test_download(app, uploader, pdf_upload):
# write pdf content to upload file storage and make sure it is flushed to # write pdf content to upload file storage and make sure it is flushed to
# disk # disk
pdf_upload.seek(0) pdf_upload.seek(0)
pdf_content = pdf_upload.read() pdf_content = pdf_upload.read()
pdf_upload.close() pdf_upload.close()
upload_dir = app.config["STORAGE_CONTAINER"]
full_path = os.path.join(upload_dir, "abc") full_path = os.path.join(upload_dir, "abc")
with open(full_path, "wb") as output_file: with open(full_path, "wb") as output_file:
output_file.write(pdf_content) output_file.write(pdf_content)
output_file.flush() output_file.flush()
stream = uploader.download_stream("abc") stream = uploader.download("abc")
stream_content = b"".join([b for b in stream]) stream_content = b"".join([b for b in stream])
assert pdf_content == stream_content assert pdf_content == stream_content