discovery.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282
  1. # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License"). You
  4. # may not use this file except in compliance with the License. A copy of
  5. # the License is located at
  6. #
  7. # http://aws.amazon.com/apache2.0/
  8. #
  9. # or in the "license" file accompanying this file. This file is
  10. # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
  11. # ANY KIND, either express or implied. See the License for the specific
  12. # language governing permissions and limitations under the License.
  13. import logging
  14. import time
  15. import weakref
  16. from botocore import xform_name
  17. from botocore.exceptions import BotoCoreError, ConnectionError, HTTPClientError
  18. from botocore.model import OperationNotFoundError
  19. from botocore.utils import CachedProperty
  20. logger = logging.getLogger(__name__)
  21. class EndpointDiscoveryException(BotoCoreError):
  22. pass
  23. class EndpointDiscoveryRequired(EndpointDiscoveryException):
  24. """Endpoint Discovery is disabled but is required for this operation."""
  25. fmt = 'Endpoint Discovery is not enabled but this operation requires it.'
  26. class EndpointDiscoveryRefreshFailed(EndpointDiscoveryException):
  27. """Endpoint Discovery failed to the refresh the known endpoints."""
  28. fmt = 'Endpoint Discovery failed to refresh the required endpoints.'
  29. def block_endpoint_discovery_required_operations(model, **kwargs):
  30. endpoint_discovery = model.endpoint_discovery
  31. if endpoint_discovery and endpoint_discovery.get('required'):
  32. raise EndpointDiscoveryRequired()
  33. class EndpointDiscoveryModel:
  34. def __init__(self, service_model):
  35. self._service_model = service_model
  36. @CachedProperty
  37. def discovery_operation_name(self):
  38. discovery_operation = self._service_model.endpoint_discovery_operation
  39. return xform_name(discovery_operation.name)
  40. @CachedProperty
  41. def discovery_operation_keys(self):
  42. discovery_operation = self._service_model.endpoint_discovery_operation
  43. keys = []
  44. if discovery_operation.input_shape:
  45. keys = list(discovery_operation.input_shape.members.keys())
  46. return keys
  47. def discovery_required_for(self, operation_name):
  48. try:
  49. operation_model = self._service_model.operation_model(
  50. operation_name
  51. )
  52. return operation_model.endpoint_discovery.get('required', False)
  53. except OperationNotFoundError:
  54. return False
  55. def discovery_operation_kwargs(self, **kwargs):
  56. input_keys = self.discovery_operation_keys
  57. # Operation and Identifiers are only sent if there are Identifiers
  58. if not kwargs.get('Identifiers'):
  59. kwargs.pop('Operation', None)
  60. kwargs.pop('Identifiers', None)
  61. return {k: v for k, v in kwargs.items() if k in input_keys}
  62. def gather_identifiers(self, operation, params):
  63. return self._gather_ids(operation.input_shape, params)
  64. def _gather_ids(self, shape, params, ids=None):
  65. # Traverse the input shape and corresponding parameters, gathering
  66. # any input fields labeled as an endpoint discovery id
  67. if ids is None:
  68. ids = {}
  69. for member_name, member_shape in shape.members.items():
  70. if member_shape.metadata.get('endpointdiscoveryid'):
  71. ids[member_name] = params[member_name]
  72. elif (
  73. member_shape.type_name == 'structure' and member_name in params
  74. ):
  75. self._gather_ids(member_shape, params[member_name], ids)
  76. return ids
  77. class EndpointDiscoveryManager:
  78. def __init__(
  79. self, client, cache=None, current_time=None, always_discover=True
  80. ):
  81. if cache is None:
  82. cache = {}
  83. self._cache = cache
  84. self._failed_attempts = {}
  85. if current_time is None:
  86. current_time = time.time
  87. self._time = current_time
  88. self._always_discover = always_discover
  89. # This needs to be a weak ref in order to prevent memory leaks on
  90. # python 2.6
  91. self._client = weakref.proxy(client)
  92. self._model = EndpointDiscoveryModel(client.meta.service_model)
  93. def _parse_endpoints(self, response):
  94. endpoints = response['Endpoints']
  95. current_time = self._time()
  96. for endpoint in endpoints:
  97. cache_time = endpoint.get('CachePeriodInMinutes')
  98. endpoint['Expiration'] = current_time + cache_time * 60
  99. return endpoints
  100. def _cache_item(self, value):
  101. if isinstance(value, dict):
  102. return tuple(sorted(value.items()))
  103. else:
  104. return value
  105. def _create_cache_key(self, **kwargs):
  106. kwargs = self._model.discovery_operation_kwargs(**kwargs)
  107. return tuple(self._cache_item(v) for k, v in sorted(kwargs.items()))
  108. def gather_identifiers(self, operation, params):
  109. return self._model.gather_identifiers(operation, params)
  110. def delete_endpoints(self, **kwargs):
  111. cache_key = self._create_cache_key(**kwargs)
  112. if cache_key in self._cache:
  113. del self._cache[cache_key]
  114. def _describe_endpoints(self, **kwargs):
  115. # This is effectively a proxy to whatever name/kwargs the service
  116. # supports for endpoint discovery.
  117. kwargs = self._model.discovery_operation_kwargs(**kwargs)
  118. operation_name = self._model.discovery_operation_name
  119. discovery_operation = getattr(self._client, operation_name)
  120. logger.debug('Discovering endpoints with kwargs: %s', kwargs)
  121. return discovery_operation(**kwargs)
  122. def _get_current_endpoints(self, key):
  123. if key not in self._cache:
  124. return None
  125. now = self._time()
  126. return [e for e in self._cache[key] if now < e['Expiration']]
  127. def _refresh_current_endpoints(self, **kwargs):
  128. cache_key = self._create_cache_key(**kwargs)
  129. try:
  130. response = self._describe_endpoints(**kwargs)
  131. endpoints = self._parse_endpoints(response)
  132. self._cache[cache_key] = endpoints
  133. self._failed_attempts.pop(cache_key, None)
  134. return endpoints
  135. except (ConnectionError, HTTPClientError):
  136. self._failed_attempts[cache_key] = self._time() + 60
  137. return None
  138. def _recently_failed(self, cache_key):
  139. if cache_key in self._failed_attempts:
  140. now = self._time()
  141. if now < self._failed_attempts[cache_key]:
  142. return True
  143. del self._failed_attempts[cache_key]
  144. return False
  145. def _select_endpoint(self, endpoints):
  146. return endpoints[0]['Address']
  147. def describe_endpoint(self, **kwargs):
  148. operation = kwargs['Operation']
  149. discovery_required = self._model.discovery_required_for(operation)
  150. if not self._always_discover and not discovery_required:
  151. # Discovery set to only run on required operations
  152. logger.debug(
  153. 'Optional discovery disabled. Skipping discovery for Operation: %s'
  154. % operation
  155. )
  156. return None
  157. # Get the endpoint for the provided operation and identifiers
  158. cache_key = self._create_cache_key(**kwargs)
  159. endpoints = self._get_current_endpoints(cache_key)
  160. if endpoints:
  161. return self._select_endpoint(endpoints)
  162. # All known endpoints are stale
  163. recently_failed = self._recently_failed(cache_key)
  164. if not recently_failed:
  165. # We haven't failed to discover recently, go ahead and refresh
  166. endpoints = self._refresh_current_endpoints(**kwargs)
  167. if endpoints:
  168. return self._select_endpoint(endpoints)
  169. # Discovery has failed recently, do our best to get an endpoint
  170. logger.debug('Endpoint Discovery has failed for: %s', kwargs)
  171. stale_entries = self._cache.get(cache_key, None)
  172. if stale_entries:
  173. # We have stale entries, use those while discovery is failing
  174. return self._select_endpoint(stale_entries)
  175. if discovery_required:
  176. # It looks strange to be checking recently_failed again but,
  177. # this informs us as to whether or not we tried to refresh earlier
  178. if recently_failed:
  179. # Discovery is required and we haven't already refreshed
  180. endpoints = self._refresh_current_endpoints(**kwargs)
  181. if endpoints:
  182. return self._select_endpoint(endpoints)
  183. # No endpoints even refresh, raise hard error
  184. raise EndpointDiscoveryRefreshFailed()
  185. # Discovery is optional, just use the default endpoint for now
  186. return None
  187. class EndpointDiscoveryHandler:
  188. def __init__(self, manager):
  189. self._manager = manager
  190. def register(self, events, service_id):
  191. events.register(
  192. 'before-parameter-build.%s' % service_id, self.gather_identifiers
  193. )
  194. events.register_first(
  195. 'request-created.%s' % service_id, self.discover_endpoint
  196. )
  197. events.register('needs-retry.%s' % service_id, self.handle_retries)
  198. def gather_identifiers(self, params, model, context, **kwargs):
  199. endpoint_discovery = model.endpoint_discovery
  200. # Only continue if the operation supports endpoint discovery
  201. if endpoint_discovery is None:
  202. return
  203. ids = self._manager.gather_identifiers(model, params)
  204. context['discovery'] = {'identifiers': ids}
  205. def discover_endpoint(self, request, operation_name, **kwargs):
  206. ids = request.context.get('discovery', {}).get('identifiers')
  207. if ids is None:
  208. return
  209. endpoint = self._manager.describe_endpoint(
  210. Operation=operation_name, Identifiers=ids
  211. )
  212. if endpoint is None:
  213. logger.debug('Failed to discover and inject endpoint')
  214. return
  215. if not endpoint.startswith('http'):
  216. endpoint = 'https://' + endpoint
  217. logger.debug('Injecting discovered endpoint: %s', endpoint)
  218. request.url = endpoint
  219. def handle_retries(self, request_dict, response, operation, **kwargs):
  220. if response is None:
  221. return None
  222. _, response = response
  223. status = response.get('ResponseMetadata', {}).get('HTTPStatusCode')
  224. error_code = response.get('Error', {}).get('Code')
  225. if status != 421 and error_code != 'InvalidEndpointException':
  226. return None
  227. context = request_dict.get('context', {})
  228. ids = context.get('discovery', {}).get('identifiers')
  229. if ids is None:
  230. return None
  231. # Delete the cached endpoints, forcing a refresh on retry
  232. # TODO: Improve eviction behavior to only evict the bad endpoint if
  233. # there are multiple. This will almost certainly require a lock.
  234. self._manager.delete_endpoints(
  235. Operation=operation.name, Identifiers=ids
  236. )
  237. return 0