test_processpool.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License"). You
  4. # may not use this file except in compliance with the License. A copy of
  5. # the License is located at
  6. #
  7. # http://aws.amazon.com/apache2.0/
  8. #
  9. # or in the "license" file accompanying this file. This file is
  10. # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
  11. # ANY KIND, either express or implied. See the License for the specific
  12. # language governing permissions and limitations under the License.
  13. import glob
  14. import os
  15. from multiprocessing.managers import BaseManager
  16. import mock
  17. import botocore.exceptions
  18. import botocore.session
  19. from botocore.stub import Stubber
  20. from __tests__ import unittest
  21. from __tests__ import FileCreator
  22. from s3transfer.compat import six
  23. from s3transfer.exceptions import CancelledError
  24. from s3transfer.processpool import ProcessTransferConfig
  25. from s3transfer.processpool import ProcessPoolDownloader
  26. from s3transfer.processpool import ClientFactory
  27. class StubbedClient(object):
  28. def __init__(self):
  29. self._client = botocore.session.get_session().create_client(
  30. 's3', 'us-west-2', aws_access_key_id='foo',
  31. aws_secret_access_key='bar')
  32. self._stubber = Stubber(self._client)
  33. self._stubber.activate()
  34. self._caught_stubber_errors = []
  35. def get_object(self, **kwargs):
  36. return self._client.get_object(**kwargs)
  37. def head_object(self, **kwargs):
  38. return self._client.head_object(**kwargs)
  39. def add_response(self, *args, **kwargs):
  40. self._stubber.add_response(*args, **kwargs)
  41. def add_client_error(self, *args, **kwargs):
  42. self._stubber.add_client_error(*args, **kwargs)
  43. class StubbedClientManager(BaseManager):
  44. pass
  45. StubbedClientManager.register('StubbedClient', StubbedClient)
  46. # Ideally a Mock would be used here. However, they cannot be pickled
  47. # for Windows. So instead we define a factory class at the module level that
  48. # can return a stubbed client we initialized in the setUp.
  49. class StubbedClientFactory(object):
  50. def __init__(self, stubbed_client):
  51. self._stubbed_client = stubbed_client
  52. def __call__(self, *args, **kwargs):
  53. # The __call__ is defined so we can provide an instance of the
  54. # StubbedClientFactory to mock.patch() and have the instance be
  55. # returned when the patched class is instantiated.
  56. return self
  57. def create_client(self):
  58. return self._stubbed_client
  59. class TestProcessPoolDownloader(unittest.TestCase):
  60. def setUp(self):
  61. # The stubbed client needs to run in a manager to be shared across
  62. # processes and have it properly consume the stubbed response across
  63. # processes.
  64. self.manager = StubbedClientManager()
  65. self.manager.start()
  66. self.stubbed_client = self.manager.StubbedClient()
  67. self.stubbed_client_factory = StubbedClientFactory(self.stubbed_client)
  68. self.client_factory_patch = mock.patch(
  69. 's3transfer.processpool.ClientFactory',
  70. self.stubbed_client_factory
  71. )
  72. self.client_factory_patch.start()
  73. self.files = FileCreator()
  74. self.config = ProcessTransferConfig(
  75. max_request_processes=1
  76. )
  77. self.downloader = ProcessPoolDownloader(config=self.config)
  78. self.bucket = 'mybucket'
  79. self.key = 'mykey'
  80. self.filename = self.files.full_path('filename')
  81. self.remote_contents = b'my content'
  82. self.stream = six.BytesIO(self.remote_contents)
  83. def tearDown(self):
  84. self.manager.shutdown()
  85. self.client_factory_patch.stop()
  86. self.files.remove_all()
  87. def assert_contents(self, filename, expected_contents):
  88. self.assertTrue(os.path.exists(filename))
  89. with open(filename, 'rb') as f:
  90. self.assertEqual(f.read(), expected_contents)
  91. def test_download_file(self):
  92. self.stubbed_client.add_response(
  93. 'head_object', {'ContentLength': len(self.remote_contents)})
  94. self.stubbed_client.add_response(
  95. 'get_object', {'Body': self.stream}
  96. )
  97. with self.downloader:
  98. self.downloader.download_file(self.bucket, self.key, self.filename)
  99. self.assert_contents(self.filename, self.remote_contents)
  100. def test_download_multiple_files(self):
  101. self.stubbed_client.add_response(
  102. 'get_object', {'Body': self.stream}
  103. )
  104. self.stubbed_client.add_response(
  105. 'get_object', {'Body': six.BytesIO(self.remote_contents)}
  106. )
  107. with self.downloader:
  108. self.downloader.download_file(
  109. self.bucket, self.key, self.filename,
  110. expected_size=len(self.remote_contents))
  111. other_file = self.files.full_path('filename2')
  112. self.downloader.download_file(
  113. self.bucket, self.key, other_file,
  114. expected_size=len(self.remote_contents))
  115. self.assert_contents(self.filename, self.remote_contents)
  116. self.assert_contents(other_file, self.remote_contents)
  117. def test_download_file_ranged_download(self):
  118. half_of_content_length = int(len(self.remote_contents)/2)
  119. self.stubbed_client.add_response(
  120. 'head_object', {'ContentLength': len(self.remote_contents)})
  121. self.stubbed_client.add_response(
  122. 'get_object', {
  123. 'Body': six.BytesIO(
  124. self.remote_contents[:half_of_content_length])}
  125. )
  126. self.stubbed_client.add_response(
  127. 'get_object', {
  128. 'Body': six.BytesIO(
  129. self.remote_contents[half_of_content_length:])}
  130. )
  131. downloader = ProcessPoolDownloader(
  132. config=ProcessTransferConfig(
  133. multipart_chunksize=half_of_content_length,
  134. multipart_threshold=half_of_content_length,
  135. max_request_processes=1
  136. )
  137. )
  138. with downloader:
  139. downloader.download_file(self.bucket, self.key, self.filename)
  140. self.assert_contents(self.filename, self.remote_contents)
  141. def test_download_file_extra_args(self):
  142. self.stubbed_client.add_response(
  143. 'head_object', {'ContentLength': len(self.remote_contents)},
  144. expected_params={
  145. 'Bucket': self.bucket, 'Key': self.key,
  146. 'VersionId': 'versionid'
  147. }
  148. )
  149. self.stubbed_client.add_response(
  150. 'get_object', {'Body': self.stream},
  151. expected_params={
  152. 'Bucket': self.bucket, 'Key': self.key,
  153. 'VersionId': 'versionid'
  154. }
  155. )
  156. with self.downloader:
  157. self.downloader.download_file(
  158. self.bucket, self.key, self.filename,
  159. extra_args={'VersionId': 'versionid'}
  160. )
  161. self.assert_contents(self.filename, self.remote_contents)
  162. def test_download_file_expected_size(self):
  163. self.stubbed_client.add_response(
  164. 'get_object', {'Body': self.stream}
  165. )
  166. with self.downloader:
  167. self.downloader.download_file(
  168. self.bucket, self.key, self.filename,
  169. expected_size=len(self.remote_contents))
  170. self.assert_contents(self.filename, self.remote_contents)
  171. def test_cleans_up_tempfile_on_failure(self):
  172. self.stubbed_client.add_client_error('get_object', 'NoSuchKey')
  173. with self.downloader:
  174. self.downloader.download_file(
  175. self.bucket, self.key, self.filename,
  176. expected_size=len(self.remote_contents))
  177. self.assertFalse(os.path.exists(self.filename))
  178. # Any tempfile should have been erased as well
  179. possible_matches = glob.glob('%s*' % self.filename + os.extsep)
  180. self.assertEqual(possible_matches, [])
  181. def test_validates_extra_args(self):
  182. with self.downloader:
  183. with self.assertRaises(ValueError):
  184. self.downloader.download_file(
  185. self.bucket, self.key, self.filename,
  186. extra_args={'NotSupported': 'NotSupported'}
  187. )
  188. def test_result_with_success(self):
  189. self.stubbed_client.add_response(
  190. 'get_object', {'Body': self.stream}
  191. )
  192. with self.downloader:
  193. future = self.downloader.download_file(
  194. self.bucket, self.key, self.filename,
  195. expected_size=len(self.remote_contents))
  196. self.assertIsNone(future.result())
  197. def test_result_with_exception(self):
  198. self.stubbed_client.add_client_error('get_object', 'NoSuchKey')
  199. with self.downloader:
  200. future = self.downloader.download_file(
  201. self.bucket, self.key, self.filename,
  202. expected_size=len(self.remote_contents))
  203. with self.assertRaises(botocore.exceptions.ClientError):
  204. future.result()
  205. def test_result_with_cancel(self):
  206. self.stubbed_client.add_response(
  207. 'get_object', {'Body': self.stream}
  208. )
  209. with self.downloader:
  210. future = self.downloader.download_file(
  211. self.bucket, self.key, self.filename,
  212. expected_size=len(self.remote_contents))
  213. future.cancel()
  214. with self.assertRaises(CancelledError):
  215. future.result()
  216. def test_shutdown_with_no_downloads(self):
  217. downloader = ProcessPoolDownloader()
  218. try:
  219. downloader.shutdown()
  220. except AttributeError:
  221. self.fail(
  222. 'The downloader should be able to be shutdown even though '
  223. 'the downloader was never started.'
  224. )
  225. def test_shutdown_with_no_downloads_and_ctrl_c(self):
  226. # Special shutdown logic happens if a KeyboardInterrupt is raised in
  227. # the context manager. However, this logic can not happen if the
  228. # downloader was never started. So a KeyboardInterrupt should be
  229. # the only exception propagated.
  230. with self.assertRaises(KeyboardInterrupt):
  231. with self.downloader:
  232. raise KeyboardInterrupt()