fernet.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. # This file is dual licensed under the terms of the Apache License, Version
  2. # 2.0, and the BSD License. See the LICENSE file in the root of this repository
  3. # for complete details.
  4. from __future__ import absolute_import, division, print_function
  5. import base64
  6. import binascii
  7. import os
  8. import struct
  9. import time
  10. import six
  11. from cryptography import utils
  12. from cryptography.exceptions import InvalidSignature
  13. from cryptography.hazmat.backends import _get_backend
  14. from cryptography.hazmat.primitives import hashes, padding
  15. from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
  16. from cryptography.hazmat.primitives.hmac import HMAC
  17. class InvalidToken(Exception):
  18. pass
  19. _MAX_CLOCK_SKEW = 60
  20. class Fernet(object):
  21. def __init__(self, key, backend=None):
  22. backend = _get_backend(backend)
  23. key = base64.urlsafe_b64decode(key)
  24. if len(key) != 32:
  25. raise ValueError(
  26. "Fernet key must be 32 url-safe base64-encoded bytes."
  27. )
  28. self._signing_key = key[:16]
  29. self._encryption_key = key[16:]
  30. self._backend = backend
  31. @classmethod
  32. def generate_key(cls):
  33. return base64.urlsafe_b64encode(os.urandom(32))
  34. def encrypt(self, data):
  35. return self.encrypt_at_time(data, int(time.time()))
  36. def encrypt_at_time(self, data, current_time):
  37. iv = os.urandom(16)
  38. return self._encrypt_from_parts(data, current_time, iv)
  39. def _encrypt_from_parts(self, data, current_time, iv):
  40. utils._check_bytes("data", data)
  41. padder = padding.PKCS7(algorithms.AES.block_size).padder()
  42. padded_data = padder.update(data) + padder.finalize()
  43. encryptor = Cipher(
  44. algorithms.AES(self._encryption_key), modes.CBC(iv), self._backend
  45. ).encryptor()
  46. ciphertext = encryptor.update(padded_data) + encryptor.finalize()
  47. basic_parts = (
  48. b"\x80" + struct.pack(">Q", current_time) + iv + ciphertext
  49. )
  50. h = HMAC(self._signing_key, hashes.SHA256(), backend=self._backend)
  51. h.update(basic_parts)
  52. hmac = h.finalize()
  53. return base64.urlsafe_b64encode(basic_parts + hmac)
  54. def decrypt(self, token, ttl=None):
  55. timestamp, data = Fernet._get_unverified_token_data(token)
  56. return self._decrypt_data(data, timestamp, ttl, int(time.time()))
  57. def decrypt_at_time(self, token, ttl, current_time):
  58. if ttl is None:
  59. raise ValueError(
  60. "decrypt_at_time() can only be used with a non-None ttl"
  61. )
  62. timestamp, data = Fernet._get_unverified_token_data(token)
  63. return self._decrypt_data(data, timestamp, ttl, current_time)
  64. def extract_timestamp(self, token):
  65. timestamp, data = Fernet._get_unverified_token_data(token)
  66. # Verify the token was not tampered with.
  67. self._verify_signature(data)
  68. return timestamp
  69. @staticmethod
  70. def _get_unverified_token_data(token):
  71. utils._check_bytes("token", token)
  72. try:
  73. data = base64.urlsafe_b64decode(token)
  74. except (TypeError, binascii.Error):
  75. raise InvalidToken
  76. if not data or six.indexbytes(data, 0) != 0x80:
  77. raise InvalidToken
  78. try:
  79. (timestamp,) = struct.unpack(">Q", data[1:9])
  80. except struct.error:
  81. raise InvalidToken
  82. return timestamp, data
  83. def _verify_signature(self, data):
  84. h = HMAC(self._signing_key, hashes.SHA256(), backend=self._backend)
  85. h.update(data[:-32])
  86. try:
  87. h.verify(data[-32:])
  88. except InvalidSignature:
  89. raise InvalidToken
  90. def _decrypt_data(self, data, timestamp, ttl, current_time):
  91. if ttl is not None:
  92. if timestamp + ttl < current_time:
  93. raise InvalidToken
  94. if current_time + _MAX_CLOCK_SKEW < timestamp:
  95. raise InvalidToken
  96. self._verify_signature(data)
  97. iv = data[9:25]
  98. ciphertext = data[25:-32]
  99. decryptor = Cipher(
  100. algorithms.AES(self._encryption_key), modes.CBC(iv), self._backend
  101. ).decryptor()
  102. plaintext_padded = decryptor.update(ciphertext)
  103. try:
  104. plaintext_padded += decryptor.finalize()
  105. except ValueError:
  106. raise InvalidToken
  107. unpadder = padding.PKCS7(algorithms.AES.block_size).unpadder()
  108. unpadded = unpadder.update(plaintext_padded)
  109. try:
  110. unpadded += unpadder.finalize()
  111. except ValueError:
  112. raise InvalidToken
  113. return unpadded
  114. class MultiFernet(object):
  115. def __init__(self, fernets):
  116. fernets = list(fernets)
  117. if not fernets:
  118. raise ValueError(
  119. "MultiFernet requires at least one Fernet instance"
  120. )
  121. self._fernets = fernets
  122. def encrypt(self, msg):
  123. return self.encrypt_at_time(msg, int(time.time()))
  124. def encrypt_at_time(self, msg, current_time):
  125. return self._fernets[0].encrypt_at_time(msg, current_time)
  126. def rotate(self, msg):
  127. timestamp, data = Fernet._get_unverified_token_data(msg)
  128. for f in self._fernets:
  129. try:
  130. p = f._decrypt_data(data, timestamp, None, None)
  131. break
  132. except InvalidToken:
  133. pass
  134. else:
  135. raise InvalidToken
  136. iv = os.urandom(16)
  137. return self._fernets[0]._encrypt_from_parts(p, timestamp, iv)
  138. def decrypt(self, msg, ttl=None):
  139. for f in self._fernets:
  140. try:
  141. return f.decrypt(msg, ttl)
  142. except InvalidToken:
  143. pass
  144. raise InvalidToken
  145. def decrypt_at_time(self, msg, ttl, current_time):
  146. for f in self._fernets:
  147. try:
  148. return f.decrypt_at_time(msg, ttl, current_time)
  149. except InvalidToken:
  150. pass
  151. raise InvalidToken