api_jws.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. import binascii
  2. import json
  3. import warnings
  4. try:
  5. # import required by mypy to perform type checking, not used for normal execution
  6. from typing import Callable, Dict, List, Optional, Union # NOQA
  7. except ImportError:
  8. pass
  9. from .algorithms import (
  10. Algorithm, get_default_algorithms, has_crypto, requires_cryptography # NOQA
  11. )
  12. from .compat import Mapping, binary_type, string_types, text_type
  13. from .exceptions import (
  14. DecodeError, InvalidAlgorithmError, InvalidSignatureError,
  15. InvalidTokenError
  16. )
  17. from .utils import base64url_decode, base64url_encode, force_bytes, merge_dict
  18. class PyJWS(object):
  19. header_typ = 'JWT'
  20. def __init__(self, algorithms=None, options=None):
  21. self._algorithms = get_default_algorithms()
  22. self._valid_algs = (set(algorithms) if algorithms is not None
  23. else set(self._algorithms))
  24. # Remove algorithms that aren't on the whitelist
  25. for key in list(self._algorithms.keys()):
  26. if key not in self._valid_algs:
  27. del self._algorithms[key]
  28. if not options:
  29. options = {}
  30. self.options = merge_dict(self._get_default_options(), options)
  31. @staticmethod
  32. def _get_default_options():
  33. return {
  34. 'verify_signature': True
  35. }
  36. def register_algorithm(self, alg_id, alg_obj):
  37. """
  38. Registers a new Algorithm for use when creating and verifying tokens.
  39. """
  40. if alg_id in self._algorithms:
  41. raise ValueError('Algorithm already has a handler.')
  42. if not isinstance(alg_obj, Algorithm):
  43. raise TypeError('Object is not of type `Algorithm`')
  44. self._algorithms[alg_id] = alg_obj
  45. self._valid_algs.add(alg_id)
  46. def unregister_algorithm(self, alg_id):
  47. """
  48. Unregisters an Algorithm for use when creating and verifying tokens
  49. Throws KeyError if algorithm is not registered.
  50. """
  51. if alg_id not in self._algorithms:
  52. raise KeyError('The specified algorithm could not be removed'
  53. ' because it is not registered.')
  54. del self._algorithms[alg_id]
  55. self._valid_algs.remove(alg_id)
  56. def get_algorithms(self):
  57. """
  58. Returns a list of supported values for the 'alg' parameter.
  59. """
  60. return list(self._valid_algs)
  61. def encode(self,
  62. payload, # type: Union[Dict, bytes]
  63. key, # type: str
  64. algorithm='HS256', # type: str
  65. headers=None, # type: Optional[Dict]
  66. json_encoder=None # type: Optional[Callable]
  67. ):
  68. segments = []
  69. if algorithm is None:
  70. algorithm = 'none'
  71. if algorithm not in self._valid_algs:
  72. pass
  73. # Header
  74. header = {'typ': self.header_typ, 'alg': algorithm}
  75. if headers:
  76. self._validate_headers(headers)
  77. header.update(headers)
  78. json_header = force_bytes(
  79. json.dumps(
  80. header,
  81. separators=(',', ':'),
  82. cls=json_encoder
  83. )
  84. )
  85. segments.append(base64url_encode(json_header))
  86. segments.append(base64url_encode(payload))
  87. # Segments
  88. signing_input = b'.'.join(segments)
  89. try:
  90. alg_obj = self._algorithms[algorithm]
  91. key = alg_obj.prepare_key(key)
  92. signature = alg_obj.sign(signing_input, key)
  93. except KeyError:
  94. if not has_crypto and algorithm in requires_cryptography:
  95. raise NotImplementedError(
  96. "Algorithm '%s' could not be found. Do you have cryptography "
  97. "installed?" % algorithm
  98. )
  99. else:
  100. raise NotImplementedError('Algorithm not supported')
  101. segments.append(base64url_encode(signature))
  102. return b'.'.join(segments)
  103. def decode(self,
  104. jwt, # type: str
  105. key='', # type: str
  106. verify=True, # type: bool
  107. algorithms=None, # type: List[str]
  108. options=None, # type: Dict
  109. **kwargs):
  110. merged_options = merge_dict(self.options, options)
  111. verify_signature = merged_options['verify_signature']
  112. if verify_signature and not algorithms:
  113. warnings.warn(
  114. 'It is strongly recommended that you pass in a ' +
  115. 'value for the "algorithms" argument when calling decode(). ' +
  116. 'This argument will be mandatory in a future version.',
  117. DeprecationWarning
  118. )
  119. payload, signing_input, header, signature = self._load(jwt)
  120. if not verify:
  121. warnings.warn('The verify parameter is deprecated. '
  122. 'Please use verify_signature in options instead.',
  123. DeprecationWarning, stacklevel=2)
  124. elif verify_signature:
  125. self._verify_signature(payload, signing_input, header, signature,
  126. key, algorithms)
  127. return payload
  128. def get_unverified_header(self, jwt):
  129. """Returns back the JWT header parameters as a dict()
  130. Note: The signature is not verified so the header parameters
  131. should not be fully trusted until signature verification is complete
  132. """
  133. headers = self._load(jwt)[2]
  134. self._validate_headers(headers)
  135. return headers
  136. def _load(self, jwt):
  137. if isinstance(jwt, text_type):
  138. jwt = jwt.encode('utf-8')
  139. if not issubclass(type(jwt), binary_type):
  140. raise DecodeError("Invalid token type. Token must be a {0}".format(
  141. binary_type))
  142. try:
  143. signing_input, crypto_segment = jwt.rsplit(b'.', 1)
  144. header_segment, payload_segment = signing_input.split(b'.', 1)
  145. except ValueError:
  146. raise DecodeError('Not enough segments')
  147. try:
  148. header_data = base64url_decode(header_segment)
  149. except (TypeError, binascii.Error):
  150. raise DecodeError('Invalid header padding')
  151. try:
  152. header = json.loads(header_data.decode('utf-8'))
  153. except ValueError as e:
  154. raise DecodeError('Invalid header string: %s' % e)
  155. if not isinstance(header, Mapping):
  156. raise DecodeError('Invalid header string: must be a json object')
  157. try:
  158. payload = base64url_decode(payload_segment)
  159. except (TypeError, binascii.Error):
  160. raise DecodeError('Invalid payload padding')
  161. try:
  162. signature = base64url_decode(crypto_segment)
  163. except (TypeError, binascii.Error):
  164. raise DecodeError('Invalid crypto padding')
  165. return (payload, signing_input, header, signature)
  166. def _verify_signature(self, payload, signing_input, header, signature,
  167. key='', algorithms=None):
  168. alg = header.get('alg')
  169. if algorithms is not None and alg not in algorithms:
  170. raise InvalidAlgorithmError('The specified alg value is not allowed')
  171. try:
  172. alg_obj = self._algorithms[alg]
  173. key = alg_obj.prepare_key(key)
  174. if not alg_obj.verify(signing_input, key, signature):
  175. raise InvalidSignatureError('Signature verification failed')
  176. except KeyError:
  177. raise InvalidAlgorithmError('Algorithm not supported')
  178. def _validate_headers(self, headers):
  179. if 'kid' in headers:
  180. self._validate_kid(headers['kid'])
  181. def _validate_kid(self, kid):
  182. if not isinstance(kid, string_types):
  183. raise InvalidTokenError('Key ID header parameter must be a string')
  184. _jws_global_obj = PyJWS()
  185. encode = _jws_global_obj.encode
  186. decode = _jws_global_obj.decode
  187. register_algorithm = _jws_global_obj.register_algorithm
  188. unregister_algorithm = _jws_global_obj.unregister_algorithm
  189. get_unverified_header = _jws_global_obj.get_unverified_header