test_crt.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570
  1. # Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License"). You
  4. # may not use this file except in compliance with the License. A copy of
  5. # the License is located at
  6. #
  7. # http://aws.amazon.com/apache2.0/
  8. #
  9. # or in the "license" file accompanying this file. This file is
  10. # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
  11. # ANY KIND, either express or implied. See the License for the specific
  12. # language governing permissions and limitations under the License.
  13. import fnmatch
  14. import io
  15. import threading
  16. import time
  17. from concurrent.futures import Future
  18. from botocore.session import Session
  19. from s3transfer.subscribers import BaseSubscriber
  20. from __tests__ import (
  21. HAS_CRT,
  22. FileCreator,
  23. NonSeekableReader,
  24. NonSeekableWriter,
  25. mock,
  26. requires_crt,
  27. unittest,
  28. )
  29. if HAS_CRT:
  30. import awscrt
  31. import s3transfer.crt
  32. class submitThread(threading.Thread):
  33. def __init__(self, transfer_manager, futures, callargs):
  34. threading.Thread.__init__(self)
  35. self._transfer_manager = transfer_manager
  36. self._futures = futures
  37. self._callargs = callargs
  38. def run(self):
  39. self._futures.append(self._transfer_manager.download(*self._callargs))
  40. class RecordingSubscriber(BaseSubscriber):
  41. def __init__(self):
  42. self.on_queued_called = False
  43. self.on_done_called = False
  44. self.bytes_transferred = 0
  45. self.on_queued_future = None
  46. self.on_done_future = None
  47. def on_queued(self, future, **kwargs):
  48. self.on_queued_called = True
  49. self.on_queued_future = future
  50. def on_done(self, future, **kwargs):
  51. self.on_done_called = True
  52. self.on_done_future = future
  53. @requires_crt
  54. class TestCRTTransferManager(unittest.TestCase):
  55. def setUp(self):
  56. self.region = 'us-west-2'
  57. self.bucket = "test_bucket"
  58. self.key = "test_key"
  59. self.expected_content = b'my content'
  60. self.expected_download_content = b'new content'
  61. self.files = FileCreator()
  62. self.filename = self.files.create_file(
  63. 'myfile', self.expected_content, mode='wb'
  64. )
  65. self.expected_path = "/" + self.bucket + "/" + self.key
  66. self.expected_host = "s3.%s.amazonaws.com" % (self.region)
  67. self.s3_request = mock.Mock(awscrt.s3.S3Request)
  68. self.s3_crt_client = mock.Mock(awscrt.s3.S3Client)
  69. self.s3_crt_client.make_request.side_effect = (
  70. self._simulate_make_request_side_effect
  71. )
  72. self.session = Session()
  73. self.session.set_config_variable('region', self.region)
  74. self.request_serializer = s3transfer.crt.BotocoreCRTRequestSerializer(
  75. self.session
  76. )
  77. self.transfer_manager = s3transfer.crt.CRTTransferManager(
  78. crt_s3_client=self.s3_crt_client,
  79. crt_request_serializer=self.request_serializer,
  80. )
  81. self.record_subscriber = RecordingSubscriber()
  82. def tearDown(self):
  83. self.files.remove_all()
  84. def _assert_expected_crt_http_request(
  85. self,
  86. crt_http_request,
  87. expected_http_method='GET',
  88. expected_host=None,
  89. expected_path=None,
  90. expected_body_content=None,
  91. expected_content_length=None,
  92. expected_missing_headers=None,
  93. ):
  94. if expected_host is None:
  95. expected_host = self.expected_host
  96. if expected_path is None:
  97. expected_path = self.expected_path
  98. self.assertEqual(crt_http_request.method, expected_http_method)
  99. self.assertEqual(crt_http_request.headers.get("host"), expected_host)
  100. self.assertEqual(crt_http_request.path, expected_path)
  101. if expected_body_content is not None:
  102. # Note: The underlying CRT awscrt.io.InputStream does not expose
  103. # a public read method so we have to reach into the private,
  104. # underlying stream to determine the content. We should update
  105. # to use a public interface if a public interface is ever exposed.
  106. self.assertEqual(
  107. crt_http_request.body_stream._stream.read(),
  108. expected_body_content,
  109. )
  110. if expected_content_length is not None:
  111. self.assertEqual(
  112. crt_http_request.headers.get('Content-Length'),
  113. str(expected_content_length),
  114. )
  115. if expected_missing_headers is not None:
  116. header_names = [
  117. header[0].lower() for header in crt_http_request.headers
  118. ]
  119. for expected_missing_header in expected_missing_headers:
  120. self.assertNotIn(expected_missing_header.lower(), header_names)
  121. def _assert_subscribers_called(self, expected_future=None):
  122. self.assertTrue(self.record_subscriber.on_queued_called)
  123. self.assertTrue(self.record_subscriber.on_done_called)
  124. if expected_future:
  125. self.assertIs(
  126. self.record_subscriber.on_queued_future, expected_future
  127. )
  128. self.assertIs(
  129. self.record_subscriber.on_done_future, expected_future
  130. )
  131. def _get_expected_upload_checksum_config(self, **overrides):
  132. checksum_config_kwargs = {
  133. 'algorithm': awscrt.s3.S3ChecksumAlgorithm.CRC32,
  134. 'location': awscrt.s3.S3ChecksumLocation.TRAILER,
  135. }
  136. checksum_config_kwargs.update(overrides)
  137. return awscrt.s3.S3ChecksumConfig(**checksum_config_kwargs)
  138. def _get_expected_download_checksum_config(self, **overrides):
  139. checksum_config_kwargs = {
  140. 'validate_response': True,
  141. }
  142. checksum_config_kwargs.update(overrides)
  143. return awscrt.s3.S3ChecksumConfig(**checksum_config_kwargs)
  144. def _invoke_done_callbacks(self, **kwargs):
  145. callargs = self.s3_crt_client.make_request.call_args
  146. callargs_kwargs = callargs[1]
  147. on_done = callargs_kwargs["on_done"]
  148. on_done(error=None)
  149. def _simulate_file_download(self, recv_filepath):
  150. self.files.create_file(
  151. recv_filepath, self.expected_download_content, mode='wb'
  152. )
  153. def _simulate_on_body_download(self, on_body_callback):
  154. on_body_callback(chunk=self.expected_download_content, offset=0)
  155. def _simulate_make_request_side_effect(self, **kwargs):
  156. if kwargs.get('recv_filepath'):
  157. self._simulate_file_download(kwargs['recv_filepath'])
  158. if kwargs.get('on_body'):
  159. self._simulate_on_body_download(kwargs['on_body'])
  160. self._invoke_done_callbacks()
  161. return self.s3_request
  162. def test_upload(self):
  163. future = self.transfer_manager.upload(
  164. self.filename, self.bucket, self.key, {}, [self.record_subscriber]
  165. )
  166. future.result()
  167. callargs_kwargs = self.s3_crt_client.make_request.call_args[1]
  168. self.assertEqual(
  169. callargs_kwargs,
  170. {
  171. 'request': mock.ANY,
  172. 'type': awscrt.s3.S3RequestType.PUT_OBJECT,
  173. 'send_filepath': self.filename,
  174. 'on_progress': mock.ANY,
  175. 'on_done': mock.ANY,
  176. 'checksum_config': self._get_expected_upload_checksum_config(),
  177. },
  178. )
  179. self._assert_expected_crt_http_request(
  180. callargs_kwargs["request"],
  181. expected_http_method='PUT',
  182. expected_content_length=len(self.expected_content),
  183. expected_missing_headers=['Content-MD5'],
  184. )
  185. self._assert_subscribers_called(future)
  186. def test_upload_from_seekable_stream(self):
  187. with open(self.filename, 'rb') as f:
  188. future = self.transfer_manager.upload(
  189. f, self.bucket, self.key, {}, [self.record_subscriber]
  190. )
  191. future.result()
  192. callargs_kwargs = self.s3_crt_client.make_request.call_args[1]
  193. self.assertEqual(
  194. callargs_kwargs,
  195. {
  196. 'request': mock.ANY,
  197. 'type': awscrt.s3.S3RequestType.PUT_OBJECT,
  198. 'send_filepath': None,
  199. 'on_progress': mock.ANY,
  200. 'on_done': mock.ANY,
  201. 'checksum_config': self._get_expected_upload_checksum_config(),
  202. },
  203. )
  204. self._assert_expected_crt_http_request(
  205. callargs_kwargs["request"],
  206. expected_http_method='PUT',
  207. expected_body_content=self.expected_content,
  208. expected_content_length=len(self.expected_content),
  209. expected_missing_headers=['Content-MD5'],
  210. )
  211. self._assert_subscribers_called(future)
  212. def test_upload_from_nonseekable_stream(self):
  213. nonseekable_stream = NonSeekableReader(self.expected_content)
  214. future = self.transfer_manager.upload(
  215. nonseekable_stream,
  216. self.bucket,
  217. self.key,
  218. {},
  219. [self.record_subscriber],
  220. )
  221. future.result()
  222. callargs_kwargs = self.s3_crt_client.make_request.call_args[1]
  223. self.assertEqual(
  224. callargs_kwargs,
  225. {
  226. 'request': mock.ANY,
  227. 'type': awscrt.s3.S3RequestType.PUT_OBJECT,
  228. 'send_filepath': None,
  229. 'on_progress': mock.ANY,
  230. 'on_done': mock.ANY,
  231. 'checksum_config': self._get_expected_upload_checksum_config(),
  232. },
  233. )
  234. self._assert_expected_crt_http_request(
  235. callargs_kwargs["request"],
  236. expected_http_method='PUT',
  237. expected_body_content=self.expected_content,
  238. expected_missing_headers=[
  239. 'Content-MD5',
  240. 'Content-Length',
  241. 'Transfer-Encoding',
  242. ],
  243. )
  244. self._assert_subscribers_called(future)
  245. def test_upload_override_checksum_algorithm(self):
  246. future = self.transfer_manager.upload(
  247. self.filename,
  248. self.bucket,
  249. self.key,
  250. {'ChecksumAlgorithm': 'CRC32C'},
  251. [self.record_subscriber],
  252. )
  253. future.result()
  254. callargs_kwargs = self.s3_crt_client.make_request.call_args[1]
  255. self.assertEqual(
  256. callargs_kwargs,
  257. {
  258. 'request': mock.ANY,
  259. 'type': awscrt.s3.S3RequestType.PUT_OBJECT,
  260. 'send_filepath': self.filename,
  261. 'on_progress': mock.ANY,
  262. 'on_done': mock.ANY,
  263. 'checksum_config': self._get_expected_upload_checksum_config(
  264. algorithm=awscrt.s3.S3ChecksumAlgorithm.CRC32C
  265. ),
  266. },
  267. )
  268. self._assert_expected_crt_http_request(
  269. callargs_kwargs["request"],
  270. expected_http_method='PUT',
  271. expected_content_length=len(self.expected_content),
  272. expected_missing_headers=[
  273. 'Content-MD5',
  274. 'x-amz-sdk-checksum-algorithm',
  275. 'X-Amz-Trailer',
  276. ],
  277. )
  278. self._assert_subscribers_called(future)
  279. def test_upload_override_checksum_algorithm_accepts_lowercase(self):
  280. future = self.transfer_manager.upload(
  281. self.filename,
  282. self.bucket,
  283. self.key,
  284. {'ChecksumAlgorithm': 'crc32c'},
  285. [self.record_subscriber],
  286. )
  287. future.result()
  288. callargs_kwargs = self.s3_crt_client.make_request.call_args[1]
  289. self.assertEqual(
  290. callargs_kwargs,
  291. {
  292. 'request': mock.ANY,
  293. 'type': awscrt.s3.S3RequestType.PUT_OBJECT,
  294. 'send_filepath': self.filename,
  295. 'on_progress': mock.ANY,
  296. 'on_done': mock.ANY,
  297. 'checksum_config': self._get_expected_upload_checksum_config(
  298. algorithm=awscrt.s3.S3ChecksumAlgorithm.CRC32C
  299. ),
  300. },
  301. )
  302. self._assert_expected_crt_http_request(
  303. callargs_kwargs["request"],
  304. expected_http_method='PUT',
  305. expected_content_length=len(self.expected_content),
  306. expected_missing_headers=[
  307. 'Content-MD5',
  308. 'x-amz-sdk-checksum-algorithm',
  309. 'X-Amz-Trailer',
  310. ],
  311. )
  312. self._assert_subscribers_called(future)
  313. def test_upload_throws_error_for_unsupported_checksum(self):
  314. with self.assertRaisesRegex(
  315. ValueError, 'ChecksumAlgorithm: UNSUPPORTED not supported'
  316. ):
  317. self.transfer_manager.upload(
  318. self.filename,
  319. self.bucket,
  320. self.key,
  321. {'ChecksumAlgorithm': 'UNSUPPORTED'},
  322. [self.record_subscriber],
  323. )
  324. def test_download(self):
  325. future = self.transfer_manager.download(
  326. self.bucket, self.key, self.filename, {}, [self.record_subscriber]
  327. )
  328. future.result()
  329. callargs_kwargs = self.s3_crt_client.make_request.call_args[1]
  330. self.assertEqual(
  331. callargs_kwargs,
  332. {
  333. 'request': mock.ANY,
  334. 'type': awscrt.s3.S3RequestType.GET_OBJECT,
  335. 'recv_filepath': mock.ANY,
  336. 'on_progress': mock.ANY,
  337. 'on_done': mock.ANY,
  338. 'on_body': None,
  339. 'checksum_config': self._get_expected_download_checksum_config(),
  340. },
  341. )
  342. # the recv_filepath will be set to a temporary file path with some
  343. # random suffix
  344. self.assertTrue(
  345. fnmatch.fnmatch(
  346. callargs_kwargs["recv_filepath"],
  347. f'{self.filename}.*',
  348. )
  349. )
  350. self._assert_expected_crt_http_request(
  351. callargs_kwargs["request"],
  352. expected_http_method='GET',
  353. expected_content_length=0,
  354. )
  355. self._assert_subscribers_called(future)
  356. with open(self.filename, 'rb') as f:
  357. # Check the fake response overwrites the file because of download
  358. self.assertEqual(f.read(), self.expected_download_content)
  359. def test_download_to_seekable_stream(self):
  360. with open(self.filename, 'wb') as f:
  361. future = self.transfer_manager.download(
  362. self.bucket, self.key, f, {}, [self.record_subscriber]
  363. )
  364. future.result()
  365. callargs_kwargs = self.s3_crt_client.make_request.call_args[1]
  366. self.assertEqual(
  367. callargs_kwargs,
  368. {
  369. 'request': mock.ANY,
  370. 'type': awscrt.s3.S3RequestType.GET_OBJECT,
  371. 'recv_filepath': None,
  372. 'on_progress': mock.ANY,
  373. 'on_done': mock.ANY,
  374. 'on_body': mock.ANY,
  375. 'checksum_config': self._get_expected_download_checksum_config(),
  376. },
  377. )
  378. self._assert_expected_crt_http_request(
  379. callargs_kwargs["request"],
  380. expected_http_method='GET',
  381. expected_content_length=0,
  382. )
  383. self._assert_subscribers_called(future)
  384. with open(self.filename, 'rb') as f:
  385. # Check the fake response overwrites the file because of download
  386. self.assertEqual(f.read(), self.expected_download_content)
  387. def test_download_to_nonseekable_stream(self):
  388. underlying_stream = io.BytesIO()
  389. nonseekable_stream = NonSeekableWriter(underlying_stream)
  390. future = self.transfer_manager.download(
  391. self.bucket,
  392. self.key,
  393. nonseekable_stream,
  394. {},
  395. [self.record_subscriber],
  396. )
  397. future.result()
  398. callargs_kwargs = self.s3_crt_client.make_request.call_args[1]
  399. self.assertEqual(
  400. callargs_kwargs,
  401. {
  402. 'request': mock.ANY,
  403. 'type': awscrt.s3.S3RequestType.GET_OBJECT,
  404. 'recv_filepath': None,
  405. 'on_progress': mock.ANY,
  406. 'on_done': mock.ANY,
  407. 'on_body': mock.ANY,
  408. 'checksum_config': self._get_expected_download_checksum_config(),
  409. },
  410. )
  411. self._assert_expected_crt_http_request(
  412. callargs_kwargs["request"],
  413. expected_http_method='GET',
  414. expected_content_length=0,
  415. )
  416. self._assert_subscribers_called(future)
  417. self.assertEqual(
  418. underlying_stream.getvalue(), self.expected_download_content
  419. )
  420. def test_delete(self):
  421. future = self.transfer_manager.delete(
  422. self.bucket, self.key, {}, [self.record_subscriber]
  423. )
  424. future.result()
  425. callargs_kwargs = self.s3_crt_client.make_request.call_args[1]
  426. self.assertEqual(
  427. callargs_kwargs,
  428. {
  429. 'request': mock.ANY,
  430. 'type': awscrt.s3.S3RequestType.DEFAULT,
  431. 'on_progress': mock.ANY,
  432. 'on_done': mock.ANY,
  433. },
  434. )
  435. self._assert_expected_crt_http_request(
  436. callargs_kwargs["request"],
  437. expected_http_method='DELETE',
  438. expected_content_length=0,
  439. )
  440. self._assert_subscribers_called(future)
  441. def test_blocks_when_max_requests_processes_reached(self):
  442. self.s3_crt_client.make_request.return_value = self.s3_request
  443. # We simulate blocking by not invoking the on_done callbacks for
  444. # all of the requests we send. The default side effect invokes all
  445. # callbacks so we need to unset the side effect to avoid on_done from
  446. # being called in the child threads.
  447. self.s3_crt_client.make_request.side_effect = None
  448. futures = []
  449. callargs = (self.bucket, self.key, self.filename, {}, [])
  450. max_request_processes = 128 # the hard coded max processes
  451. all_concurrent = max_request_processes + 1
  452. threads = []
  453. for i in range(0, all_concurrent):
  454. thread = submitThread(self.transfer_manager, futures, callargs)
  455. thread.start()
  456. threads.append(thread)
  457. # Sleep until the expected max requests has been reached
  458. while len(futures) < max_request_processes:
  459. time.sleep(0.05)
  460. self.assertLessEqual(
  461. self.s3_crt_client.make_request.call_count, max_request_processes
  462. )
  463. # Release lock
  464. callargs = self.s3_crt_client.make_request.call_args
  465. callargs_kwargs = callargs[1]
  466. on_done = callargs_kwargs["on_done"]
  467. on_done(error=None)
  468. for thread in threads:
  469. thread.join()
  470. self.assertEqual(
  471. self.s3_crt_client.make_request.call_count, all_concurrent
  472. )
  473. def _cancel_function(self):
  474. self.cancel_called = True
  475. self.s3_request.finished_future.set_exception(
  476. awscrt.exceptions.from_code(0)
  477. )
  478. self._invoke_done_callbacks()
  479. def test_cancel(self):
  480. self.s3_request.finished_future = Future()
  481. self.cancel_called = False
  482. self.s3_request.cancel = self._cancel_function
  483. try:
  484. with self.transfer_manager:
  485. future = self.transfer_manager.upload(
  486. self.filename, self.bucket, self.key, {}, []
  487. )
  488. raise KeyboardInterrupt()
  489. except KeyboardInterrupt:
  490. pass
  491. with self.assertRaises(awscrt.exceptions.AwsCrtError):
  492. future.result()
  493. self.assertTrue(self.cancel_called)
  494. def test_serializer_error_handling(self):
  495. class SerializationException(Exception):
  496. pass
  497. class ExceptionRaisingSerializer(
  498. s3transfer.crt.BaseCRTRequestSerializer
  499. ):
  500. def serialize_http_request(self, transfer_type, future):
  501. raise SerializationException()
  502. not_impl_serializer = ExceptionRaisingSerializer()
  503. transfer_manager = s3transfer.crt.CRTTransferManager(
  504. crt_s3_client=self.s3_crt_client,
  505. crt_request_serializer=not_impl_serializer,
  506. )
  507. future = transfer_manager.upload(
  508. self.filename, self.bucket, self.key, {}, []
  509. )
  510. with self.assertRaises(SerializationException):
  511. future.result()
  512. def test_crt_s3_client_error_handling(self):
  513. self.s3_crt_client.make_request.side_effect = (
  514. awscrt.exceptions.from_code(0)
  515. )
  516. future = self.transfer_manager.upload(
  517. self.filename, self.bucket, self.key, {}, []
  518. )
  519. with self.assertRaises(awscrt.exceptions.AwsCrtError):
  520. future.result()