test_copies.py 7.0 KB


  1. # Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License"). You
  4. # may not use this file except in compliance with the License. A copy of
  5. # the License is located at
  6. #
  7. # http://aws.amazon.com/apache2.0/
  8. #
  9. # or in the "license" file accompanying this file. This file is
  10. # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
  11. # ANY KIND, either express or implied. See the License for the specific
  12. # language governing permissions and limitations under the License.
  13. from s3transfer.copies import CopyObjectTask, CopyPartTask
  14. from __tests__ import BaseTaskTest, RecordingSubscriber
  15. class BaseCopyTaskTest(BaseTaskTest):
  16. def setUp(self):
  17. super().setUp()
  18. self.bucket = 'mybucket'
  19. self.key = 'mykey'
  20. self.copy_source = {'Bucket': 'mysourcebucket', 'Key': 'mysourcekey'}
  21. self.extra_args = {}
  22. self.callbacks = []
  23. self.size = 5
  24. class TestCopyObjectTask(BaseCopyTaskTest):
  25. def get_copy_task(self, **kwargs):
  26. default_kwargs = {
  27. 'client': self.client,
  28. 'copy_source': self.copy_source,
  29. 'bucket': self.bucket,
  30. 'key': self.key,
  31. 'extra_args': self.extra_args,
  32. 'callbacks': self.callbacks,
  33. 'size': self.size,
  34. }
  35. default_kwargs.update(kwargs)
  36. return self.get_task(CopyObjectTask, main_kwargs=default_kwargs)
  37. def test_main(self):
  38. self.stubber.add_response(
  39. 'copy_object',
  40. service_response={},
  41. expected_params={
  42. 'Bucket': self.bucket,
  43. 'Key': self.key,
  44. 'CopySource': self.copy_source,
  45. },
  46. )
  47. task = self.get_copy_task()
  48. task()
  49. self.stubber.assert_no_pending_responses()
  50. def test_extra_args(self):
  51. self.extra_args['ACL'] = 'private'
  52. self.stubber.add_response(
  53. 'copy_object',
  54. service_response={},
  55. expected_params={
  56. 'Bucket': self.bucket,
  57. 'Key': self.key,
  58. 'CopySource': self.copy_source,
  59. 'ACL': 'private',
  60. },
  61. )
  62. task = self.get_copy_task()
  63. task()
  64. self.stubber.assert_no_pending_responses()
  65. def test_callbacks_invoked(self):
  66. subscriber = RecordingSubscriber()
  67. self.callbacks.append(subscriber.on_progress)
  68. self.stubber.add_response(
  69. 'copy_object',
  70. service_response={},
  71. expected_params={
  72. 'Bucket': self.bucket,
  73. 'Key': self.key,
  74. 'CopySource': self.copy_source,
  75. },
  76. )
  77. task = self.get_copy_task()
  78. task()
  79. self.stubber.assert_no_pending_responses()
  80. self.assertEqual(subscriber.calculate_bytes_seen(), self.size)
  81. class TestCopyPartTask(BaseCopyTaskTest):
  82. def setUp(self):
  83. super().setUp()
  84. self.copy_source_range = 'bytes=5-9'
  85. self.extra_args['CopySourceRange'] = self.copy_source_range
  86. self.upload_id = 'myuploadid'
  87. self.part_number = 1
  88. self.result_etag = 'my-etag'
  89. self.checksum_sha1 = 'my-checksum_sha1'
  90. def get_copy_task(self, **kwargs):
  91. default_kwargs = {
  92. 'client': self.client,
  93. 'copy_source': self.copy_source,
  94. 'bucket': self.bucket,
  95. 'key': self.key,
  96. 'upload_id': self.upload_id,
  97. 'part_number': self.part_number,
  98. 'extra_args': self.extra_args,
  99. 'callbacks': self.callbacks,
  100. 'size': self.size,
  101. }
  102. default_kwargs.update(kwargs)
  103. return self.get_task(CopyPartTask, main_kwargs=default_kwargs)
  104. def test_main(self):
  105. self.stubber.add_response(
  106. 'upload_part_copy',
  107. service_response={'CopyPartResult': {'ETag': self.result_etag}},
  108. expected_params={
  109. 'Bucket': self.bucket,
  110. 'Key': self.key,
  111. 'CopySource': self.copy_source,
  112. 'UploadId': self.upload_id,
  113. 'PartNumber': self.part_number,
  114. 'CopySourceRange': self.copy_source_range,
  115. },
  116. )
  117. task = self.get_copy_task()
  118. self.assertEqual(
  119. task(), {'PartNumber': self.part_number, 'ETag': self.result_etag}
  120. )
  121. self.stubber.assert_no_pending_responses()
  122. def test_main_with_checksum(self):
  123. self.stubber.add_response(
  124. 'upload_part_copy',
  125. service_response={
  126. 'CopyPartResult': {
  127. 'ETag': self.result_etag,
  128. 'ChecksumSHA1': self.checksum_sha1,
  129. }
  130. },
  131. expected_params={
  132. 'Bucket': self.bucket,
  133. 'Key': self.key,
  134. 'CopySource': self.copy_source,
  135. 'UploadId': self.upload_id,
  136. 'PartNumber': self.part_number,
  137. 'CopySourceRange': self.copy_source_range,
  138. },
  139. )
  140. task = self.get_copy_task(checksum_algorithm="sha1")
  141. self.assertEqual(
  142. task(),
  143. {
  144. 'PartNumber': self.part_number,
  145. 'ETag': self.result_etag,
  146. 'ChecksumSHA1': self.checksum_sha1,
  147. },
  148. )
  149. self.stubber.assert_no_pending_responses()
  150. def test_extra_args(self):
  151. self.extra_args['RequestPayer'] = 'requester'
  152. self.stubber.add_response(
  153. 'upload_part_copy',
  154. service_response={'CopyPartResult': {'ETag': self.result_etag}},
  155. expected_params={
  156. 'Bucket': self.bucket,
  157. 'Key': self.key,
  158. 'CopySource': self.copy_source,
  159. 'UploadId': self.upload_id,
  160. 'PartNumber': self.part_number,
  161. 'CopySourceRange': self.copy_source_range,
  162. 'RequestPayer': 'requester',
  163. },
  164. )
  165. task = self.get_copy_task()
  166. self.assertEqual(
  167. task(), {'PartNumber': self.part_number, 'ETag': self.result_etag}
  168. )
  169. self.stubber.assert_no_pending_responses()
  170. def test_callbacks_invoked(self):
  171. subscriber = RecordingSubscriber()
  172. self.callbacks.append(subscriber.on_progress)
  173. self.stubber.add_response(
  174. 'upload_part_copy',
  175. service_response={'CopyPartResult': {'ETag': self.result_etag}},
  176. expected_params={
  177. 'Bucket': self.bucket,
  178. 'Key': self.key,
  179. 'CopySource': self.copy_source,
  180. 'UploadId': self.upload_id,
  181. 'PartNumber': self.part_number,
  182. 'CopySourceRange': self.copy_source_range,
  183. },
  184. )
  185. task = self.get_copy_task()
  186. self.assertEqual(
  187. task(), {'PartNumber': self.part_number, 'ETag': self.result_etag}
  188. )
  189. self.stubber.assert_no_pending_responses()
  190. self.assertEqual(subscriber.calculate_bytes_seen(), self.size)