123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235 |
- from typing import Union
- from django.db.models import CharField, Field, expressions
- class HStoreValue(expressions.Expression):
- """Represents a HStore value.
- The base PostgreSQL implementation Django provides, always
- represents HStore values as dictionaries, but this doesn't work if
- you want to use expressions inside hstore values.
- """
- def __init__(self, value):
- """Initializes a new instance."""
- self.value = value
- def resolve_expression(self, *args, **kwargs):
- """Resolves expressions inside the dictionary."""
- result = dict()
- for key, value in self.value.items():
- if hasattr(value, "resolve_expression"):
- result[key] = value.resolve_expression(*args, **kwargs)
- else:
- result[key] = value
- return HStoreValue(result)
- def as_sql(self, compiler, connection):
- """Compiles the HStore value into SQL.
- Compiles expressions contained in the values
- of HStore entries as well.
- Given a dictionary like:
- dict(key1='val1', key2='val2')
- The resulting SQL will be:
- hstore(hstore('key1', 'val1'), hstore('key2', 'val2'))
- """
- sql = []
- params = []
- for key, value in self.value.items():
- if hasattr(value, "as_sql"):
- inner_sql, inner_params = value.as_sql(compiler, connection)
- sql.append(f"hstore(%s, {inner_sql})")
- params.append(key)
- params.extend(inner_params)
- elif value is not None:
- sql.append("hstore(%s, %s)")
- params.append(key)
- params.append(str(value))
- else:
- sql.append("hstore(%s, NULL)")
- params.append(key)
- return " || ".join(sql), params
- class HStoreColumn(expressions.Col):
- """HStoreColumn expression.
- Generates expressions like:
- [db table].[column]->'[hstore key]'
- """
- contains_column_references = True
- def __init__(self, alias, target, hstore_key):
- """Initializes a new instance of :see:HStoreColumn.
- Arguments:
- alias:
- The table name.
- target:
- The field instance.
- hstore_key
- The name of the hstore key to include
- in the epxression.
- """
- super().__init__(alias, target, output_field=target)
- self.alias, self.target, self.hstore_key = alias, target, hstore_key
- def __repr__(self):
- """Gets a textual representation of this expresion."""
- return "{}({}, {}->'{}')".format(
- self.__class__.__name__, self.alias, self.target, self.hstore_key
- )
- def as_sql(self, compiler, connection):
- """Compiles this expression into SQL."""
- qn = compiler.quote_name_unless_alias
- return (
- "%s.%s->'%s'"
- % (qn(self.alias), qn(self.target.column), self.hstore_key),
- [],
- )
- def relabeled_clone(self, relabels):
- """Gets a re-labeled clone of this expression."""
- return self.__class__(
- relabels.get(self.alias, self.alias),
- self.target,
- self.hstore_key,
- self.output_field,
- )
- class HStoreRef(expressions.F):
- """Inline reference to a HStore key.
- Allows selecting individual keys in annotations.
- """
- def __init__(self, name: str, key: str):
- """Initializes a new instance of :see:HStoreRef.
- Arguments:
- name:
- The name of the column/field to resolve.
- key:
- The name of the HStore key to select.
- """
- super().__init__(name)
- self.key = key
- def resolve_expression(self, *args, **kwargs):
- """Resolves the expression into a :see:HStoreColumn expression."""
- original_expression: expressions.Col = super().resolve_expression( # type: ignore[annotation-unchecked]
- *args, **kwargs
- )
- expression = HStoreColumn(
- original_expression.alias, original_expression.target, self.key
- )
- return expression
- class DateTimeEpochColumn(expressions.Col):
- """Gets the date/time column as a UNIX epoch timestamp."""
- contains_column_references = True
- def as_sql(self, compiler, connection):
- """Compiles this expression into SQL."""
- sql, params = super().as_sql(compiler, connection)
- return "EXTRACT(epoch FROM {})".format(sql), params
- def get_group_by_cols(self):
- return []
- class DateTimeEpoch(expressions.F):
- """Gets the date/time column as a UNIX epoch timestamp."""
- contains_aggregate = False
- def resolve_expression(self, *args, **kwargs):
- original_expression = super().resolve_expression(*args, **kwargs)
- expression = DateTimeEpochColumn(
- original_expression.alias, original_expression.target
- )
- return expression
- def IsNotNone(*fields, default=None):
- """Selects whichever field is not None, in the specified order.
- Arguments:
- fields:
- The fields to attempt to get a value from,
- in order.
- default:
- The value to return in case all values are None.
- Returns:
- A Case-When expression that tries each field and
- returns the specified default value when all of
- them are None.
- """
- when_clauses = [
- expressions.When(
- ~expressions.Q(**{field: None}), then=expressions.F(field)
- )
- for field in reversed(fields)
- ]
- return expressions.Case(
- *when_clauses,
- default=expressions.Value(default),
- output_field=CharField(),
- )
- class ExcludedCol(expressions.Expression):
- """References a column in PostgreSQL's special EXCLUDED column, which is
- used in upserts to refer to the data about to be inserted/updated.
- See: https://www.postgresql.org/docs/current/sql-insert.html#SQL-ON-CONFLICT
- """
- def __init__(self, field_or_name: Union[Field, str]):
- # We support both field classes or just field names here. We prefer
- # fields because when the expression is compiled, it might need
- # the field information to figure out the correct placeholder.
- # Even though that isn't require for this particular expression.
- if isinstance(field_or_name, Field):
- super().__init__(field_or_name)
- self.name = field_or_name.column
- else:
- super().__init__(None)
- self.name = field_or_name
- def as_sql(self, compiler, connection):
- quoted_name = connection.ops.quote_name(self.name)
- return f"EXCLUDED.{quoted_name}", tuple()
|