__init__.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524
  1. # Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the 'License'). You
  4. # may not use this file except in compliance with the License. A copy of
  5. # the License is located at
  6. #
  7. # http://aws.amazon.com/apache2.0/
  8. #
  9. # or in the 'license' file accompanying this file. This file is
  10. # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
  11. # ANY KIND, either express or implied. See the License for the specific
  12. # language governing permissions and limitations under the License.
  13. import io
  14. import hashlib
  15. import math
  16. import os
  17. import platform
  18. import shutil
  19. import string
  20. import tempfile
  21. try:
  22. import unittest2 as unittest
  23. except ImportError:
  24. import unittest
  25. import botocore.session
  26. from botocore.stub import Stubber
  27. from botocore.compat import six
  28. from s3transfer.manager import TransferConfig
  29. from s3transfer.futures import IN_MEMORY_UPLOAD_TAG
  30. from s3transfer.futures import IN_MEMORY_DOWNLOAD_TAG
  31. from s3transfer.futures import TransferCoordinator
  32. from s3transfer.futures import TransferMeta
  33. from s3transfer.futures import TransferFuture
  34. from s3transfer.futures import BoundedExecutor
  35. from s3transfer.futures import NonThreadedExecutor
  36. from s3transfer.subscribers import BaseSubscriber
  37. from s3transfer.utils import OSUtils
  38. from s3transfer.utils import CallArgs
  39. from s3transfer.utils import TaskSemaphore
  40. from s3transfer.utils import SlidingWindowSemaphore
  41. ORIGINAL_EXECUTOR_CLS = BoundedExecutor.EXECUTOR_CLS
  42. # Detect if CRT is available for use
  43. try:
  44. import awscrt.s3
  45. HAS_CRT = True
  46. except ImportError:
  47. HAS_CRT = False
  48. def setup_package():
  49. if is_serial_implementation():
  50. BoundedExecutor.EXECUTOR_CLS = NonThreadedExecutor
  51. def teardown_package():
  52. BoundedExecutor.EXECUTOR_CLS = ORIGINAL_EXECUTOR_CLS
  53. def is_serial_implementation():
  54. return os.environ.get('USE_SERIAL_EXECUTOR', False)
  55. def assert_files_equal(first, second):
  56. if os.path.getsize(first) != os.path.getsize(second):
  57. raise AssertionError("Files are not equal: %s, %s" % (first, second))
  58. first_md5 = md5_checksum(first)
  59. second_md5 = md5_checksum(second)
  60. if first_md5 != second_md5:
  61. raise AssertionError(
  62. "Files are not equal: %s(md5=%s) != %s(md5=%s)" % (
  63. first, first_md5, second, second_md5))
  64. def md5_checksum(filename):
  65. checksum = hashlib.md5()
  66. with open(filename, 'rb') as f:
  67. for chunk in iter(lambda: f.read(8192), b''):
  68. checksum.update(chunk)
  69. return checksum.hexdigest()
  70. def random_bucket_name(prefix='s3transfer', num_chars=10):
  71. base = string.ascii_lowercase + string.digits
  72. random_bytes = bytearray(os.urandom(num_chars))
  73. return prefix + ''.join([base[b % len(base)] for b in random_bytes])
  74. def skip_if_windows(reason):
  75. """Decorator to skip tests that should not be run on windows.
  76. Example usage:
  77. @skip_if_windows("Not valid")
  78. def test_some_non_windows_stuff(self):
  79. self.assertEqual(...)
  80. """
  81. def decorator(func):
  82. return unittest.skipIf(
  83. platform.system() not in ['Darwin', 'Linux'], reason)(func)
  84. return decorator
  85. def skip_if_using_serial_implementation(reason):
  86. """Decorator to skip tests when running as the serial implementation"""
  87. def decorator(func):
  88. return unittest.skipIf(
  89. is_serial_implementation(), reason)(func)
  90. return decorator
  91. def requires_crt(cls, reason=None):
  92. if reason is None:
  93. reason = "Test requires awscrt to be installed."
  94. return unittest.skipIf(not HAS_CRT, reason)(cls)
  95. class StreamWithError(object):
  96. """A wrapper to simulate errors while reading from a stream
  97. :param stream: The underlying stream to read from
  98. :param exception_type: The exception type to throw
  99. :param num_reads: The number of times to allow a read before raising
  100. the exception. A value of zero indicates to raise the error on the
  101. first read.
  102. """
  103. def __init__(self, stream, exception_type, num_reads=0):
  104. self._stream = stream
  105. self._exception_type = exception_type
  106. self._num_reads = num_reads
  107. self._count = 0
  108. def read(self, n=-1):
  109. if self._count == self._num_reads:
  110. raise self._exception_type
  111. self._count += 1
  112. return self._stream.read(n)
  113. class FileSizeProvider(object):
  114. def __init__(self, file_size):
  115. self.file_size = file_size
  116. def on_queued(self, future, **kwargs):
  117. future.meta.provide_transfer_size(self.file_size)
  118. class FileCreator(object):
  119. def __init__(self):
  120. self.rootdir = tempfile.mkdtemp()
  121. def remove_all(self):
  122. shutil.rmtree(self.rootdir)
  123. def create_file(self, filename, contents, mode='w'):
  124. """Creates a file in a tmpdir
  125. ``filename`` should be a relative path, e.g. "foo/bar/baz.txt"
  126. It will be translated into a full path in a tmp dir.
  127. ``mode`` is the mode the file should be opened either as ``w`` or
  128. `wb``.
  129. Returns the full path to the file.
  130. """
  131. full_path = os.path.join(self.rootdir, filename)
  132. if not os.path.isdir(os.path.dirname(full_path)):
  133. os.makedirs(os.path.dirname(full_path))
  134. with open(full_path, mode) as f:
  135. f.write(contents)
  136. return full_path
  137. def create_file_with_size(self, filename, filesize):
  138. filename = self.create_file(filename, contents='')
  139. chunksize = 8192
  140. with open(filename, 'wb') as f:
  141. for i in range(int(math.ceil(filesize / float(chunksize)))):
  142. f.write(b'a' * chunksize)
  143. return filename
  144. def append_file(self, filename, contents):
  145. """Append contents to a file
  146. ``filename`` should be a relative path, e.g. "foo/bar/baz.txt"
  147. It will be translated into a full path in a tmp dir.
  148. Returns the full path to the file.
  149. """
  150. full_path = os.path.join(self.rootdir, filename)
  151. if not os.path.isdir(os.path.dirname(full_path)):
  152. os.makedirs(os.path.dirname(full_path))
  153. with open(full_path, 'a') as f:
  154. f.write(contents)
  155. return full_path
  156. def full_path(self, filename):
  157. """Translate relative path to full path in temp dir.
  158. f.full_path('foo/bar.txt') -> /tmp/asdfasd/foo/bar.txt
  159. """
  160. return os.path.join(self.rootdir, filename)
  161. class RecordingOSUtils(OSUtils):
  162. """An OSUtil abstraction that records openings and renamings"""
  163. def __init__(self):
  164. super(RecordingOSUtils, self).__init__()
  165. self.open_records = []
  166. self.rename_records = []
  167. def open(self, filename, mode):
  168. self.open_records.append((filename, mode))
  169. return super(RecordingOSUtils, self).open(filename, mode)
  170. def rename_file(self, current_filename, new_filename):
  171. self.rename_records.append((current_filename, new_filename))
  172. super(RecordingOSUtils, self).rename_file(
  173. current_filename, new_filename)
  174. class RecordingSubscriber(BaseSubscriber):
  175. def __init__(self):
  176. self.on_queued_calls = []
  177. self.on_progress_calls = []
  178. self.on_done_calls = []
  179. def on_queued(self, **kwargs):
  180. self.on_queued_calls.append(kwargs)
  181. def on_progress(self, **kwargs):
  182. self.on_progress_calls.append(kwargs)
  183. def on_done(self, **kwargs):
  184. self.on_done_calls.append(kwargs)
  185. def calculate_bytes_seen(self, **kwargs):
  186. amount_seen = 0
  187. for call in self.on_progress_calls:
  188. amount_seen += call['bytes_transferred']
  189. return amount_seen
  190. class TransferCoordinatorWithInterrupt(TransferCoordinator):
  191. """Used to inject keyboard interrupts"""
  192. def result(self):
  193. raise KeyboardInterrupt()
  194. class RecordingExecutor(object):
  195. """A wrapper on an executor to record calls made to submit()
  196. You can access the submissions property to receive a list of dictionaries
  197. that represents all submissions where the dictionary is formatted::
  198. {
  199. 'fn': function
  200. 'args': positional args (as tuple)
  201. 'kwargs': keyword args (as dict)
  202. }
  203. """
  204. def __init__(self, executor):
  205. self._executor = executor
  206. self.submissions = []
  207. def submit(self, task, tag=None, block=True):
  208. future = self._executor.submit(task, tag, block)
  209. self.submissions.append(
  210. {
  211. 'task': task,
  212. 'tag': tag,
  213. 'block': block
  214. }
  215. )
  216. return future
  217. def shutdown(self):
  218. self._executor.shutdown()
  219. class StubbedClientTest(unittest.TestCase):
  220. def setUp(self):
  221. self.session = botocore.session.get_session()
  222. self.region = 'us-west-2'
  223. self.client = self.session.create_client(
  224. 's3', self.region, aws_access_key_id='foo',
  225. aws_secret_access_key='bar')
  226. self.stubber = Stubber(self.client)
  227. self.stubber.activate()
  228. def tearDown(self):
  229. self.stubber.deactivate()
  230. def reset_stubber_with_new_client(self, override_client_kwargs):
  231. client_kwargs = {
  232. 'service_name': 's3',
  233. 'region_name': self.region,
  234. 'aws_access_key_id': 'foo',
  235. 'aws_secret_access_key': 'bar'
  236. }
  237. client_kwargs.update(override_client_kwargs)
  238. self.client = self.session.create_client(**client_kwargs)
  239. self.stubber = Stubber(self.client)
  240. self.stubber.activate()
  241. class BaseTaskTest(StubbedClientTest):
  242. def setUp(self):
  243. super(BaseTaskTest, self).setUp()
  244. self.transfer_coordinator = TransferCoordinator()
  245. def get_task(self, task_cls, **kwargs):
  246. if 'transfer_coordinator' not in kwargs:
  247. kwargs['transfer_coordinator'] = self.transfer_coordinator
  248. return task_cls(**kwargs)
  249. def get_transfer_future(self, call_args=None):
  250. return TransferFuture(
  251. meta=TransferMeta(call_args),
  252. coordinator=self.transfer_coordinator
  253. )
  254. class BaseSubmissionTaskTest(BaseTaskTest):
  255. def setUp(self):
  256. super(BaseSubmissionTaskTest, self).setUp()
  257. self.config = TransferConfig()
  258. self.osutil = OSUtils()
  259. self.executor = BoundedExecutor(
  260. 1000,
  261. 1,
  262. {
  263. IN_MEMORY_UPLOAD_TAG: TaskSemaphore(10),
  264. IN_MEMORY_DOWNLOAD_TAG: SlidingWindowSemaphore(10)
  265. }
  266. )
  267. def tearDown(self):
  268. super(BaseSubmissionTaskTest, self).tearDown()
  269. self.executor.shutdown()
  270. class BaseGeneralInterfaceTest(StubbedClientTest):
  271. """A general test class to ensure consistency across TransferManger methods
  272. This test should never be called and should be subclassed from to pick up
  273. the various tests that all TransferManager method must pass from a
  274. functionality standpoint.
  275. """
  276. __test__ = False
  277. def manager(self):
  278. """The transfer manager to use"""
  279. raise NotImplementedError('method is not implemented')
  280. @property
  281. def method(self):
  282. """The transfer manager method to invoke i.e. upload()"""
  283. raise NotImplementedError('method is not implemented')
  284. def create_call_kwargs(self):
  285. """The kwargs to be passed to the transfer manager method"""
  286. raise NotImplementedError('create_call_kwargs is not implemented')
  287. def create_invalid_extra_args(self):
  288. """A value for extra_args that will cause validation errors"""
  289. raise NotImplementedError(
  290. 'create_invalid_extra_args is not implemented')
  291. def create_stubbed_responses(self):
  292. """A list of stubbed responses that will cause the request to succeed
  293. The elements of this list is a dictionary that will be used as key
  294. word arguments to botocore.Stubber.add_response(). For example::
  295. [{'method': 'put_object', 'service_response': {}}]
  296. """
  297. raise NotImplementedError(
  298. 'create_stubbed_responses is not implemented')
  299. def create_expected_progress_callback_info(self):
  300. """A list of kwargs expected to be passed to each progress callback
  301. Note that the future kwargs does not need to be added to each
  302. dictionary provided in the list. This is injected for you. An example
  303. is::
  304. [
  305. {'bytes_transferred': 4},
  306. {'bytes_transferred': 4},
  307. {'bytes_transferred': 2}
  308. ]
  309. This indicates that the progress callback will be called three
  310. times and pass along the specified keyword arguments and corresponding
  311. values.
  312. """
  313. raise NotImplementedError(
  314. 'create_expected_progress_callback_info is not implemented')
  315. def _setup_default_stubbed_responses(self):
  316. for stubbed_response in self.create_stubbed_responses():
  317. self.stubber.add_response(**stubbed_response)
  318. def test_returns_future_with_meta(self):
  319. self._setup_default_stubbed_responses()
  320. future = self.method(**self.create_call_kwargs())
  321. # The result is called so we ensure that the entire process executes
  322. # before we try to clean up resources in the tearDown.
  323. future.result()
  324. # Assert the return value is a future with metadata associated to it.
  325. self.assertIsInstance(future, TransferFuture)
  326. self.assertIsInstance(future.meta, TransferMeta)
  327. def test_returns_correct_call_args(self):
  328. self._setup_default_stubbed_responses()
  329. call_kwargs = self.create_call_kwargs()
  330. future = self.method(**call_kwargs)
  331. # The result is called so we ensure that the entire process executes
  332. # before we try to clean up resources in the tearDown.
  333. future.result()
  334. # Assert that there are call args associated to the metadata
  335. self.assertIsInstance(future.meta.call_args, CallArgs)
  336. # Assert that all of the arguments passed to the method exist and
  337. # are of the correct value in call_args.
  338. for param, value in call_kwargs.items():
  339. self.assertEqual(value, getattr(future.meta.call_args, param))
  340. def test_has_transfer_id_associated_to_future(self):
  341. self._setup_default_stubbed_responses()
  342. call_kwargs = self.create_call_kwargs()
  343. future = self.method(**call_kwargs)
  344. # The result is called so we ensure that the entire process executes
  345. # before we try to clean up resources in the tearDown.
  346. future.result()
  347. # Assert that an transfer id was associated to the future.
  348. # Since there is only one transfer request is made for that transfer
  349. # manager the id will be zero since it will be the first transfer
  350. # request made for that transfer manager.
  351. self.assertEqual(future.meta.transfer_id, 0)
  352. # If we make a second request, the transfer id should have incremented
  353. # by one for that new TransferFuture.
  354. self._setup_default_stubbed_responses()
  355. future = self.method(**call_kwargs)
  356. future.result()
  357. self.assertEqual(future.meta.transfer_id, 1)
  358. def test_invalid_extra_args(self):
  359. with self.assertRaisesRegexp(ValueError, 'Invalid extra_args'):
  360. self.method(
  361. extra_args=self.create_invalid_extra_args(),
  362. **self.create_call_kwargs()
  363. )
  364. def test_for_callback_kwargs_correctness(self):
  365. # Add the stubbed responses before invoking the method
  366. self._setup_default_stubbed_responses()
  367. subscriber = RecordingSubscriber()
  368. future = self.method(
  369. subscribers=[subscriber], **self.create_call_kwargs())
  370. # We call shutdown instead of result on future because the future
  371. # could be finished but the done callback could still be going.
  372. # The manager's shutdown method ensures everything completes.
  373. self.manager.shutdown()
  374. # Assert the various subscribers were called with the
  375. # expected kwargs
  376. expected_progress_calls = self.create_expected_progress_callback_info()
  377. for expected_progress_call in expected_progress_calls:
  378. expected_progress_call['future'] = future
  379. self.assertEqual(subscriber.on_queued_calls, [{'future': future}])
  380. self.assertEqual(subscriber.on_progress_calls, expected_progress_calls)
  381. self.assertEqual(subscriber.on_done_calls, [{'future': future}])
  382. class NonSeekableReader(io.RawIOBase):
  383. def __init__(self, b=b''):
  384. super(NonSeekableReader, self).__init__()
  385. self._data = six.BytesIO(b)
  386. def seekable(self):
  387. return False
  388. def writable(self):
  389. return False
  390. def readable(self):
  391. return True
  392. def write(self, b):
  393. # This is needed because python will not always return the correct
  394. # kind of error even though writeable returns False.
  395. raise io.UnsupportedOperation("write")
  396. def read(self, n=-1):
  397. return self._data.read(n)
  398. class NonSeekableWriter(io.RawIOBase):
  399. def __init__(self, fileobj):
  400. super(NonSeekableWriter, self).__init__()
  401. self._fileobj = fileobj
  402. def seekable(self):
  403. return False
  404. def writable(self):
  405. return True
  406. def readable(self):
  407. return False
  408. def write(self, b):
  409. self._fileobj.write(b)
  410. def read(self, n=-1):
  411. raise io.UnsupportedOperation("read")