test__refresh_worker.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. # Copyright 2023 Google LLC
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import pickle
  15. import random
  16. import threading
  17. import time
  18. import mock
  19. import pytest # type: ignore
  20. from google.auth import _refresh_worker, credentials, exceptions
  21. MAIN_THREAD_SLEEP_MS = 100 / 1000
  22. class MockCredentialsImpl(credentials.Credentials):
  23. def __init__(self, sleep_seconds=None):
  24. self.refresh_count = 0
  25. self.token = None
  26. self.sleep_seconds = sleep_seconds if sleep_seconds else None
  27. def refresh(self, request):
  28. if self.sleep_seconds:
  29. time.sleep(self.sleep_seconds)
  30. self.token = request
  31. self.refresh_count += 1
  32. @pytest.fixture
  33. def test_thread_count():
  34. return 25
  35. def _cred_spinlock(cred):
  36. while cred.token is None: # pragma: NO COVER
  37. time.sleep(MAIN_THREAD_SLEEP_MS)
  38. def test_invalid_start_refresh():
  39. w = _refresh_worker.RefreshThreadManager()
  40. with pytest.raises(exceptions.InvalidValue):
  41. w.start_refresh(None, None)
  42. def test_start_refresh():
  43. w = _refresh_worker.RefreshThreadManager()
  44. cred = MockCredentialsImpl()
  45. request = mock.MagicMock()
  46. assert w.start_refresh(cred, request)
  47. assert w._worker is not None
  48. _cred_spinlock(cred)
  49. assert cred.token == request
  50. assert cred.refresh_count == 1
  51. def test_nonblocking_start_refresh():
  52. w = _refresh_worker.RefreshThreadManager()
  53. cred = MockCredentialsImpl(sleep_seconds=1)
  54. request = mock.MagicMock()
  55. assert w.start_refresh(cred, request)
  56. assert w._worker is not None
  57. assert not cred.token
  58. assert cred.refresh_count == 0
  59. def test_multiple_refreshes_multiple_workers(test_thread_count):
  60. w = _refresh_worker.RefreshThreadManager()
  61. cred = MockCredentialsImpl()
  62. request = mock.MagicMock()
  63. def _thread_refresh():
  64. time.sleep(random.randrange(0, 5))
  65. assert w.start_refresh(cred, request)
  66. threads = [
  67. threading.Thread(target=_thread_refresh) for _ in range(test_thread_count)
  68. ]
  69. for t in threads:
  70. t.start()
  71. _cred_spinlock(cred)
  72. assert cred.token == request
  73. # There is a chance only one thread has enough time to perform a refresh.
  74. # Generally multiple threads will have time to perform a refresh
  75. assert cred.refresh_count > 0
  76. def test_refresh_error():
  77. w = _refresh_worker.RefreshThreadManager()
  78. cred = mock.MagicMock()
  79. request = mock.MagicMock()
  80. cred.refresh.side_effect = exceptions.RefreshError("Failed to refresh")
  81. assert w.start_refresh(cred, request)
  82. while w._worker._error_info is None: # pragma: NO COVER
  83. time.sleep(MAIN_THREAD_SLEEP_MS)
  84. assert w._worker is not None
  85. assert isinstance(w._worker._error_info, exceptions.RefreshError)
  86. def test_refresh_error_call_refresh_again():
  87. w = _refresh_worker.RefreshThreadManager()
  88. cred = mock.MagicMock()
  89. request = mock.MagicMock()
  90. cred.refresh.side_effect = exceptions.RefreshError("Failed to refresh")
  91. assert w.start_refresh(cred, request)
  92. while w._worker._error_info is None: # pragma: NO COVER
  93. time.sleep(MAIN_THREAD_SLEEP_MS)
  94. assert not w.start_refresh(cred, request)
  95. def test_refresh_dead_worker():
  96. cred = MockCredentialsImpl()
  97. request = mock.MagicMock()
  98. w = _refresh_worker.RefreshThreadManager()
  99. w._worker = None
  100. w.start_refresh(cred, request)
  101. _cred_spinlock(cred)
  102. assert cred.token == request
  103. assert cred.refresh_count == 1
  104. def test_pickle():
  105. w = _refresh_worker.RefreshThreadManager()
  106. pickled_manager = pickle.dumps(w)
  107. manager = pickle.loads(pickled_manager)
  108. assert isinstance(manager, _refresh_worker.RefreshThreadManager)