models.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. from typing import Any, Mapping, Optional, Sequence, Type, TypeVar, cast
  2. from django.db.models.fields import BigIntegerField
  3. from bitfield.query import BitQueryExactLookupStub
  4. from bitfield.types import Bit, BitHandler
  5. # Count binary capacity. Truncate "0b" prefix from binary form.
  6. # Twice faster than bin(i)[2:] or math.floor(math.log(i))
  7. MAX_FLAG_COUNT = int(len(bin(BigIntegerField.MAX_BIGINT)) - 2)
  8. class BitFieldFlags:
  9. def __init__(self, flags):
  10. if len(flags) > MAX_FLAG_COUNT:
  11. raise ValueError("Too many flags")
  12. self._flags = flags
  13. def __repr__(self):
  14. return repr(self._flags)
  15. def __iter__(self):
  16. yield from self._flags
  17. def __getattr__(self, key):
  18. if key not in self._flags:
  19. raise AttributeError
  20. return Bit(self._flags.index(key))
  21. __getitem__ = __getattr__
  22. def iteritems(self):
  23. for flag in self._flags:
  24. yield flag, Bit(self._flags.index(flag))
  25. def iterkeys(self):
  26. yield from self._flags
  27. def itervalues(self):
  28. for flag in self._flags:
  29. yield Bit(self._flags.index(flag))
  30. def items(self):
  31. return list(self.iteritems())
  32. def keys(self):
  33. return list(self.iterkeys())
  34. def values(self):
  35. return list(self.itervalues())
  36. class BitFieldCreator:
  37. """
  38. A placeholder class that provides a way to set the attribute on the model.
  39. Descriptor for BitFields. Checks to make sure that all flags of the
  40. instance match the class. This is to handle the case when caching
  41. an older version of the instance and a newer version of the class is
  42. available (usually during deploys).
  43. """
  44. def __init__(self, field):
  45. self.field = field
  46. def __set__(self, obj, value):
  47. obj.__dict__[self.field.name] = self.field.to_python(value)
  48. def __get__(self, obj, type=None):
  49. if obj is None:
  50. return BitFieldFlags(self.field.flags)
  51. retval = obj.__dict__[self.field.name]
  52. if self.field.__class__ is BitField:
  53. # Update flags from class in case they've changed.
  54. retval._keys = self.field.flags
  55. return retval
  56. class BitField(BigIntegerField):
  57. def contribute_to_class(self, cls, name, **kwargs):
  58. super().contribute_to_class(cls, name, **kwargs)
  59. setattr(cls, self.name, BitFieldCreator(self))
  60. def __init__(self, flags, default=None, *args, **kwargs):
  61. if isinstance(flags, dict):
  62. # Get only integer keys in correct range
  63. valid_keys = (
  64. k for k in flags.keys() if isinstance(k, int) and (0 <= k < MAX_FLAG_COUNT)
  65. )
  66. if not valid_keys:
  67. raise ValueError("Wrong keys or empty dictionary")
  68. # Fill list with values from dict or with empty values
  69. flags = [flags.get(i, "") for i in range(max(valid_keys) + 1)]
  70. if len(flags) > MAX_FLAG_COUNT:
  71. raise ValueError("Too many flags")
  72. self._arg_flags = flags
  73. flags = list(flags)
  74. labels = []
  75. for num, flag in enumerate(flags):
  76. if isinstance(flag, (tuple, list)):
  77. flags[num] = flag[0]
  78. labels.append(flag[1])
  79. else:
  80. labels.append(flag)
  81. if isinstance(default, (list, tuple, set, frozenset)):
  82. new_value = 0
  83. for flag in default:
  84. new_value |= Bit(flags.index(flag))
  85. default = new_value
  86. kwargs["default"] = default
  87. BigIntegerField.__init__(self, *args, **kwargs)
  88. self.flags = flags
  89. self.labels = labels
  90. def pre_save(self, instance, add):
  91. value = getattr(instance, self.attname)
  92. return value
  93. def get_prep_value(self, value):
  94. if value is None:
  95. return None
  96. if isinstance(value, (BitHandler, Bit)):
  97. value = value.mask
  98. return int(value)
  99. def to_python(self, value):
  100. if isinstance(value, Bit):
  101. value = value.mask
  102. if not isinstance(value, BitHandler):
  103. # Regression for #1425: fix bad data that was created resulting
  104. # in negative values for flags. Compute the value that would
  105. # have been visible ot the application to preserve compatibility.
  106. if isinstance(value, int) and value < 0:
  107. new_value = 0
  108. for bit_number, _ in enumerate(self.flags):
  109. new_value |= value & (2**bit_number)
  110. value = new_value
  111. value = BitHandler(value, self.flags, self.labels)
  112. else:
  113. # Ensure flags are consistent for unpickling
  114. value._keys = self.flags
  115. return value
  116. def deconstruct(self):
  117. name, path, args, kwargs = super().deconstruct()
  118. args.insert(0, self._arg_flags)
  119. return name, path, args, kwargs
  120. def flags_from_annotations(annotations: Mapping[str, type]) -> Sequence[str]:
  121. flags = []
  122. for attr, ty in annotations.items():
  123. assert ty in ("bool", bool), f"bitfields can only hold bools, {attr} is {ty!r}"
  124. flags.append(attr)
  125. return flags
  126. class TypedBitfieldMeta(type):
  127. def __new__(cls, name, bases, clsdict):
  128. if name == "TypedClassBitField":
  129. return type.__new__(cls, name, bases, clsdict)
  130. flags = {}
  131. for attr, ty in clsdict["__annotations__"].items():
  132. if attr.startswith("_"):
  133. continue
  134. if attr in ("bitfield_default", "bitfield_null", "bitfield_db_column"):
  135. continue
  136. flags[attr] = ty
  137. return BitField(
  138. flags=flags_from_annotations(flags),
  139. default=clsdict.get("bitfield_default"),
  140. null=clsdict.get("bitfield_null") or False,
  141. db_column=clsdict.get("bitfield_db_column"),
  142. )
  143. def __int__(self) -> int:
  144. raise NotImplementedError()
  145. class TypedClassBitField(metaclass=TypedBitfieldMeta):
  146. """
  147. A wrapper around BitField that allows you to access its fields as instance
  148. attributes in a type-safe way.
  149. """
  150. bitfield_default: Optional[Any]
  151. bitfield_null: bool
  152. _value: int
  153. T = TypeVar("T")
  154. def typed_dict_bitfield(definition: Type[T], default=None, null=False) -> T:
  155. """
  156. A wrapper around BitField that allows you to access its fields as
  157. dictionary keys attributes in a type-safe way.
  158. Prefer `TypedClassBitField` over this if you can help it. This function
  159. only exists to make it simpler to type bitfields with fields that are not
  160. valid Python identifiers, but has limitations for how far it can provide
  161. type safety.
  162. """
  163. assert issubclass(definition, dict)
  164. return cast(
  165. T,
  166. BitField(
  167. flags=flags_from_annotations(definition.__annotations__), default=default, null=null
  168. ),
  169. )
  170. BitField.register_lookup(BitQueryExactLookupStub)