test_bandwidth.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465
  1. # Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License"). You
  4. # may not use this file except in compliance with the License. A copy of
  5. # the License is located at
  6. #
  7. # http://aws.amazon.com/apache2.0/
  8. #
  9. # or in the "license" file accompanying this file. This file is
  10. # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
  11. # ANY KIND, either express or implied. See the License for the specific
  12. # language governing permissions and limitations under the License.
  13. import os
  14. import shutil
  15. import tempfile
  16. import mock
  17. from __tests__ import unittest
  18. from s3transfer.bandwidth import RequestExceededException
  19. from s3transfer.bandwidth import RequestToken
  20. from s3transfer.bandwidth import TimeUtils
  21. from s3transfer.bandwidth import BandwidthLimiter
  22. from s3transfer.bandwidth import BandwidthLimitedStream
  23. from s3transfer.bandwidth import LeakyBucket
  24. from s3transfer.bandwidth import ConsumptionScheduler
  25. from s3transfer.bandwidth import BandwidthRateTracker
  26. from s3transfer.futures import TransferCoordinator
  27. class FixedIncrementalTickTimeUtils(TimeUtils):
  28. def __init__(self, seconds_per_tick=1.0):
  29. self._count = 0
  30. self._seconds_per_tick = seconds_per_tick
  31. def time(self):
  32. current_count = self._count
  33. self._count += self._seconds_per_tick
  34. return current_count
  35. class TestTimeUtils(unittest.TestCase):
  36. @mock.patch('time.time')
  37. def test_time(self, mock_time):
  38. mock_return_val = 1
  39. mock_time.return_value = mock_return_val
  40. time_utils = TimeUtils()
  41. self.assertEqual(time_utils.time(), mock_return_val)
  42. @mock.patch('time.sleep')
  43. def test_sleep(self, mock_sleep):
  44. time_utils = TimeUtils()
  45. time_utils.sleep(1)
  46. self.assertEqual(
  47. mock_sleep.call_args_list,
  48. [mock.call(1)]
  49. )
  50. class BaseBandwidthLimitTest(unittest.TestCase):
  51. def setUp(self):
  52. self.leaky_bucket = mock.Mock(LeakyBucket)
  53. self.time_utils = mock.Mock(TimeUtils)
  54. self.tempdir = tempfile.mkdtemp()
  55. self.content = b'a' * 1024 * 1024
  56. self.filename = os.path.join(self.tempdir, 'myfile')
  57. with open(self.filename, 'wb') as f:
  58. f.write(self.content)
  59. self.coordinator = TransferCoordinator()
  60. def tearDown(self):
  61. shutil.rmtree(self.tempdir)
  62. def assert_consume_calls(self, amts):
  63. expected_consume_args = [
  64. mock.call(amt, mock.ANY) for amt in amts
  65. ]
  66. self.assertEqual(
  67. self.leaky_bucket.consume.call_args_list,
  68. expected_consume_args
  69. )
  70. class TestBandwidthLimiter(BaseBandwidthLimitTest):
  71. def setUp(self):
  72. super(TestBandwidthLimiter, self).setUp()
  73. self.bandwidth_limiter = BandwidthLimiter(self.leaky_bucket)
  74. def test_get_bandwidth_limited_stream(self):
  75. with open(self.filename, 'rb') as f:
  76. stream = self.bandwidth_limiter.get_bandwith_limited_stream(
  77. f, self.coordinator)
  78. self.assertIsInstance(stream, BandwidthLimitedStream)
  79. self.assertEqual(stream.read(len(self.content)), self.content)
  80. self.assert_consume_calls(amts=[len(self.content)])
  81. def test_get_disabled_bandwidth_limited_stream(self):
  82. with open(self.filename, 'rb') as f:
  83. stream = self.bandwidth_limiter.get_bandwith_limited_stream(
  84. f, self.coordinator, enabled=False)
  85. self.assertIsInstance(stream, BandwidthLimitedStream)
  86. self.assertEqual(stream.read(len(self.content)), self.content)
  87. self.leaky_bucket.consume.assert_not_called()
  88. class TestBandwidthLimitedStream(BaseBandwidthLimitTest):
  89. def setUp(self):
  90. super(TestBandwidthLimitedStream, self).setUp()
  91. self.bytes_threshold = 1
  92. def tearDown(self):
  93. shutil.rmtree(self.tempdir)
  94. def get_bandwidth_limited_stream(self, f):
  95. return BandwidthLimitedStream(
  96. f, self.leaky_bucket, self.coordinator, self.time_utils,
  97. self.bytes_threshold)
  98. def assert_sleep_calls(self, amts):
  99. expected_sleep_args_list = [
  100. mock.call(amt) for amt in amts
  101. ]
  102. self.assertEqual(
  103. self.time_utils.sleep.call_args_list,
  104. expected_sleep_args_list
  105. )
  106. def get_unique_consume_request_tokens(self):
  107. return set(
  108. call_args[0][1] for call_args in
  109. self.leaky_bucket.consume.call_args_list
  110. )
  111. def test_read(self):
  112. with open(self.filename, 'rb') as f:
  113. stream = self.get_bandwidth_limited_stream(f)
  114. data = stream.read(len(self.content))
  115. self.assertEqual(self.content, data)
  116. self.assert_consume_calls(amts=[len(self.content)])
  117. self.assert_sleep_calls(amts=[])
  118. def test_retries_on_request_exceeded(self):
  119. with open(self.filename, 'rb') as f:
  120. stream = self.get_bandwidth_limited_stream(f)
  121. retry_time = 1
  122. amt_requested = len(self.content)
  123. self.leaky_bucket.consume.side_effect = [
  124. RequestExceededException(amt_requested, retry_time),
  125. len(self.content)
  126. ]
  127. data = stream.read(len(self.content))
  128. self.assertEqual(self.content, data)
  129. self.assert_consume_calls(amts=[amt_requested, amt_requested])
  130. self.assert_sleep_calls(amts=[retry_time])
  131. def test_with_transfer_coordinator_exception(self):
  132. self.coordinator.set_exception(ValueError())
  133. with open(self.filename, 'rb') as f:
  134. stream = self.get_bandwidth_limited_stream(f)
  135. with self.assertRaises(ValueError):
  136. stream.read(len(self.content))
  137. def test_read_when_bandwidth_limiting_disabled(self):
  138. with open(self.filename, 'rb') as f:
  139. stream = self.get_bandwidth_limited_stream(f)
  140. stream.disable_bandwidth_limiting()
  141. data = stream.read(len(self.content))
  142. self.assertEqual(self.content, data)
  143. self.assertFalse(self.leaky_bucket.consume.called)
  144. def test_read_toggle_disable_enable_bandwidth_limiting(self):
  145. with open(self.filename, 'rb') as f:
  146. stream = self.get_bandwidth_limited_stream(f)
  147. stream.disable_bandwidth_limiting()
  148. data = stream.read(1)
  149. self.assertEqual(self.content[:1], data)
  150. self.assert_consume_calls(amts=[])
  151. stream.enable_bandwidth_limiting()
  152. data = stream.read(len(self.content) - 1)
  153. self.assertEqual(self.content[1:], data)
  154. self.assert_consume_calls(amts=[len(self.content) - 1])
  155. def test_seek(self):
  156. mock_fileobj = mock.Mock()
  157. stream = self.get_bandwidth_limited_stream(mock_fileobj)
  158. stream.seek(1)
  159. self.assertEqual(
  160. mock_fileobj.seek.call_args_list,
  161. [mock.call(1, 0)]
  162. )
  163. def test_tell(self):
  164. mock_fileobj = mock.Mock()
  165. stream = self.get_bandwidth_limited_stream(mock_fileobj)
  166. stream.tell()
  167. self.assertEqual(
  168. mock_fileobj.tell.call_args_list,
  169. [mock.call()]
  170. )
  171. def test_close(self):
  172. mock_fileobj = mock.Mock()
  173. stream = self.get_bandwidth_limited_stream(mock_fileobj)
  174. stream.close()
  175. self.assertEqual(
  176. mock_fileobj.close.call_args_list,
  177. [mock.call()]
  178. )
  179. def test_context_manager(self):
  180. mock_fileobj = mock.Mock()
  181. stream = self.get_bandwidth_limited_stream(mock_fileobj)
  182. with stream as stream_handle:
  183. self.assertIs(stream_handle, stream)
  184. self.assertEqual(
  185. mock_fileobj.close.call_args_list,
  186. [mock.call()]
  187. )
  188. def test_reuses_request_token(self):
  189. with open(self.filename, 'rb') as f:
  190. stream = self.get_bandwidth_limited_stream(f)
  191. stream.read(1)
  192. stream.read(1)
  193. self.assertEqual(len(self.get_unique_consume_request_tokens()), 1)
  194. def test_request_tokens_unique_per_stream(self):
  195. with open(self.filename, 'rb') as f:
  196. stream = self.get_bandwidth_limited_stream(f)
  197. stream.read(1)
  198. with open(self.filename, 'rb') as f:
  199. stream = self.get_bandwidth_limited_stream(f)
  200. stream.read(1)
  201. self.assertEqual(len(self.get_unique_consume_request_tokens()), 2)
  202. def test_call_consume_after_reaching_threshold(self):
  203. self.bytes_threshold = 2
  204. with open(self.filename, 'rb') as f:
  205. stream = self.get_bandwidth_limited_stream(f)
  206. self.assertEqual(stream.read(1), self.content[:1])
  207. self.assert_consume_calls(amts=[])
  208. self.assertEqual(stream.read(1), self.content[1:2])
  209. self.assert_consume_calls(amts=[2])
  210. def test_resets_after_reaching_threshold(self):
  211. self.bytes_threshold = 2
  212. with open(self.filename, 'rb') as f:
  213. stream = self.get_bandwidth_limited_stream(f)
  214. self.assertEqual(stream.read(2), self.content[:2])
  215. self.assert_consume_calls(amts=[2])
  216. self.assertEqual(stream.read(1), self.content[2:3])
  217. self.assert_consume_calls(amts=[2])
  218. def test_pending_bytes_seen_on_close(self):
  219. self.bytes_threshold = 2
  220. with open(self.filename, 'rb') as f:
  221. stream = self.get_bandwidth_limited_stream(f)
  222. self.assertEqual(stream.read(1), self.content[:1])
  223. self.assert_consume_calls(amts=[])
  224. stream.close()
  225. self.assert_consume_calls(amts=[1])
  226. def test_no_bytes_remaining_on(self):
  227. self.bytes_threshold = 2
  228. with open(self.filename, 'rb') as f:
  229. stream = self.get_bandwidth_limited_stream(f)
  230. self.assertEqual(stream.read(2), self.content[:2])
  231. self.assert_consume_calls(amts=[2])
  232. stream.close()
  233. # There should have been no more consume() calls made
  234. # as all bytes have been accounted for in the previous
  235. # consume() call.
  236. self.assert_consume_calls(amts=[2])
  237. def test_disable_bandwidth_limiting_with_pending_bytes_seen_on_close(self):
  238. self.bytes_threshold = 2
  239. with open(self.filename, 'rb') as f:
  240. stream = self.get_bandwidth_limited_stream(f)
  241. self.assertEqual(stream.read(1), self.content[:1])
  242. self.assert_consume_calls(amts=[])
  243. stream.disable_bandwidth_limiting()
  244. stream.close()
  245. self.assert_consume_calls(amts=[])
  246. def test_signal_transferring(self):
  247. with open(self.filename, 'rb') as f:
  248. stream = self.get_bandwidth_limited_stream(f)
  249. stream.signal_not_transferring()
  250. data = stream.read(1)
  251. self.assertEqual(self.content[:1], data)
  252. self.assert_consume_calls(amts=[])
  253. stream.signal_transferring()
  254. data = stream.read(len(self.content) - 1)
  255. self.assertEqual(self.content[1:], data)
  256. self.assert_consume_calls(amts=[len(self.content) - 1])
  257. class TestLeakyBucket(unittest.TestCase):
  258. def setUp(self):
  259. self.max_rate = 1
  260. self.time_now = 1.0
  261. self.time_utils = mock.Mock(TimeUtils)
  262. self.time_utils.time.return_value = self.time_now
  263. self.scheduler = mock.Mock(ConsumptionScheduler)
  264. self.scheduler.is_scheduled.return_value = False
  265. self.rate_tracker = mock.Mock(BandwidthRateTracker)
  266. self.leaky_bucket = LeakyBucket(
  267. self.max_rate, self.time_utils, self.rate_tracker,
  268. self.scheduler
  269. )
  270. def set_projected_rate(self, rate):
  271. self.rate_tracker.get_projected_rate.return_value = rate
  272. def set_retry_time(self, retry_time):
  273. self.scheduler.schedule_consumption.return_value = retry_time
  274. def assert_recorded_consumed_amt(self, expected_amt):
  275. self.assertEqual(
  276. self.rate_tracker.record_consumption_rate.call_args,
  277. mock.call(expected_amt, self.time_utils.time.return_value))
  278. def assert_was_scheduled(self, amt, token):
  279. self.assertEqual(
  280. self.scheduler.schedule_consumption.call_args,
  281. mock.call(amt, token, amt/(self.max_rate))
  282. )
  283. def assert_nothing_scheduled(self):
  284. self.assertFalse(self.scheduler.schedule_consumption.called)
  285. def assert_processed_request_token(self, request_token):
  286. self.assertEqual(
  287. self.scheduler.process_scheduled_consumption.call_args,
  288. mock.call(request_token)
  289. )
  290. def test_consume_under_max_rate(self):
  291. amt = 1
  292. self.set_projected_rate(self.max_rate/2)
  293. self.assertEqual(self.leaky_bucket.consume(amt, RequestToken()), amt)
  294. self.assert_recorded_consumed_amt(amt)
  295. self.assert_nothing_scheduled()
  296. def test_consume_at_max_rate(self):
  297. amt = 1
  298. self.set_projected_rate(self.max_rate)
  299. self.assertEqual(self.leaky_bucket.consume(amt, RequestToken()), amt)
  300. self.assert_recorded_consumed_amt(amt)
  301. self.assert_nothing_scheduled()
  302. def test_consume_over_max_rate(self):
  303. amt = 1
  304. retry_time = 2.0
  305. self.set_projected_rate(self.max_rate + 1)
  306. self.set_retry_time(retry_time)
  307. request_token = RequestToken()
  308. try:
  309. self.leaky_bucket.consume(amt, request_token)
  310. self.fail('A RequestExceededException should have been thrown')
  311. except RequestExceededException as e:
  312. self.assertEqual(e.requested_amt, amt)
  313. self.assertEqual(e.retry_time, retry_time)
  314. self.assert_was_scheduled(amt, request_token)
  315. def test_consume_with_scheduled_retry(self):
  316. amt = 1
  317. self.set_projected_rate(self.max_rate + 1)
  318. self.scheduler.is_scheduled.return_value = True
  319. request_token = RequestToken()
  320. self.assertEqual(self.leaky_bucket.consume(amt, request_token), amt)
  321. # Nothing new should have been scheduled but the request token
  322. # should have been processed.
  323. self.assert_nothing_scheduled()
  324. self.assert_processed_request_token(request_token)
  325. class TestConsumptionScheduler(unittest.TestCase):
  326. def setUp(self):
  327. self.scheduler = ConsumptionScheduler()
  328. def test_schedule_consumption(self):
  329. token = RequestToken()
  330. consume_time = 5
  331. actual_wait_time = self.scheduler.schedule_consumption(
  332. 1, token, consume_time)
  333. self.assertEqual(consume_time, actual_wait_time)
  334. def test_schedule_consumption_for_multiple_requests(self):
  335. token = RequestToken()
  336. consume_time = 5
  337. actual_wait_time = self.scheduler.schedule_consumption(
  338. 1, token, consume_time)
  339. self.assertEqual(consume_time, actual_wait_time)
  340. other_consume_time = 3
  341. other_token = RequestToken()
  342. next_wait_time = self.scheduler.schedule_consumption(
  343. 1, other_token, other_consume_time)
  344. # This wait time should be the previous time plus its desired
  345. # wait time
  346. self.assertEqual(next_wait_time, consume_time + other_consume_time)
  347. def test_is_scheduled(self):
  348. token = RequestToken()
  349. consume_time = 5
  350. self.scheduler.schedule_consumption(1, token, consume_time)
  351. self.assertTrue(self.scheduler.is_scheduled(token))
  352. def test_is_not_scheduled(self):
  353. self.assertFalse(self.scheduler.is_scheduled(RequestToken()))
  354. def test_process_scheduled_consumption(self):
  355. token = RequestToken()
  356. consume_time = 5
  357. self.scheduler.schedule_consumption(1, token, consume_time)
  358. self.scheduler.process_scheduled_consumption(token)
  359. self.assertFalse(self.scheduler.is_scheduled(token))
  360. different_time = 7
  361. # The previous consume time should have no affect on the next wait tim
  362. # as it has been completed.
  363. self.assertEqual(
  364. self.scheduler.schedule_consumption(1, token, different_time),
  365. different_time
  366. )
  367. class TestBandwidthRateTracker(unittest.TestCase):
  368. def setUp(self):
  369. self.alpha = 0.8
  370. self.rate_tracker = BandwidthRateTracker(self.alpha)
  371. def test_current_rate_at_initilizations(self):
  372. self.assertEqual(self.rate_tracker.current_rate, 0.0)
  373. def test_current_rate_after_one_recorded_point(self):
  374. self.rate_tracker.record_consumption_rate(1, 1)
  375. # There is no last time point to do a diff against so return a
  376. # current rate of 0.0
  377. self.assertEqual(self.rate_tracker.current_rate, 0.0)
  378. def test_current_rate(self):
  379. self.rate_tracker.record_consumption_rate(1, 1)
  380. self.rate_tracker.record_consumption_rate(1, 2)
  381. self.rate_tracker.record_consumption_rate(1, 3)
  382. self.assertEqual(self.rate_tracker.current_rate, 0.96)
  383. def test_get_projected_rate_at_initilizations(self):
  384. self.assertEqual(self.rate_tracker.get_projected_rate(1, 1), 0.0)
  385. def test_get_projected_rate(self):
  386. self.rate_tracker.record_consumption_rate(1, 1)
  387. self.rate_tracker.record_consumption_rate(1, 2)
  388. projected_rate = self.rate_tracker.get_projected_rate(1, 3)
  389. self.assertEqual(projected_rate, 0.96)
  390. self.rate_tracker.record_consumption_rate(1, 3)
  391. self.assertEqual(self.rate_tracker.current_rate, projected_rate)
  392. def test_get_projected_rate_for_same_timestamp(self):
  393. self.rate_tracker.record_consumption_rate(1, 1)
  394. self.assertEqual(
  395. self.rate_tracker.get_projected_rate(1, 1),
  396. float('inf')
  397. )