api_jwt.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
  1. import json
  2. import warnings
  3. from calendar import timegm
  4. from datetime import datetime, timedelta
  5. try:
  6. # import required by mypy to perform type checking, not used for normal execution
  7. from typing import Callable, Dict, List, Optional, Union # NOQA
  8. except ImportError:
  9. pass
  10. from .api_jws import PyJWS
  11. from .algorithms import Algorithm, get_default_algorithms # NOQA
  12. from .compat import Iterable, Mapping, string_types
  13. from .exceptions import (
  14. DecodeError, ExpiredSignatureError, ImmatureSignatureError,
  15. InvalidAudienceError, InvalidIssuedAtError,
  16. InvalidIssuerError, MissingRequiredClaimError
  17. )
  18. from .utils import merge_dict
  19. class PyJWT(PyJWS):
  20. header_type = 'JWT'
  21. @staticmethod
  22. def _get_default_options():
  23. # type: () -> Dict[str, bool]
  24. return {
  25. 'verify_signature': True,
  26. 'verify_exp': True,
  27. 'verify_nbf': True,
  28. 'verify_iat': True,
  29. 'verify_aud': True,
  30. 'verify_iss': True,
  31. 'require_exp': False,
  32. 'require_iat': False,
  33. 'require_nbf': False
  34. }
  35. def encode(self,
  36. payload, # type: Union[Dict, bytes]
  37. key, # type: str
  38. algorithm='HS256', # type: str
  39. headers=None, # type: Optional[Dict]
  40. json_encoder=None # type: Optional[Callable]
  41. ):
  42. # Check that we get a mapping
  43. if not isinstance(payload, Mapping):
  44. raise TypeError('Expecting a mapping object, as JWT only supports '
  45. 'JSON objects as payloads.')
  46. # Payload
  47. for time_claim in ['exp', 'iat', 'nbf']:
  48. # Convert datetime to a intDate value in known time-format claims
  49. if isinstance(payload.get(time_claim), datetime):
  50. payload[time_claim] = timegm(payload[time_claim].utctimetuple()) # type: ignore
  51. json_payload = json.dumps(
  52. payload,
  53. separators=(',', ':'),
  54. cls=json_encoder
  55. ).encode('utf-8')
  56. return super(PyJWT, self).encode(
  57. json_payload, key, algorithm, headers, json_encoder
  58. )
  59. def decode(self,
  60. jwt, # type: str
  61. key='', # type: str
  62. verify=True, # type: bool
  63. algorithms=None, # type: List[str]
  64. options=None, # type: Dict
  65. **kwargs):
  66. if verify and not algorithms:
  67. warnings.warn(
  68. 'It is strongly recommended that you pass in a ' +
  69. 'value for the "algorithms" argument when calling decode(). ' +
  70. 'This argument will be mandatory in a future version.',
  71. DeprecationWarning
  72. )
  73. payload, _, _, _ = self._load(jwt)
  74. if options is None:
  75. options = {'verify_signature': verify}
  76. else:
  77. options.setdefault('verify_signature', verify)
  78. decoded = super(PyJWT, self).decode(
  79. jwt, key=key, algorithms=algorithms, options=options, **kwargs
  80. )
  81. try:
  82. payload = json.loads(decoded.decode('utf-8'))
  83. except ValueError as e:
  84. raise DecodeError('Invalid payload string: %s' % e)
  85. if not isinstance(payload, Mapping):
  86. raise DecodeError('Invalid payload string: must be a json object')
  87. if verify:
  88. merged_options = merge_dict(self.options, options)
  89. self._validate_claims(payload, merged_options, **kwargs)
  90. return payload
  91. def _validate_claims(self, payload, options, audience=None, issuer=None,
  92. leeway=0, **kwargs):
  93. if 'verify_expiration' in kwargs:
  94. options['verify_exp'] = kwargs.get('verify_expiration', True)
  95. warnings.warn('The verify_expiration parameter is deprecated. '
  96. 'Please use verify_exp in options instead.',
  97. DeprecationWarning)
  98. if isinstance(leeway, timedelta):
  99. leeway = leeway.total_seconds()
  100. if not isinstance(audience, (string_types, type(None), Iterable)):
  101. raise TypeError('audience must be a string, iterable, or None')
  102. self._validate_required_claims(payload, options)
  103. now = timegm(datetime.utcnow().utctimetuple())
  104. if 'iat' in payload and options.get('verify_iat'):
  105. self._validate_iat(payload, now, leeway)
  106. if 'nbf' in payload and options.get('verify_nbf'):
  107. self._validate_nbf(payload, now, leeway)
  108. if 'exp' in payload and options.get('verify_exp'):
  109. self._validate_exp(payload, now, leeway)
  110. if options.get('verify_iss'):
  111. self._validate_iss(payload, issuer)
  112. if options.get('verify_aud'):
  113. self._validate_aud(payload, audience)
  114. def _validate_required_claims(self, payload, options):
  115. if options.get('require_exp') and payload.get('exp') is None:
  116. raise MissingRequiredClaimError('exp')
  117. if options.get('require_iat') and payload.get('iat') is None:
  118. raise MissingRequiredClaimError('iat')
  119. if options.get('require_nbf') and payload.get('nbf') is None:
  120. raise MissingRequiredClaimError('nbf')
  121. def _validate_iat(self, payload, now, leeway):
  122. try:
  123. int(payload['iat'])
  124. except ValueError:
  125. raise InvalidIssuedAtError('Issued At claim (iat) must be an integer.')
  126. def _validate_nbf(self, payload, now, leeway):
  127. try:
  128. nbf = int(payload['nbf'])
  129. except ValueError:
  130. raise DecodeError('Not Before claim (nbf) must be an integer.')
  131. if nbf > (now + leeway):
  132. raise ImmatureSignatureError('The token is not yet valid (nbf)')
  133. def _validate_exp(self, payload, now, leeway):
  134. try:
  135. exp = int(payload['exp'])
  136. except ValueError:
  137. raise DecodeError('Expiration Time claim (exp) must be an'
  138. ' integer.')
  139. if exp < (now - leeway):
  140. raise ExpiredSignatureError('Signature has expired')
  141. def _validate_aud(self, payload, audience):
  142. if audience is None and 'aud' not in payload:
  143. return
  144. if audience is not None and 'aud' not in payload:
  145. # Application specified an audience, but it could not be
  146. # verified since the token does not contain a claim.
  147. raise MissingRequiredClaimError('aud')
  148. if audience is None and 'aud' in payload:
  149. # Application did not specify an audience, but
  150. # the token has the 'aud' claim
  151. raise InvalidAudienceError('Invalid audience')
  152. audience_claims = payload['aud']
  153. if isinstance(audience_claims, string_types):
  154. audience_claims = [audience_claims]
  155. if not isinstance(audience_claims, list):
  156. raise InvalidAudienceError('Invalid claim format in token')
  157. if any(not isinstance(c, string_types) for c in audience_claims):
  158. raise InvalidAudienceError('Invalid claim format in token')
  159. if isinstance(audience, string_types):
  160. audience = [audience]
  161. if not any(aud in audience_claims for aud in audience):
  162. raise InvalidAudienceError('Invalid audience')
  163. def _validate_iss(self, payload, issuer):
  164. if issuer is None:
  165. return
  166. if 'iss' not in payload:
  167. raise MissingRequiredClaimError('iss')
  168. if payload['iss'] != issuer:
  169. raise InvalidIssuerError('Invalid issuer')
  170. _jwt_global_obj = PyJWT()
  171. encode = _jwt_global_obj.encode
  172. decode = _jwt_global_obj.decode
  173. register_algorithm = _jwt_global_obj.register_algorithm
  174. unregister_algorithm = _jwt_global_obj.unregister_algorithm
  175. get_unverified_header = _jwt_global_obj.get_unverified_header