David Burke 1 year ago
parent
commit
c115d82480
7 changed files with 87 additions and 63 deletions
  1. 2 2
      bitfield/__init__.py
  2. 11 10
      bitfield/admin.py
  3. 1 1
      bitfield/apps.py
  4. 11 10
      bitfield/forms.py
  5. 24 20
      bitfield/models.py
  6. 19 12
      bitfield/query.py
  7. 19 8
      bitfield/types.py

+ 2 - 2
bitfield/__init__.py

@@ -4,8 +4,8 @@ django-bitfield
 """
 from __future__ import absolute_import
 
-from bitfield.models import Bit, BitHandler, CompositeBitField, BitField  # NOQA
+from bitfield.models import Bit, BitField, BitHandler, CompositeBitField  # NOQA
 
-default_app_config = 'bitfield.apps.BitFieldAppConfig'
+default_app_config = "bitfield.apps.BitFieldAppConfig"
 
 VERSION = "2.2.0"

+ 11 - 10
bitfield/admin.py

@@ -1,12 +1,13 @@
 import django
 import six
-
 from django.core.exceptions import ValidationError
+
 if django.VERSION < (2, 0):
     from django.utils.translation import ugettext_lazy as _
 else:
     # Aliased since Django 2.0 https://github.com/django/django/blob/2.0/django/utils/translation/__init__.py#L80-L81
     from django.utils.translation import gettext_lazy as _
+
 from django.contrib.admin import FieldListFilter
 from django.contrib.admin.options import IncorrectLookupParameters
 
@@ -24,12 +25,12 @@ class BitFieldListFilter(FieldListFilter):
         self.flags = field.flags
         self.labels = field.labels
         super(BitFieldListFilter, self).__init__(
-            field, request, params, model, model_admin, field_path)
+            field, request, params, model, model_admin, field_path
+        )
 
     def queryset(self, request, queryset):
         filter_kwargs = dict(
-            (p, BitHandler(v, ()))
-            for p, v in six.iteritems(self.used_parameters)
+            (p, BitHandler(v, ())) for p, v in six.iteritems(self.used_parameters)
         )
         if not filter_kwargs:
             return queryset
@@ -43,14 +44,14 @@ class BitFieldListFilter(FieldListFilter):
 
     def choices(self, cl):
         yield {
-            'selected': self.lookup_val == 0,
-            'query_string': cl.get_query_string({}, [self.lookup_kwarg]),
-            'display': _('All'),
+            "selected": self.lookup_val == 0,
+            "query_string": cl.get_query_string({}, [self.lookup_kwarg]),
+            "display": _("All"),
         }
         for number, flag in enumerate(self.flags):
             bit_mask = Bit(number).mask
             yield {
-                'selected': self.lookup_val == bit_mask,
-                'query_string': cl.get_query_string({self.lookup_kwarg: bit_mask}),
-                'display': self.labels[number],
+                "selected": self.lookup_val == bit_mask,
+                "query_string": cl.get_query_string({self.lookup_kwarg: bit_mask}),
+                "display": self.labels[number],
             }

+ 1 - 1
bitfield/apps.py

@@ -2,5 +2,5 @@ from django.apps import AppConfig
 
 
 class BitFieldAppConfig(AppConfig):
-    name = 'bitfield'
+    name = "bitfield"
     verbose_name = "Bit Field"

+ 11 - 10
bitfield/forms.py

@@ -1,7 +1,6 @@
 from __future__ import absolute_import
 
 from django.forms import CheckboxSelectMultiple, IntegerField, ValidationError
-
 from django.utils.encoding import force_str
 
 from bitfield.types import BitHandler
@@ -14,14 +13,15 @@ class BitFieldCheckboxSelectMultiple(CheckboxSelectMultiple):
         elif isinstance(value, int):
             real_value = []
             div = 2
-            for (k, v) in self.choices:
+            for k, v in self.choices:
                 if value % div != 0:
                     real_value.append(k)
-                    value -= (value % div)
+                    value -= value % div
                 div *= 2
             value = real_value
         return super(BitFieldCheckboxSelectMultiple, self).render(
-            name, value, attrs=attrs)
+            name, value, attrs=attrs
+        )
 
     def has_changed(self, initial, data):
         if initial is None:
@@ -36,15 +36,16 @@ class BitFieldCheckboxSelectMultiple(CheckboxSelectMultiple):
 
 
 class BitFormField(IntegerField):
-    def __init__(self, choices=(), widget=BitFieldCheckboxSelectMultiple, *args, **kwargs):
-
-        if isinstance(kwargs['initial'], int):
-            iv = kwargs['initial']
+    def __init__(
+        self, choices=(), widget=BitFieldCheckboxSelectMultiple, *args, **kwargs
+    ):
+        if isinstance(kwargs["initial"], int):
+            iv = kwargs["initial"]
             iv_list = []
             for i in range(0, min(len(choices), 63)):
                 if (1 << i) & iv > 0:
                     iv_list += [choices[i][0]]
-            kwargs['initial'] = iv_list
+            kwargs["initial"] = iv_list
         self.widget = widget
         super(BitFormField, self).__init__(widget=widget, *args, **kwargs)
         self.choices = self.widget.choices = choices
@@ -59,5 +60,5 @@ class BitFormField(IntegerField):
             try:
                 setattr(result, str(k), True)
             except AttributeError:
-                raise ValidationError('Unknown choice: %r' % (k,))
+                raise ValidationError("Unknown choice: %r" % (k,))
         return int(result)

+ 24 - 20
bitfield/models.py

@@ -1,13 +1,12 @@
 from __future__ import absolute_import
 
 import six
-
 from django.db.models import signals
-from django.db.models.fields import Field, BigIntegerField
+from django.db.models.fields import BigIntegerField, Field
 
 from bitfield.forms import BitFormField
 from bitfield.query import BitQueryLookupWrapper
-from bitfield.types import BitHandler, Bit
+from bitfield.types import Bit, BitHandler
 
 # Count binary capacity. Truncate "0b" prefix from binary form.
 # Twice faster than bin(i)[2:] or math.floor(math.log(i))
@@ -17,7 +16,7 @@ MAX_FLAG_COUNT = int(len(bin(BigIntegerField.MAX_BIGINT)) - 2)
 class BitFieldFlags(object):
     def __init__(self, flags):
         if len(flags) > MAX_FLAG_COUNT:
-            raise ValueError('Too many flags')
+            raise ValueError("Too many flags")
         self._flags = flags
 
     def __repr__(self):
@@ -28,7 +27,7 @@ class BitFieldFlags(object):
             yield flag
 
     def __getattr__(self, key):
-        if key == '_flags':
+        if key == "_flags":
             # Since __getattr__ is for fallback, reaching here from Python
             # means that there's no '_flags' attribute in this object,
             # which may be caused by intermediate state while copying etc.
@@ -77,6 +76,7 @@ class BitFieldCreator(object):
     an older version of the instance and a newer version of the class is
     available (usually during deploys).
     """
+
     def __init__(self, field):
         self.field = field
 
@@ -94,7 +94,6 @@ class BitFieldCreator(object):
 
 
 class BitField(BigIntegerField):
-
     def contribute_to_class(self, cls, name, **kwargs):
         super(BitField, self).contribute_to_class(cls, name, **kwargs)
         setattr(cls, self.name, BitFieldCreator(self))
@@ -102,14 +101,18 @@ class BitField(BigIntegerField):
     def __init__(self, flags, default=None, *args, **kwargs):
         if isinstance(flags, dict):
             # Get only integer keys in correct range
-            valid_keys = (k for k in flags.keys() if isinstance(k, int) and (0 <= k < MAX_FLAG_COUNT))
+            valid_keys = (
+                k
+                for k in flags.keys()
+                if isinstance(k, int) and (0 <= k < MAX_FLAG_COUNT)
+            )
             if not valid_keys:
-                raise ValueError('Wrong keys or empty dictionary')
+                raise ValueError("Wrong keys or empty dictionary")
             # Fill list with values from dict or with empty values
-            flags = [flags.get(i, '') for i in range(max(valid_keys) + 1)]
+            flags = [flags.get(i, "") for i in range(max(valid_keys) + 1)]
 
         if len(flags) > MAX_FLAG_COUNT:
-            raise ValueError('Too many flags')
+            raise ValueError("Too many flags")
 
         self._arg_flags = flags
         flags = list(flags)
@@ -157,7 +160,7 @@ class BitField(BigIntegerField):
             if isinstance(value, six.integer_types) and value < 0:
                 new_value = 0
                 for bit_number, _ in enumerate(self.flags):
-                    new_value |= (value & (2 ** bit_number))
+                    new_value |= value & (2**bit_number)
                 value = new_value
 
             value = BitHandler(value, self.flags, self.labels)
@@ -180,16 +183,16 @@ class CompositeBitFieldWrapper(object):
         self.fields = fields
 
     def __getattr__(self, attr):
-        if attr == 'fields':
+        if attr == "fields":
             return super(CompositeBitFieldWrapper, self).__getattr__(attr)
 
         for field in self.fields:
             if hasattr(field, attr):
                 return getattr(field, attr)
-        raise AttributeError('%s is not a valid flag' % attr)
+        raise AttributeError("%s is not a valid flag" % attr)
 
     def __hasattr__(self, attr):
-        if attr == 'fields':
+        if attr == "fields":
             return super(CompositeBitFieldWrapper, self).__hasattr__(attr)
 
         for field in self.fields:
@@ -198,7 +201,7 @@ class CompositeBitFieldWrapper(object):
         return False
 
     def __setattr__(self, attr, value):
-        if attr == 'fields':
+        if attr == "fields":
             super(CompositeBitFieldWrapper, self).__setattr__(attr, value)
             return
 
@@ -206,7 +209,7 @@ class CompositeBitFieldWrapper(object):
             if hasattr(field, attr):
                 setattr(field, attr, value)
                 return
-        raise AttributeError('%s is not a valid flag' % attr)
+        raise AttributeError("%s is not a valid flag" % attr)
 
 
 class CompositeBitField(object):
@@ -228,15 +231,16 @@ class CompositeBitField(object):
 
     def validate_fields(self, sender, **kwargs):
         cls = sender
-        model_fields = dict([
-            (f.name, f) for f in cls._meta.fields if f.name in self.fields])
+        model_fields = dict(
+            [(f.name, f) for f in cls._meta.fields if f.name in self.fields]
+        )
         all_flags = sum([model_fields[f].flags for f in self.fields], [])
         if len(all_flags) != len(set(all_flags)):
-            raise ValueError('BitField flags must be unique.')
+            raise ValueError("BitField flags must be unique.")
 
     def __get__(self, instance, instance_type=None):
         fields = [getattr(instance, f) for f in self.fields]
         return CompositeBitFieldWrapper(fields)
 
     def __set__(self, *args, **kwargs):
-        raise NotImplementedError('CompositeBitField cannot be set.')
+        raise NotImplementedError("CompositeBitField cannot be set.")

+ 19 - 12
bitfield/query.py

@@ -1,18 +1,20 @@
 from __future__ import absolute_import
 
-from bitfield.types import Bit, BitHandler
 from django.db.models.lookups import Exact
 
+from bitfield.types import Bit, BitHandler
+
 
 class BitQueryLookupWrapper(Exact):  # NOQA
     def process_lhs(self, compiler, connection, lhs=None):
         lhs_sql, lhs_params = super(BitQueryLookupWrapper, self).process_lhs(
-            compiler, connection, lhs)
+            compiler, connection, lhs
+        )
 
         if not isinstance(self.rhs, (BitHandler, Bit)):
             return lhs_sql, lhs_params
 
-        op = ' & ' if self.rhs else ' | '
+        op = " & " if self.rhs else " | "
         rhs_sql, rhs_params = self.process_rhs(compiler, connection)
         params = list(lhs_params)
         params.extend(rhs_params)
@@ -37,16 +39,21 @@ class BitQuerySaveWrapper(BitQueryLookupWrapper):
 
         This will be called by Where.as_sql()
         """
-        engine = connection.settings_dict['ENGINE'].rsplit('.', -1)[-1]
-        if engine.startswith('postgres'):
-            XOR_OPERATOR = '#'
-        elif engine.startswith('sqlite'):
+        engine = connection.settings_dict["ENGINE"].rsplit(".", -1)[-1]
+        if engine.startswith("postgres"):
+            XOR_OPERATOR = "#"
+        elif engine.startswith("sqlite"):
             raise NotImplementedError
         else:
-            XOR_OPERATOR = '^'
+            XOR_OPERATOR = "^"
 
         if self.bit:
-            return ("%s.%s | %d" % (qn(self.table_alias), qn(self.column), self.bit.mask),
-                    [])
-        return ("%s.%s %s %d" % (qn(self.table_alias), qn(self.column), XOR_OPERATOR, self.bit.mask),
-                [])
+            return (
+                "%s.%s | %d" % (qn(self.table_alias), qn(self.column), self.bit.mask),
+                [],
+            )
+        return (
+            "%s.%s %s %d"
+            % (qn(self.table_alias), qn(self.column), XOR_OPERATOR, self.bit.mask),
+            [],
+        )

+ 19 - 8
bitfield/types.py

@@ -11,6 +11,7 @@ class Bit(object):
     """
     Represents a single Bit.
     """
+
     def __init__(self, number, is_set=True):
         self.number = number
         self.is_set = bool(is_set)
@@ -20,7 +21,11 @@ class Bit(object):
             self.mask = ~self.mask
 
     def __repr__(self):
-        return '<%s: number=%d, is_set=%s>' % (self.__class__.__name__, self.number, self.is_set)
+        return "<%s: number=%d, is_set=%s>" % (
+            self.__class__.__name__,
+            self.number,
+            self.is_set,
+        )
 
     # def __str__(self):
     #     if self.is_set:
@@ -117,6 +122,7 @@ class BitHandler(object):
     """
     Represents an array of bits, each as a ``Bit`` object.
     """
+
     def __init__(self, value, keys, labels=None):
         # TODO: change to bitarray?
         if value:
@@ -147,7 +153,12 @@ class BitHandler(object):
         return cmp(self._value, other)
 
     def __repr__(self):
-        return '<%s: %s>' % (self.__class__.__name__, ', '.join('%s=%s' % (k, self.get_bit(n).is_set) for n, k in enumerate(self._keys)),)
+        return "<%s: %s>" % (
+            self.__class__.__name__,
+            ", ".join(
+                "%s=%s" % (k, self.get_bit(n).is_set) for n, k in enumerate(self._keys)
+            ),
+        )
 
     def __str__(self):
         return str(self._value)
@@ -186,17 +197,17 @@ class BitHandler(object):
         return bool(self.get_bit(bit_number))
 
     def __getattr__(self, key):
-        if key.startswith('_'):
+        if key.startswith("_"):
             return object.__getattribute__(self, key)
         if key not in self._keys:
-            raise AttributeError('%s is not a valid flag' % key)
+            raise AttributeError("%s is not a valid flag" % key)
         return self.get_bit(self._keys.index(key))
 
     def __setattr__(self, key, value):
-        if key.startswith('_'):
+        if key.startswith("_"):
             return object.__setattr__(self, key, value)
         if key not in self._keys:
-            raise AttributeError('%s is not a valid flag' % key)
+            raise AttributeError("%s is not a valid flag" % key)
         self.set_bit(self._keys.index(key), value)
 
     def __iter__(self):
@@ -207,6 +218,7 @@ class BitHandler(object):
 
     def _get_mask(self):
         return self._value
+
     mask = property(_get_mask)
 
     def evaluate(self, evaluator, qn, connection):
@@ -221,7 +233,7 @@ class BitHandler(object):
         if true_or_false:
             self._value |= mask
         else:
-            self._value &= (~mask)
+            self._value &= ~mask
         return Bit(bit_number, self._value & mask != 0)
 
     def keys(self):
@@ -243,4 +255,3 @@ class BitHandler(object):
         if isinstance(flag, Bit):
             flag = flag.number
         return self._labels[flag]
-