test_crt.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. # Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License"). You
  4. # may not use this file except in compliance with the License. A copy of
  5. # the License is located at
  6. #
  7. # http://aws.amazon.com/apache2.0/
  8. #
  9. # or in the "license" file accompanying this file. This file is
  10. # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
  11. # ANY KIND, either express or implied. See the License for the specific
  12. # language governing permissions and limitations under the License.
  13. import mock
  14. import unittest
  15. import threading
  16. import re
  17. from concurrent.futures import Future
  18. from botocore.session import Session
  19. from s3transfer.subscribers import BaseSubscriber
  20. from __tests__ import FileCreator
  21. from __tests__ import requires_crt, HAS_CRT
  22. if HAS_CRT:
  23. import s3transfer.crt
  24. import awscrt
  25. class submitThread(threading.Thread):
  26. def __init__(self, transfer_manager, futures, callargs):
  27. threading.Thread.__init__(self)
  28. self._transfer_manager = transfer_manager
  29. self._futures = futures
  30. self._callargs = callargs
  31. def run(self):
  32. self._futures.append(self._transfer_manager.download(*self._callargs))
  33. class RecordingSubscriber(BaseSubscriber):
  34. def __init__(self):
  35. self.on_queued_called = False
  36. self.on_done_called = False
  37. self.bytes_transferred = 0
  38. self.on_queued_future = None
  39. self.on_done_future = None
  40. def on_queued(self, future, **kwargs):
  41. self.on_queued_called = True
  42. self.on_queued_future = future
  43. def on_done(self, future, **kwargs):
  44. self.on_done_called = True
  45. self.on_done_future = future
  46. @requires_crt
  47. class TestCRTTransferManager(unittest.TestCase):
  48. def setUp(self):
  49. self.region = 'us-west-2'
  50. self.bucket = "test_bucket"
  51. self.key = "test_key"
  52. self.files = FileCreator()
  53. self.filename = self.files.create_file('myfile', 'my content')
  54. self.expected_path = "/" + self.bucket + "/" + self.key
  55. self.expected_host = "s3.%s.amazonaws.com" % (self.region)
  56. self.s3_request = mock.Mock(awscrt.s3.S3Request)
  57. self.s3_crt_client = mock.Mock(awscrt.s3.S3Client)
  58. self.s3_crt_client.make_request.return_value = self.s3_request
  59. self.session = Session()
  60. self.session.set_config_variable('region', self.region)
  61. self.request_serializer = s3transfer.crt.BotocoreCRTRequestSerializer(
  62. self.session)
  63. self.transfer_manager = s3transfer.crt.CRTTransferManager(
  64. crt_s3_client=self.s3_crt_client,
  65. crt_request_serializer=self.request_serializer)
  66. self.record_subscriber = RecordingSubscriber()
  67. def tearDown(self):
  68. self.files.remove_all()
  69. def _assert_subscribers_called(self, expected_future=None):
  70. self.assertTrue(self.record_subscriber.on_queued_called)
  71. self.assertTrue(self.record_subscriber.on_done_called)
  72. if expected_future:
  73. self.assertIs(
  74. self.record_subscriber.on_queued_future,
  75. expected_future)
  76. self.assertIs(
  77. self.record_subscriber.on_done_future,
  78. expected_future)
  79. def _invoke_done_callbacks(self, **kwargs):
  80. callargs = self.s3_crt_client.make_request.call_args
  81. callargs_kwargs = callargs[1]
  82. on_done = callargs_kwargs["on_done"]
  83. on_done(error=None)
  84. def _simulate_file_download(self, recv_filepath):
  85. self.files.create_file(recv_filepath, "fake resopnse")
  86. def _simulate_make_request_side_effect(self, **kwargs):
  87. if kwargs.get('recv_filepath'):
  88. self._simulate_file_download(kwargs['recv_filepath'])
  89. self._invoke_done_callbacks()
  90. return mock.DEFAULT
  91. def test_upload(self):
  92. self.s3_crt_client.make_request.side_effect = self._simulate_make_request_side_effect
  93. future = self.transfer_manager.upload(
  94. self.filename, self.bucket, self.key, {}, [self.record_subscriber])
  95. future.result()
  96. callargs = self.s3_crt_client.make_request.call_args
  97. callargs_kwargs = callargs[1]
  98. self.assertEqual(callargs_kwargs["send_filepath"], self.filename)
  99. self.assertIsNone(callargs_kwargs["recv_filepath"])
  100. self.assertEqual(callargs_kwargs["type"],
  101. awscrt.s3.S3RequestType.PUT_OBJECT)
  102. crt_request = callargs_kwargs["request"]
  103. self.assertEqual("PUT", crt_request.method)
  104. self.assertEqual(self.expected_path, crt_request.path)
  105. self.assertEqual(self.expected_host, crt_request.headers.get("host"))
  106. self._assert_subscribers_called(future)
  107. def test_download(self):
  108. self.s3_crt_client.make_request.side_effect = self._simulate_make_request_side_effect
  109. future = self.transfer_manager.download(
  110. self.bucket, self.key, self.filename, {}, [self.record_subscriber])
  111. future.result()
  112. callargs = self.s3_crt_client.make_request.call_args
  113. callargs_kwargs = callargs[1]
  114. # the recv_filepath will be set to a temporary file path with some
  115. # random suffix
  116. self.assertTrue(re.match(self.filename + ".*",
  117. callargs_kwargs["recv_filepath"]))
  118. self.assertIsNone(callargs_kwargs["send_filepath"])
  119. self.assertEqual(callargs_kwargs["type"],
  120. awscrt.s3.S3RequestType.GET_OBJECT)
  121. crt_request = callargs_kwargs["request"]
  122. self.assertEqual("GET", crt_request.method)
  123. self.assertEqual(self.expected_path, crt_request.path)
  124. self.assertEqual(self.expected_host, crt_request.headers.get("host"))
  125. self._assert_subscribers_called(future)
  126. with open(self.filename, 'rb') as f:
  127. # Check the fake response overwrites the file because of download
  128. self.assertEqual(f.read(), b'fake resopnse')
  129. def test_delete(self):
  130. self.s3_crt_client.make_request.side_effect = self._simulate_make_request_side_effect
  131. future = self.transfer_manager.delete(
  132. self.bucket, self.key, {}, [self.record_subscriber])
  133. future.result()
  134. callargs = self.s3_crt_client.make_request.call_args
  135. callargs_kwargs = callargs[1]
  136. self.assertIsNone(callargs_kwargs["send_filepath"])
  137. self.assertIsNone(callargs_kwargs["recv_filepath"])
  138. self.assertEqual(callargs_kwargs["type"],
  139. awscrt.s3.S3RequestType.DEFAULT)
  140. crt_request = callargs_kwargs["request"]
  141. self.assertEqual("DELETE", crt_request.method)
  142. self.assertEqual(self.expected_path, crt_request.path)
  143. self.assertEqual(self.expected_host, crt_request.headers.get("host"))
  144. self._assert_subscribers_called(future)
  145. def test_blocks_when_max_requests_processes_reached(self):
  146. futures = []
  147. callargs = (self.bucket, self.key, self.filename, {}, [])
  148. max_request_processes = 128 # the hard coded max processes
  149. all_concurrent = max_request_processes + 1
  150. threads = []
  151. for i in range(0, all_concurrent):
  152. thread = submitThread(self.transfer_manager, futures, callargs)
  153. thread.start()
  154. threads.append(thread)
  155. self.assertLessEqual(
  156. self.s3_crt_client.make_request.call_count,
  157. max_request_processes)
  158. # Release lock
  159. callargs = self.s3_crt_client.make_request.call_args
  160. callargs_kwargs = callargs[1]
  161. on_done = callargs_kwargs["on_done"]
  162. on_done(error=None)
  163. for thread in threads:
  164. thread.join()
  165. self.assertEqual(
  166. self.s3_crt_client.make_request.call_count,
  167. all_concurrent)
  168. def _cancel_function(self):
  169. self.cancel_called = True
  170. self.s3_request.finished_future.set_exception(
  171. awscrt.exceptions.from_code(0))
  172. self._invoke_done_callbacks()
  173. def test_cancel(self):
  174. self.s3_request.finished_future = Future()
  175. self.cancel_called = False
  176. self.s3_request.cancel = self._cancel_function
  177. try:
  178. with self.transfer_manager:
  179. future = self.transfer_manager.upload(
  180. self.filename, self.bucket, self.key, {}, [])
  181. raise KeyboardInterrupt()
  182. except KeyboardInterrupt:
  183. pass
  184. with self.assertRaises(awscrt.exceptions.AwsCrtError):
  185. future.result()
  186. self.assertTrue(self.cancel_called)
  187. def test_serializer_error_handling(self):
  188. class SerializationException(Exception):
  189. pass
  190. class ExceptionRaisingSerializer(s3transfer.crt.BaseCRTRequestSerializer):
  191. def serialize_http_request(self, transfer_type, future):
  192. raise SerializationException()
  193. not_impl_serializer = ExceptionRaisingSerializer()
  194. transfer_manager = s3transfer.crt.CRTTransferManager(
  195. crt_s3_client=self.s3_crt_client,
  196. crt_request_serializer=not_impl_serializer)
  197. future = transfer_manager.upload(
  198. self.filename, self.bucket, self.key, {}, [])
  199. with self.assertRaises(SerializationException):
  200. future.result()
  201. def test_crt_s3_client_error_handling(self):
  202. self.s3_crt_client.make_request.side_effect = awscrt.exceptions.from_code(
  203. 0)
  204. future = self.transfer_manager.upload(
  205. self.filename, self.bucket, self.key, {}, [])
  206. with self.assertRaises(awscrt.exceptions.AwsCrtError):
  207. future.result()