models.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. from __future__ import absolute_import
  2. import six
  3. from django.db.models import signals
  4. from django.db.models.fields import BigIntegerField, Field
  5. from bitfield.forms import BitFormField
  6. from bitfield.query import BitQueryLookupWrapper
  7. from bitfield.types import Bit, BitHandler
  8. # Count binary capacity. Truncate "0b" prefix from binary form.
  9. # Twice faster than bin(i)[2:] or math.floor(math.log(i))
  10. MAX_FLAG_COUNT = int(len(bin(BigIntegerField.MAX_BIGINT)) - 2)
  11. class BitFieldFlags(object):
  12. def __init__(self, flags):
  13. if len(flags) > MAX_FLAG_COUNT:
  14. raise ValueError("Too many flags")
  15. self._flags = flags
  16. def __repr__(self):
  17. return repr(self._flags)
  18. def __iter__(self):
  19. for flag in self._flags:
  20. yield flag
  21. def __getattr__(self, key):
  22. if key == "_flags":
  23. # Since __getattr__ is for fallback, reaching here from Python
  24. # means that there's no '_flags' attribute in this object,
  25. # which may be caused by intermediate state while copying etc.
  26. raise AttributeError(
  27. "'%s' object has no attribute '%s'" % (self.__class__.__name__, key)
  28. )
  29. try:
  30. flags = self._flags
  31. except AttributeError:
  32. raise AttributeError(
  33. "'%s' object has no attribute '%s'" % (self.__class__.__name__, key)
  34. )
  35. try:
  36. flag = flags.index(key)
  37. except ValueError:
  38. raise AttributeError("flag {} is not registered".format(key))
  39. return Bit(flag)
  40. def iteritems(self):
  41. for flag in self._flags:
  42. yield flag, Bit(self._flags.index(flag))
  43. def iterkeys(self):
  44. for flag in self._flags:
  45. yield flag
  46. def itervalues(self):
  47. for flag in self._flags:
  48. yield Bit(self._flags.index(flag))
  49. def items(self):
  50. return list(self.iteritems())
  51. def keys(self):
  52. return list(self.iterkeys())
  53. def values(self):
  54. return list(self.itervalues())
  55. class BitFieldCreator(object):
  56. """
  57. A placeholder class that provides a way to set the attribute on the model.
  58. Descriptor for BitFields. Checks to make sure that all flags of the
  59. instance match the class. This is to handle the case when caching
  60. an older version of the instance and a newer version of the class is
  61. available (usually during deploys).
  62. """
  63. def __init__(self, field):
  64. self.field = field
  65. def __set__(self, obj, value):
  66. obj.__dict__[self.field.name] = self.field.to_python(value)
  67. def __get__(self, obj, type=None):
  68. if obj is None:
  69. return BitFieldFlags(self.field.flags)
  70. retval = obj.__dict__[self.field.name]
  71. if self.field.__class__ is BitField:
  72. # Update flags from class in case they've changed.
  73. retval._keys = self.field.flags
  74. return retval
  75. class BitField(BigIntegerField):
  76. def contribute_to_class(self, cls, name, **kwargs):
  77. super(BitField, self).contribute_to_class(cls, name, **kwargs)
  78. setattr(cls, self.name, BitFieldCreator(self))
  79. def __init__(self, flags, default=None, *args, **kwargs):
  80. if isinstance(flags, dict):
  81. # Get only integer keys in correct range
  82. valid_keys = (
  83. k
  84. for k in flags.keys()
  85. if isinstance(k, int) and (0 <= k < MAX_FLAG_COUNT)
  86. )
  87. if not valid_keys:
  88. raise ValueError("Wrong keys or empty dictionary")
  89. # Fill list with values from dict or with empty values
  90. flags = [flags.get(i, "") for i in range(max(valid_keys) + 1)]
  91. if len(flags) > MAX_FLAG_COUNT:
  92. raise ValueError("Too many flags")
  93. self._arg_flags = flags
  94. flags = list(flags)
  95. labels = []
  96. for num, flag in enumerate(flags):
  97. if isinstance(flag, (tuple, list)):
  98. flags[num] = flag[0]
  99. labels.append(flag[1])
  100. else:
  101. labels.append(flag)
  102. if isinstance(default, (list, tuple, set, frozenset)):
  103. new_value = 0
  104. for flag in default:
  105. new_value |= Bit(flags.index(flag))
  106. default = new_value
  107. BigIntegerField.__init__(self, default=default, *args, **kwargs)
  108. self.flags = flags
  109. self.labels = labels
  110. def formfield(self, form_class=BitFormField, **kwargs):
  111. choices = [(k, self.labels[self.flags.index(k)]) for k in self.flags]
  112. return Field.formfield(self, form_class, choices=choices, **kwargs)
  113. def get_prep_value(self, value):
  114. if value is None:
  115. return None
  116. if isinstance(value, (BitHandler, Bit)):
  117. value = value.mask
  118. return int(value)
  119. # def get_db_prep_save(self, value, connection):
  120. # if isinstance(value, Bit):
  121. # return BitQuerySaveWrapper(self.model._meta.db_table, self.name, value)
  122. # return super(BitField, self).get_db_prep_save(value, connection=connection)
  123. def to_python(self, value):
  124. if isinstance(value, Bit):
  125. value = value.mask
  126. if not isinstance(value, BitHandler):
  127. # Regression for #1425: fix bad data that was created resulting
  128. # in negative values for flags. Compute the value that would
  129. # have been visible ot the application to preserve compatibility.
  130. if isinstance(value, six.integer_types) and value < 0:
  131. new_value = 0
  132. for bit_number, _ in enumerate(self.flags):
  133. new_value |= value & (2**bit_number)
  134. value = new_value
  135. value = BitHandler(value, self.flags, self.labels)
  136. else:
  137. # Ensure flags are consistent for unpickling
  138. value._keys = self.flags
  139. return value
  140. def deconstruct(self):
  141. name, path, args, kwargs = super(BitField, self).deconstruct()
  142. args.insert(0, self._arg_flags)
  143. return name, path, args, kwargs
  144. BitField.register_lookup(BitQueryLookupWrapper)
  145. class CompositeBitFieldWrapper(object):
  146. def __init__(self, fields):
  147. self.fields = fields
  148. def __getattr__(self, attr):
  149. if attr == "fields":
  150. return super(CompositeBitFieldWrapper, self).__getattr__(attr)
  151. for field in self.fields:
  152. if hasattr(field, attr):
  153. return getattr(field, attr)
  154. raise AttributeError("%s is not a valid flag" % attr)
  155. def __hasattr__(self, attr):
  156. if attr == "fields":
  157. return super(CompositeBitFieldWrapper, self).__hasattr__(attr)
  158. for field in self.fields:
  159. if hasattr(field, attr):
  160. return True
  161. return False
  162. def __setattr__(self, attr, value):
  163. if attr == "fields":
  164. super(CompositeBitFieldWrapper, self).__setattr__(attr, value)
  165. return
  166. for field in self.fields:
  167. if hasattr(field, attr):
  168. setattr(field, attr, value)
  169. return
  170. raise AttributeError("%s is not a valid flag" % attr)
  171. class CompositeBitField(object):
  172. is_relation = False
  173. many_to_many = False
  174. concrete = False
  175. def __init__(self, fields):
  176. self.fields = fields
  177. def contribute_to_class(self, cls, name):
  178. self.name = name
  179. self.model = cls
  180. cls._meta.private_fields.append(self)
  181. signals.class_prepared.connect(self.validate_fields, sender=cls)
  182. setattr(cls, name, self)
  183. def validate_fields(self, sender, **kwargs):
  184. cls = sender
  185. model_fields = dict(
  186. [(f.name, f) for f in cls._meta.fields if f.name in self.fields]
  187. )
  188. all_flags = sum([model_fields[f].flags for f in self.fields], [])
  189. if len(all_flags) != len(set(all_flags)):
  190. raise ValueError("BitField flags must be unique.")
  191. def __get__(self, instance, instance_type=None):
  192. fields = [getattr(instance, f) for f in self.fields]
  193. return CompositeBitFieldWrapper(fields)
  194. def __set__(self, *args, **kwargs):
  195. raise NotImplementedError("CompositeBitField cannot be set.")