diff --git a/atst/domain/csp/files.py b/atst/domain/csp/files.py index d1fa07ae..9fbb545e 100644 --- a/atst/domain/csp/files.py +++ b/atst/domain/csp/files.py @@ -1,7 +1,27 @@ -from atst.uploader import Uploader +from tempfile import NamedTemporaryFile +from uuid import uuid4 + +from libcloud.storage.types import Provider +from libcloud.storage.providers import get_driver + +from atst.domain.exceptions import UploadError class FileProviderInterface: + _PERMITTED_MIMETYPES = ["application/pdf"] + + def _enforce_mimetype(self, fyle): + # TODO: for hardening, we should probably use a better library for + # determining mimetype and not rely on FileUpload's determination + # TODO: we should set MAX_CONTENT_LENGTH in the config to prevent large + # uploads + if not fyle.mimetype in self._PERMITTED_MIMETYPES: + raise UploadError( + "could not upload {} with mimetype {}".format( + fyle.filename, fyle.mimetype + ) + ) + def upload(self, fyle): # pragma: no cover """Store the file object `fyle` in the CSP. This method returns the object name that can be used to later look up the file.""" @@ -16,15 +36,36 @@ class FileProviderInterface: class RackspaceFileProvider(FileProviderInterface): def __init__(self, app): - self.uploader = Uploader( + self.container = self._get_container( provider=app.config.get("STORAGE_PROVIDER"), container=app.config.get("STORAGE_CONTAINER"), key=app.config.get("STORAGE_KEY"), secret=app.config.get("STORAGE_SECRET"), ) + def _get_container(self, provider, container=None, key=None, secret=None): + if provider == "LOCAL": # pragma: no branch + key = container + container = "" + + driver = get_driver(getattr(Provider, provider))(key=key, secret=secret) + return driver.get_container(container) + def upload(self, fyle): - return self.uploader.upload(fyle) + self._enforce_mimetype(fyle) + + object_name = uuid4().hex + with NamedTemporaryFile() as tempfile: + tempfile.write(fyle.stream.read()) + self.container.upload_object( + file_path=tempfile.name, + object_name=object_name, + extra={"acl": "private"}, + ) + return object_name def download(self, object_name): - return self.uploader.download_stream(object_name) + obj = self.container.get_object(object_name=object_name) + with NamedTemporaryFile() as tempfile: + obj.download(tempfile.name, overwrite_existing=True) + return open(tempfile.name, "rb") diff --git a/atst/domain/exceptions.py b/atst/domain/exceptions.py index ec574232..bad0d4b3 100644 --- a/atst/domain/exceptions.py +++ b/atst/domain/exceptions.py @@ -30,3 +30,7 @@ class UnauthenticatedError(Exception): @property def message(self): return str(self) + + +class UploadError(Exception): + pass diff --git a/atst/models/attachment.py b/atst/models/attachment.py index 2dee056f..e4a9d6c2 100644 --- a/atst/models/attachment.py +++ b/atst/models/attachment.py @@ -5,8 +5,7 @@ from flask import current_app as app from atst.models import Base, types, mixins from atst.database import db -from atst.uploader import UploadError -from atst.domain.exceptions import NotFoundError +from atst.domain.exceptions import NotFoundError, UploadError class AttachmentError(Exception): diff --git a/atst/uploader.py b/atst/uploader.py index 5e44e2ee..139597f9 100644 --- a/atst/uploader.py +++ b/atst/uploader.py @@ -1,52 +1,2 @@ -from tempfile import NamedTemporaryFile -from uuid import uuid4 - -from libcloud.storage.types import Provider -from libcloud.storage.providers import get_driver -class UploadError(Exception): - pass - - -class Uploader: - _PERMITTED_MIMETYPES = ["application/pdf"] - - def __init__(self, provider, container=None, key=None, secret=None): - self.container = self._get_container(provider, container, key, secret) - - def upload(self, fyle): - # TODO: for hardening, we should probably use a better library for - # determining mimetype and not rely on FileUpload's determination - # TODO: we should set MAX_CONTENT_LENGTH in the config to prevent large - # uploads - if not fyle.mimetype in self._PERMITTED_MIMETYPES: - raise UploadError( - "could not upload {} with mimetype {}".format( - fyle.filename, fyle.mimetype - ) - ) - - object_name = uuid4().hex - with NamedTemporaryFile() as tempfile: - tempfile.write(fyle.stream.read()) - self.container.upload_object( - file_path=tempfile.name, - object_name=object_name, - extra={"acl": "private"}, - ) - return object_name - - def download_stream(self, object_name): - obj = self.container.get_object(object_name=object_name) - with NamedTemporaryFile() as tempfile: - obj.download(tempfile.name, overwrite_existing=True) - return open(tempfile.name, "rb") - - def _get_container(self, provider, container, key, secret): - if provider == "LOCAL": - key = container - container = "" - - driver = get_driver(getattr(Provider, provider))(key=key, secret=secret) - return driver.get_container(container) diff --git a/config/test.ini b/config/test.ini index fcb43e48..3da77886 100644 --- a/config/test.ini +++ b/config/test.ini @@ -3,3 +3,4 @@ ENVIRONMENT = test PGDATABASE = atat_test CRL_DIRECTORY = tests/fixtures/crl WTF_CSRF_ENABLED = false +STORAGE_PROVIDER=LOCAL diff --git a/tests/test_uploader.py b/tests/domain/csp/test_files.py similarity index 70% rename from tests/test_uploader.py rename to tests/domain/csp/test_files.py index 1e8b86ba..0f50cd11 100644 --- a/tests/test_uploader.py +++ b/tests/domain/csp/test_files.py @@ -2,26 +2,23 @@ import os import pytest from werkzeug.datastructures import FileStorage -from atst.uploader import Uploader, UploadError +from atst.domain.csp.files import RackspaceFileProvider +from atst.domain.exceptions import UploadError from tests.mocks import PDF_FILENAME -@pytest.fixture(scope="function") -def upload_dir(tmpdir): - return tmpdir.mkdir("uploads") - - @pytest.fixture -def uploader(upload_dir): - return Uploader("LOCAL", container=upload_dir) +def uploader(app): + return RackspaceFileProvider(app) NONPDF_FILENAME = "tests/fixtures/disa-pki.html" -def test_upload(uploader, upload_dir, pdf_upload): +def test_upload(app, uploader, pdf_upload): object_name = uploader.upload(pdf_upload) + upload_dir = app.config["STORAGE_CONTAINER"] assert os.path.isfile(os.path.join(upload_dir, object_name)) @@ -32,17 +29,18 @@ def test_upload_fails_for_non_pdfs(uploader): uploader.upload(fs) -def test_download_stream(upload_dir, uploader, pdf_upload): +def test_download(app, uploader, pdf_upload): # write pdf content to upload file storage and make sure it is flushed to # disk pdf_upload.seek(0) pdf_content = pdf_upload.read() pdf_upload.close() + upload_dir = app.config["STORAGE_CONTAINER"] full_path = os.path.join(upload_dir, "abc") with open(full_path, "wb") as output_file: output_file.write(pdf_content) output_file.flush() - stream = uploader.download_stream("abc") + stream = uploader.download("abc") stream_content = b"".join([b for b in stream]) assert pdf_content == stream_content