algorithms.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403
  1. import hashlib
  2. import hmac
  3. import json
  4. from .compat import constant_time_compare, string_types
  5. from .exceptions import InvalidKeyError
  6. from .utils import (
  7. base64url_decode, base64url_encode, der_to_raw_signature,
  8. force_bytes, force_unicode, from_base64url_uint, raw_to_der_signature,
  9. to_base64url_uint
  10. )
  11. try:
  12. from cryptography.hazmat.primitives import hashes
  13. from cryptography.hazmat.primitives.serialization import (
  14. load_pem_private_key, load_pem_public_key, load_ssh_public_key
  15. )
  16. from cryptography.hazmat.primitives.asymmetric.rsa import (
  17. RSAPrivateKey, RSAPublicKey, RSAPrivateNumbers, RSAPublicNumbers,
  18. rsa_recover_prime_factors, rsa_crt_dmp1, rsa_crt_dmq1, rsa_crt_iqmp
  19. )
  20. from cryptography.hazmat.primitives.asymmetric.ec import (
  21. EllipticCurvePrivateKey, EllipticCurvePublicKey
  22. )
  23. from cryptography.hazmat.primitives.asymmetric import ec, padding
  24. from cryptography.hazmat.backends import default_backend
  25. from cryptography.exceptions import InvalidSignature
  26. has_crypto = True
  27. except ImportError:
  28. has_crypto = False
  29. requires_cryptography = set(['RS256', 'RS384', 'RS512', 'ES256', 'ES384',
  30. 'ES521', 'ES512', 'PS256', 'PS384', 'PS512'])
  31. def get_default_algorithms():
  32. """
  33. Returns the algorithms that are implemented by the library.
  34. """
  35. default_algorithms = {
  36. 'none': NoneAlgorithm(),
  37. 'HS256': HMACAlgorithm(HMACAlgorithm.SHA256),
  38. 'HS384': HMACAlgorithm(HMACAlgorithm.SHA384),
  39. 'HS512': HMACAlgorithm(HMACAlgorithm.SHA512)
  40. }
  41. if has_crypto:
  42. default_algorithms.update({
  43. 'RS256': RSAAlgorithm(RSAAlgorithm.SHA256),
  44. 'RS384': RSAAlgorithm(RSAAlgorithm.SHA384),
  45. 'RS512': RSAAlgorithm(RSAAlgorithm.SHA512),
  46. 'ES256': ECAlgorithm(ECAlgorithm.SHA256),
  47. 'ES384': ECAlgorithm(ECAlgorithm.SHA384),
  48. 'ES521': ECAlgorithm(ECAlgorithm.SHA512),
  49. 'ES512': ECAlgorithm(ECAlgorithm.SHA512), # Backward compat for #219 fix
  50. 'PS256': RSAPSSAlgorithm(RSAPSSAlgorithm.SHA256),
  51. 'PS384': RSAPSSAlgorithm(RSAPSSAlgorithm.SHA384),
  52. 'PS512': RSAPSSAlgorithm(RSAPSSAlgorithm.SHA512)
  53. })
  54. return default_algorithms
  55. class Algorithm(object):
  56. """
  57. The interface for an algorithm used to sign and verify tokens.
  58. """
  59. def prepare_key(self, key):
  60. """
  61. Performs necessary validation and conversions on the key and returns
  62. the key value in the proper format for sign() and verify().
  63. """
  64. raise NotImplementedError
  65. def sign(self, msg, key):
  66. """
  67. Returns a digital signature for the specified message
  68. using the specified key value.
  69. """
  70. raise NotImplementedError
  71. def verify(self, msg, key, sig):
  72. """
  73. Verifies that the specified digital signature is valid
  74. for the specified message and key values.
  75. """
  76. raise NotImplementedError
  77. @staticmethod
  78. def to_jwk(key_obj):
  79. """
  80. Serializes a given RSA key into a JWK
  81. """
  82. raise NotImplementedError
  83. @staticmethod
  84. def from_jwk(jwk):
  85. """
  86. Deserializes a given RSA key from JWK back into a PublicKey or PrivateKey object
  87. """
  88. raise NotImplementedError
  89. class NoneAlgorithm(Algorithm):
  90. """
  91. Placeholder for use when no signing or verification
  92. operations are required.
  93. """
  94. def prepare_key(self, key):
  95. if key == '':
  96. key = None
  97. if key is not None:
  98. raise InvalidKeyError('When alg = "none", key value must be None.')
  99. return key
  100. def sign(self, msg, key):
  101. return b''
  102. def verify(self, msg, key, sig):
  103. return False
  104. class HMACAlgorithm(Algorithm):
  105. """
  106. Performs signing and verification operations using HMAC
  107. and the specified hash function.
  108. """
  109. SHA256 = hashlib.sha256
  110. SHA384 = hashlib.sha384
  111. SHA512 = hashlib.sha512
  112. def __init__(self, hash_alg):
  113. self.hash_alg = hash_alg
  114. def prepare_key(self, key):
  115. key = force_bytes(key)
  116. invalid_strings = [
  117. b'-----BEGIN PUBLIC KEY-----',
  118. b'-----BEGIN CERTIFICATE-----',
  119. b'-----BEGIN RSA PUBLIC KEY-----',
  120. b'ssh-rsa'
  121. ]
  122. if any([string_value in key for string_value in invalid_strings]):
  123. raise InvalidKeyError(
  124. 'The specified key is an asymmetric key or x509 certificate and'
  125. ' should not be used as an HMAC secret.')
  126. return key
  127. @staticmethod
  128. def to_jwk(key_obj):
  129. return json.dumps({
  130. 'k': force_unicode(base64url_encode(force_bytes(key_obj))),
  131. 'kty': 'oct'
  132. })
  133. @staticmethod
  134. def from_jwk(jwk):
  135. obj = json.loads(jwk)
  136. if obj.get('kty') != 'oct':
  137. raise InvalidKeyError('Not an HMAC key')
  138. return base64url_decode(obj['k'])
  139. def sign(self, msg, key):
  140. return hmac.new(key, msg, self.hash_alg).digest()
  141. def verify(self, msg, key, sig):
  142. return constant_time_compare(sig, self.sign(msg, key))
  143. if has_crypto:
  144. class RSAAlgorithm(Algorithm):
  145. """
  146. Performs signing and verification operations using
  147. RSASSA-PKCS-v1_5 and the specified hash function.
  148. """
  149. SHA256 = hashes.SHA256
  150. SHA384 = hashes.SHA384
  151. SHA512 = hashes.SHA512
  152. def __init__(self, hash_alg):
  153. self.hash_alg = hash_alg
  154. def prepare_key(self, key):
  155. if isinstance(key, RSAPrivateKey) or \
  156. isinstance(key, RSAPublicKey):
  157. return key
  158. if isinstance(key, string_types):
  159. key = force_bytes(key)
  160. try:
  161. if key.startswith(b'ssh-rsa'):
  162. key = load_ssh_public_key(key, backend=default_backend())
  163. else:
  164. key = load_pem_private_key(key, password=None, backend=default_backend())
  165. except ValueError:
  166. key = load_pem_public_key(key, backend=default_backend())
  167. else:
  168. raise TypeError('Expecting a PEM-formatted key.')
  169. return key
  170. @staticmethod
  171. def to_jwk(key_obj):
  172. obj = None
  173. if getattr(key_obj, 'private_numbers', None):
  174. # Private key
  175. numbers = key_obj.private_numbers()
  176. obj = {
  177. 'kty': 'RSA',
  178. 'key_ops': ['sign'],
  179. 'n': force_unicode(to_base64url_uint(numbers.public_numbers.n)),
  180. 'e': force_unicode(to_base64url_uint(numbers.public_numbers.e)),
  181. 'd': force_unicode(to_base64url_uint(numbers.d)),
  182. 'p': force_unicode(to_base64url_uint(numbers.p)),
  183. 'q': force_unicode(to_base64url_uint(numbers.q)),
  184. 'dp': force_unicode(to_base64url_uint(numbers.dmp1)),
  185. 'dq': force_unicode(to_base64url_uint(numbers.dmq1)),
  186. 'qi': force_unicode(to_base64url_uint(numbers.iqmp))
  187. }
  188. elif getattr(key_obj, 'verify', None):
  189. # Public key
  190. numbers = key_obj.public_numbers()
  191. obj = {
  192. 'kty': 'RSA',
  193. 'key_ops': ['verify'],
  194. 'n': force_unicode(to_base64url_uint(numbers.n)),
  195. 'e': force_unicode(to_base64url_uint(numbers.e))
  196. }
  197. else:
  198. raise InvalidKeyError('Not a public or private key')
  199. return json.dumps(obj)
  200. @staticmethod
  201. def from_jwk(jwk):
  202. try:
  203. obj = json.loads(jwk)
  204. except ValueError:
  205. raise InvalidKeyError('Key is not valid JSON')
  206. if obj.get('kty') != 'RSA':
  207. raise InvalidKeyError('Not an RSA key')
  208. if 'd' in obj and 'e' in obj and 'n' in obj:
  209. # Private key
  210. if 'oth' in obj:
  211. raise InvalidKeyError('Unsupported RSA private key: > 2 primes not supported')
  212. other_props = ['p', 'q', 'dp', 'dq', 'qi']
  213. props_found = [prop in obj for prop in other_props]
  214. any_props_found = any(props_found)
  215. if any_props_found and not all(props_found):
  216. raise InvalidKeyError('RSA key must include all parameters if any are present besides d')
  217. public_numbers = RSAPublicNumbers(
  218. from_base64url_uint(obj['e']), from_base64url_uint(obj['n'])
  219. )
  220. if any_props_found:
  221. numbers = RSAPrivateNumbers(
  222. d=from_base64url_uint(obj['d']),
  223. p=from_base64url_uint(obj['p']),
  224. q=from_base64url_uint(obj['q']),
  225. dmp1=from_base64url_uint(obj['dp']),
  226. dmq1=from_base64url_uint(obj['dq']),
  227. iqmp=from_base64url_uint(obj['qi']),
  228. public_numbers=public_numbers
  229. )
  230. else:
  231. d = from_base64url_uint(obj['d'])
  232. p, q = rsa_recover_prime_factors(
  233. public_numbers.n, d, public_numbers.e
  234. )
  235. numbers = RSAPrivateNumbers(
  236. d=d,
  237. p=p,
  238. q=q,
  239. dmp1=rsa_crt_dmp1(d, p),
  240. dmq1=rsa_crt_dmq1(d, q),
  241. iqmp=rsa_crt_iqmp(p, q),
  242. public_numbers=public_numbers
  243. )
  244. return numbers.private_key(default_backend())
  245. elif 'n' in obj and 'e' in obj:
  246. # Public key
  247. numbers = RSAPublicNumbers(
  248. from_base64url_uint(obj['e']), from_base64url_uint(obj['n'])
  249. )
  250. return numbers.public_key(default_backend())
  251. else:
  252. raise InvalidKeyError('Not a public or private key')
  253. def sign(self, msg, key):
  254. return key.sign(msg, padding.PKCS1v15(), self.hash_alg())
  255. def verify(self, msg, key, sig):
  256. try:
  257. key.verify(sig, msg, padding.PKCS1v15(), self.hash_alg())
  258. return True
  259. except InvalidSignature:
  260. return False
  261. class ECAlgorithm(Algorithm):
  262. """
  263. Performs signing and verification operations using
  264. ECDSA and the specified hash function
  265. """
  266. SHA256 = hashes.SHA256
  267. SHA384 = hashes.SHA384
  268. SHA512 = hashes.SHA512
  269. def __init__(self, hash_alg):
  270. self.hash_alg = hash_alg
  271. def prepare_key(self, key):
  272. if isinstance(key, EllipticCurvePrivateKey) or \
  273. isinstance(key, EllipticCurvePublicKey):
  274. return key
  275. if isinstance(key, string_types):
  276. key = force_bytes(key)
  277. # Attempt to load key. We don't know if it's
  278. # a Signing Key or a Verifying Key, so we try
  279. # the Verifying Key first.
  280. try:
  281. if key.startswith(b'ecdsa-sha2-'):
  282. key = load_ssh_public_key(key, backend=default_backend())
  283. else:
  284. key = load_pem_public_key(key, backend=default_backend())
  285. except ValueError:
  286. key = load_pem_private_key(key, password=None, backend=default_backend())
  287. else:
  288. raise TypeError('Expecting a PEM-formatted key.')
  289. return key
  290. def sign(self, msg, key):
  291. der_sig = key.sign(msg, ec.ECDSA(self.hash_alg()))
  292. return der_to_raw_signature(der_sig, key.curve)
  293. def verify(self, msg, key, sig):
  294. try:
  295. der_sig = raw_to_der_signature(sig, key.curve)
  296. except ValueError:
  297. return False
  298. try:
  299. key.verify(der_sig, msg, ec.ECDSA(self.hash_alg()))
  300. return True
  301. except InvalidSignature:
  302. return False
  303. class RSAPSSAlgorithm(RSAAlgorithm):
  304. """
  305. Performs a signature using RSASSA-PSS with MGF1
  306. """
  307. def sign(self, msg, key):
  308. return key.sign(
  309. msg,
  310. padding.PSS(
  311. mgf=padding.MGF1(self.hash_alg()),
  312. salt_length=self.hash_alg.digest_size
  313. ),
  314. self.hash_alg()
  315. )
  316. def verify(self, msg, key, sig):
  317. try:
  318. key.verify(
  319. sig,
  320. msg,
  321. padding.PSS(
  322. mgf=padding.MGF1(self.hash_alg()),
  323. salt_length=self.hash_alg.digest_size
  324. ),
  325. self.hash_alg()
  326. )
  327. return True
  328. except InvalidSignature:
  329. return False