123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246 |
- from __future__ import absolute_import
- import six
- from django.db.models import signals
- from django.db.models.fields import BigIntegerField, Field
- from bitfield.forms import BitFormField
- from bitfield.query import BitQueryLookupWrapper
- from bitfield.types import Bit, BitHandler
- # Count binary capacity. Truncate "0b" prefix from binary form.
- # Twice faster than bin(i)[2:] or math.floor(math.log(i))
- MAX_FLAG_COUNT = int(len(bin(BigIntegerField.MAX_BIGINT)) - 2)
- class BitFieldFlags(object):
- def __init__(self, flags):
- if len(flags) > MAX_FLAG_COUNT:
- raise ValueError("Too many flags")
- self._flags = flags
- def __repr__(self):
- return repr(self._flags)
- def __iter__(self):
- for flag in self._flags:
- yield flag
- def __getattr__(self, key):
- if key == "_flags":
- # Since __getattr__ is for fallback, reaching here from Python
- # means that there's no '_flags' attribute in this object,
- # which may be caused by intermediate state while copying etc.
- raise AttributeError(
- "'%s' object has no attribute '%s'" % (self.__class__.__name__, key)
- )
- try:
- flags = self._flags
- except AttributeError:
- raise AttributeError(
- "'%s' object has no attribute '%s'" % (self.__class__.__name__, key)
- )
- try:
- flag = flags.index(key)
- except ValueError:
- raise AttributeError("flag {} is not registered".format(key))
- return Bit(flag)
- def iteritems(self):
- for flag in self._flags:
- yield flag, Bit(self._flags.index(flag))
- def iterkeys(self):
- for flag in self._flags:
- yield flag
- def itervalues(self):
- for flag in self._flags:
- yield Bit(self._flags.index(flag))
- def items(self):
- return list(self.iteritems())
- def keys(self):
- return list(self.iterkeys())
- def values(self):
- return list(self.itervalues())
- class BitFieldCreator(object):
- """
- A placeholder class that provides a way to set the attribute on the model.
- Descriptor for BitFields. Checks to make sure that all flags of the
- instance match the class. This is to handle the case when caching
- an older version of the instance and a newer version of the class is
- available (usually during deploys).
- """
- def __init__(self, field):
- self.field = field
- def __set__(self, obj, value):
- obj.__dict__[self.field.name] = self.field.to_python(value)
- def __get__(self, obj, type=None):
- if obj is None:
- return BitFieldFlags(self.field.flags)
- retval = obj.__dict__[self.field.name]
- if self.field.__class__ is BitField:
- # Update flags from class in case they've changed.
- retval._keys = self.field.flags
- return retval
- class BitField(BigIntegerField):
- def contribute_to_class(self, cls, name, **kwargs):
- super(BitField, self).contribute_to_class(cls, name, **kwargs)
- setattr(cls, self.name, BitFieldCreator(self))
- def __init__(self, flags, default=None, *args, **kwargs):
- if isinstance(flags, dict):
- # Get only integer keys in correct range
- valid_keys = (
- k
- for k in flags.keys()
- if isinstance(k, int) and (0 <= k < MAX_FLAG_COUNT)
- )
- if not valid_keys:
- raise ValueError("Wrong keys or empty dictionary")
- # Fill list with values from dict or with empty values
- flags = [flags.get(i, "") for i in range(max(valid_keys) + 1)]
- if len(flags) > MAX_FLAG_COUNT:
- raise ValueError("Too many flags")
- self._arg_flags = flags
- flags = list(flags)
- labels = []
- for num, flag in enumerate(flags):
- if isinstance(flag, (tuple, list)):
- flags[num] = flag[0]
- labels.append(flag[1])
- else:
- labels.append(flag)
- if isinstance(default, (list, tuple, set, frozenset)):
- new_value = 0
- for flag in default:
- new_value |= Bit(flags.index(flag))
- default = new_value
- BigIntegerField.__init__(self, default=default, *args, **kwargs)
- self.flags = flags
- self.labels = labels
- def formfield(self, form_class=BitFormField, **kwargs):
- choices = [(k, self.labels[self.flags.index(k)]) for k in self.flags]
- return Field.formfield(self, form_class, choices=choices, **kwargs)
- def get_prep_value(self, value):
- if value is None:
- return None
- if isinstance(value, (BitHandler, Bit)):
- value = value.mask
- return int(value)
- # def get_db_prep_save(self, value, connection):
- # if isinstance(value, Bit):
- # return BitQuerySaveWrapper(self.model._meta.db_table, self.name, value)
- # return super(BitField, self).get_db_prep_save(value, connection=connection)
- def to_python(self, value):
- if isinstance(value, Bit):
- value = value.mask
- if not isinstance(value, BitHandler):
- # Regression for #1425: fix bad data that was created resulting
- # in negative values for flags. Compute the value that would
- # have been visible ot the application to preserve compatibility.
- if isinstance(value, six.integer_types) and value < 0:
- new_value = 0
- for bit_number, _ in enumerate(self.flags):
- new_value |= value & (2**bit_number)
- value = new_value
- value = BitHandler(value, self.flags, self.labels)
- else:
- # Ensure flags are consistent for unpickling
- value._keys = self.flags
- return value
- def deconstruct(self):
- name, path, args, kwargs = super(BitField, self).deconstruct()
- args.insert(0, self._arg_flags)
- return name, path, args, kwargs
- BitField.register_lookup(BitQueryLookupWrapper)
- class CompositeBitFieldWrapper(object):
- def __init__(self, fields):
- self.fields = fields
- def __getattr__(self, attr):
- if attr == "fields":
- return super(CompositeBitFieldWrapper, self).__getattr__(attr)
- for field in self.fields:
- if hasattr(field, attr):
- return getattr(field, attr)
- raise AttributeError("%s is not a valid flag" % attr)
- def __hasattr__(self, attr):
- if attr == "fields":
- return super(CompositeBitFieldWrapper, self).__hasattr__(attr)
- for field in self.fields:
- if hasattr(field, attr):
- return True
- return False
- def __setattr__(self, attr, value):
- if attr == "fields":
- super(CompositeBitFieldWrapper, self).__setattr__(attr, value)
- return
- for field in self.fields:
- if hasattr(field, attr):
- setattr(field, attr, value)
- return
- raise AttributeError("%s is not a valid flag" % attr)
- class CompositeBitField(object):
- is_relation = False
- many_to_many = False
- concrete = False
- def __init__(self, fields):
- self.fields = fields
- def contribute_to_class(self, cls, name):
- self.name = name
- self.model = cls
- cls._meta.private_fields.append(self)
- signals.class_prepared.connect(self.validate_fields, sender=cls)
- setattr(cls, name, self)
- def validate_fields(self, sender, **kwargs):
- cls = sender
- model_fields = dict(
- [(f.name, f) for f in cls._meta.fields if f.name in self.fields]
- )
- all_flags = sum([model_fields[f].flags for f in self.fields], [])
- if len(all_flags) != len(set(all_flags)):
- raise ValueError("BitField flags must be unique.")
- def __get__(self, instance, instance_type=None):
- fields = [getattr(instance, f) for f in self.fields]
- return CompositeBitFieldWrapper(fields)
- def __set__(self, *args, **kwargs):
- raise NotImplementedError("CompositeBitField cannot be set.")
|