123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282 |
- # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License"). You
- # may not use this file except in compliance with the License. A copy of
- # the License is located at
- #
- # http://aws.amazon.com/apache2.0/
- #
- # or in the "license" file accompanying this file. This file is
- # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
- # ANY KIND, either express or implied. See the License for the specific
- # language governing permissions and limitations under the License.
- import logging
- import time
- import weakref
- from botocore import xform_name
- from botocore.exceptions import BotoCoreError, ConnectionError, HTTPClientError
- from botocore.model import OperationNotFoundError
- from botocore.utils import CachedProperty
- logger = logging.getLogger(__name__)
- class EndpointDiscoveryException(BotoCoreError):
- pass
- class EndpointDiscoveryRequired(EndpointDiscoveryException):
- """Endpoint Discovery is disabled but is required for this operation."""
- fmt = 'Endpoint Discovery is not enabled but this operation requires it.'
- class EndpointDiscoveryRefreshFailed(EndpointDiscoveryException):
- """Endpoint Discovery failed to the refresh the known endpoints."""
- fmt = 'Endpoint Discovery failed to refresh the required endpoints.'
- def block_endpoint_discovery_required_operations(model, **kwargs):
- endpoint_discovery = model.endpoint_discovery
- if endpoint_discovery and endpoint_discovery.get('required'):
- raise EndpointDiscoveryRequired()
- class EndpointDiscoveryModel:
- def __init__(self, service_model):
- self._service_model = service_model
- @CachedProperty
- def discovery_operation_name(self):
- discovery_operation = self._service_model.endpoint_discovery_operation
- return xform_name(discovery_operation.name)
- @CachedProperty
- def discovery_operation_keys(self):
- discovery_operation = self._service_model.endpoint_discovery_operation
- keys = []
- if discovery_operation.input_shape:
- keys = list(discovery_operation.input_shape.members.keys())
- return keys
- def discovery_required_for(self, operation_name):
- try:
- operation_model = self._service_model.operation_model(
- operation_name
- )
- return operation_model.endpoint_discovery.get('required', False)
- except OperationNotFoundError:
- return False
- def discovery_operation_kwargs(self, **kwargs):
- input_keys = self.discovery_operation_keys
- # Operation and Identifiers are only sent if there are Identifiers
- if not kwargs.get('Identifiers'):
- kwargs.pop('Operation', None)
- kwargs.pop('Identifiers', None)
- return {k: v for k, v in kwargs.items() if k in input_keys}
- def gather_identifiers(self, operation, params):
- return self._gather_ids(operation.input_shape, params)
- def _gather_ids(self, shape, params, ids=None):
- # Traverse the input shape and corresponding parameters, gathering
- # any input fields labeled as an endpoint discovery id
- if ids is None:
- ids = {}
- for member_name, member_shape in shape.members.items():
- if member_shape.metadata.get('endpointdiscoveryid'):
- ids[member_name] = params[member_name]
- elif (
- member_shape.type_name == 'structure' and member_name in params
- ):
- self._gather_ids(member_shape, params[member_name], ids)
- return ids
- class EndpointDiscoveryManager:
- def __init__(
- self, client, cache=None, current_time=None, always_discover=True
- ):
- if cache is None:
- cache = {}
- self._cache = cache
- self._failed_attempts = {}
- if current_time is None:
- current_time = time.time
- self._time = current_time
- self._always_discover = always_discover
- # This needs to be a weak ref in order to prevent memory leaks on
- # python 2.6
- self._client = weakref.proxy(client)
- self._model = EndpointDiscoveryModel(client.meta.service_model)
- def _parse_endpoints(self, response):
- endpoints = response['Endpoints']
- current_time = self._time()
- for endpoint in endpoints:
- cache_time = endpoint.get('CachePeriodInMinutes')
- endpoint['Expiration'] = current_time + cache_time * 60
- return endpoints
- def _cache_item(self, value):
- if isinstance(value, dict):
- return tuple(sorted(value.items()))
- else:
- return value
- def _create_cache_key(self, **kwargs):
- kwargs = self._model.discovery_operation_kwargs(**kwargs)
- return tuple(self._cache_item(v) for k, v in sorted(kwargs.items()))
- def gather_identifiers(self, operation, params):
- return self._model.gather_identifiers(operation, params)
- def delete_endpoints(self, **kwargs):
- cache_key = self._create_cache_key(**kwargs)
- if cache_key in self._cache:
- del self._cache[cache_key]
- def _describe_endpoints(self, **kwargs):
- # This is effectively a proxy to whatever name/kwargs the service
- # supports for endpoint discovery.
- kwargs = self._model.discovery_operation_kwargs(**kwargs)
- operation_name = self._model.discovery_operation_name
- discovery_operation = getattr(self._client, operation_name)
- logger.debug('Discovering endpoints with kwargs: %s', kwargs)
- return discovery_operation(**kwargs)
- def _get_current_endpoints(self, key):
- if key not in self._cache:
- return None
- now = self._time()
- return [e for e in self._cache[key] if now < e['Expiration']]
- def _refresh_current_endpoints(self, **kwargs):
- cache_key = self._create_cache_key(**kwargs)
- try:
- response = self._describe_endpoints(**kwargs)
- endpoints = self._parse_endpoints(response)
- self._cache[cache_key] = endpoints
- self._failed_attempts.pop(cache_key, None)
- return endpoints
- except (ConnectionError, HTTPClientError):
- self._failed_attempts[cache_key] = self._time() + 60
- return None
- def _recently_failed(self, cache_key):
- if cache_key in self._failed_attempts:
- now = self._time()
- if now < self._failed_attempts[cache_key]:
- return True
- del self._failed_attempts[cache_key]
- return False
- def _select_endpoint(self, endpoints):
- return endpoints[0]['Address']
- def describe_endpoint(self, **kwargs):
- operation = kwargs['Operation']
- discovery_required = self._model.discovery_required_for(operation)
- if not self._always_discover and not discovery_required:
- # Discovery set to only run on required operations
- logger.debug(
- 'Optional discovery disabled. Skipping discovery for Operation: %s'
- % operation
- )
- return None
- # Get the endpoint for the provided operation and identifiers
- cache_key = self._create_cache_key(**kwargs)
- endpoints = self._get_current_endpoints(cache_key)
- if endpoints:
- return self._select_endpoint(endpoints)
- # All known endpoints are stale
- recently_failed = self._recently_failed(cache_key)
- if not recently_failed:
- # We haven't failed to discover recently, go ahead and refresh
- endpoints = self._refresh_current_endpoints(**kwargs)
- if endpoints:
- return self._select_endpoint(endpoints)
- # Discovery has failed recently, do our best to get an endpoint
- logger.debug('Endpoint Discovery has failed for: %s', kwargs)
- stale_entries = self._cache.get(cache_key, None)
- if stale_entries:
- # We have stale entries, use those while discovery is failing
- return self._select_endpoint(stale_entries)
- if discovery_required:
- # It looks strange to be checking recently_failed again but,
- # this informs us as to whether or not we tried to refresh earlier
- if recently_failed:
- # Discovery is required and we haven't already refreshed
- endpoints = self._refresh_current_endpoints(**kwargs)
- if endpoints:
- return self._select_endpoint(endpoints)
- # No endpoints even refresh, raise hard error
- raise EndpointDiscoveryRefreshFailed()
- # Discovery is optional, just use the default endpoint for now
- return None
- class EndpointDiscoveryHandler:
- def __init__(self, manager):
- self._manager = manager
- def register(self, events, service_id):
- events.register(
- 'before-parameter-build.%s' % service_id, self.gather_identifiers
- )
- events.register_first(
- 'request-created.%s' % service_id, self.discover_endpoint
- )
- events.register('needs-retry.%s' % service_id, self.handle_retries)
- def gather_identifiers(self, params, model, context, **kwargs):
- endpoint_discovery = model.endpoint_discovery
- # Only continue if the operation supports endpoint discovery
- if endpoint_discovery is None:
- return
- ids = self._manager.gather_identifiers(model, params)
- context['discovery'] = {'identifiers': ids}
- def discover_endpoint(self, request, operation_name, **kwargs):
- ids = request.context.get('discovery', {}).get('identifiers')
- if ids is None:
- return
- endpoint = self._manager.describe_endpoint(
- Operation=operation_name, Identifiers=ids
- )
- if endpoint is None:
- logger.debug('Failed to discover and inject endpoint')
- return
- if not endpoint.startswith('http'):
- endpoint = 'https://' + endpoint
- logger.debug('Injecting discovered endpoint: %s', endpoint)
- request.url = endpoint
- def handle_retries(self, request_dict, response, operation, **kwargs):
- if response is None:
- return None
- _, response = response
- status = response.get('ResponseMetadata', {}).get('HTTPStatusCode')
- error_code = response.get('Error', {}).get('Code')
- if status != 421 and error_code != 'InvalidEndpointException':
- return None
- context = request_dict.get('context', {})
- ids = context.get('discovery', {}).get('identifiers')
- if ids is None:
- return None
- # Delete the cached endpoints, forcing a refresh on retry
- # TODO: Improve eviction behavior to only evict the bad endpoint if
- # there are multiple. This will almost certainly require a lock.
- self._manager.delete_endpoints(
- Operation=operation.name, Identifiers=ids
- )
- return 0
|