expressions.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. from typing import Union
  2. from django.db.models import CharField, Field, expressions
  3. class HStoreValue(expressions.Expression):
  4. """Represents a HStore value.
  5. The base PostgreSQL implementation Django provides, always
  6. represents HStore values as dictionaries, but this doesn't work if
  7. you want to use expressions inside hstore values.
  8. """
  9. def __init__(self, value):
  10. """Initializes a new instance."""
  11. self.value = value
  12. def resolve_expression(self, *args, **kwargs):
  13. """Resolves expressions inside the dictionary."""
  14. result = dict()
  15. for key, value in self.value.items():
  16. if hasattr(value, "resolve_expression"):
  17. result[key] = value.resolve_expression(*args, **kwargs)
  18. else:
  19. result[key] = value
  20. return HStoreValue(result)
  21. def as_sql(self, compiler, connection):
  22. """Compiles the HStore value into SQL.
  23. Compiles expressions contained in the values
  24. of HStore entries as well.
  25. Given a dictionary like:
  26. dict(key1='val1', key2='val2')
  27. The resulting SQL will be:
  28. hstore(hstore('key1', 'val1'), hstore('key2', 'val2'))
  29. """
  30. sql = []
  31. params = []
  32. for key, value in self.value.items():
  33. if hasattr(value, "as_sql"):
  34. inner_sql, inner_params = value.as_sql(compiler, connection)
  35. sql.append(f"hstore(%s, {inner_sql})")
  36. params.append(key)
  37. params.extend(inner_params)
  38. elif value is not None:
  39. sql.append("hstore(%s, %s)")
  40. params.append(key)
  41. params.append(str(value))
  42. else:
  43. sql.append("hstore(%s, NULL)")
  44. params.append(key)
  45. return " || ".join(sql), params
  46. class HStoreColumn(expressions.Col):
  47. """HStoreColumn expression.
  48. Generates expressions like:
  49. [db table].[column]->'[hstore key]'
  50. """
  51. contains_column_references = True
  52. def __init__(self, alias, target, hstore_key):
  53. """Initializes a new instance of :see:HStoreColumn.
  54. Arguments:
  55. alias:
  56. The table name.
  57. target:
  58. The field instance.
  59. hstore_key
  60. The name of the hstore key to include
  61. in the epxression.
  62. """
  63. super().__init__(alias, target, output_field=target)
  64. self.alias, self.target, self.hstore_key = alias, target, hstore_key
  65. def __repr__(self):
  66. """Gets a textual representation of this expresion."""
  67. return "{}({}, {}->'{}')".format(
  68. self.__class__.__name__, self.alias, self.target, self.hstore_key
  69. )
  70. def as_sql(self, compiler, connection):
  71. """Compiles this expression into SQL."""
  72. qn = compiler.quote_name_unless_alias
  73. return (
  74. "%s.%s->'%s'"
  75. % (qn(self.alias), qn(self.target.column), self.hstore_key),
  76. [],
  77. )
  78. def relabeled_clone(self, relabels):
  79. """Gets a re-labeled clone of this expression."""
  80. return self.__class__(
  81. relabels.get(self.alias, self.alias),
  82. self.target,
  83. self.hstore_key,
  84. self.output_field,
  85. )
  86. class HStoreRef(expressions.F):
  87. """Inline reference to a HStore key.
  88. Allows selecting individual keys in annotations.
  89. """
  90. def __init__(self, name: str, key: str):
  91. """Initializes a new instance of :see:HStoreRef.
  92. Arguments:
  93. name:
  94. The name of the column/field to resolve.
  95. key:
  96. The name of the HStore key to select.
  97. """
  98. super().__init__(name)
  99. self.key = key
  100. def resolve_expression(self, *args, **kwargs):
  101. """Resolves the expression into a :see:HStoreColumn expression."""
  102. original_expression: expressions.Col = super().resolve_expression( # type: ignore[annotation-unchecked]
  103. *args, **kwargs
  104. )
  105. expression = HStoreColumn(
  106. original_expression.alias, original_expression.target, self.key
  107. )
  108. return expression
  109. class DateTimeEpochColumn(expressions.Col):
  110. """Gets the date/time column as a UNIX epoch timestamp."""
  111. contains_column_references = True
  112. def as_sql(self, compiler, connection):
  113. """Compiles this expression into SQL."""
  114. sql, params = super().as_sql(compiler, connection)
  115. return "EXTRACT(epoch FROM {})".format(sql), params
  116. def get_group_by_cols(self):
  117. return []
  118. class DateTimeEpoch(expressions.F):
  119. """Gets the date/time column as a UNIX epoch timestamp."""
  120. contains_aggregate = False
  121. def resolve_expression(self, *args, **kwargs):
  122. original_expression = super().resolve_expression(*args, **kwargs)
  123. expression = DateTimeEpochColumn(
  124. original_expression.alias, original_expression.target
  125. )
  126. return expression
  127. def IsNotNone(*fields, default=None):
  128. """Selects whichever field is not None, in the specified order.
  129. Arguments:
  130. fields:
  131. The fields to attempt to get a value from,
  132. in order.
  133. default:
  134. The value to return in case all values are None.
  135. Returns:
  136. A Case-When expression that tries each field and
  137. returns the specified default value when all of
  138. them are None.
  139. """
  140. when_clauses = [
  141. expressions.When(
  142. ~expressions.Q(**{field: None}), then=expressions.F(field)
  143. )
  144. for field in reversed(fields)
  145. ]
  146. return expressions.Case(
  147. *when_clauses,
  148. default=expressions.Value(default),
  149. output_field=CharField(),
  150. )
  151. class ExcludedCol(expressions.Expression):
  152. """References a column in PostgreSQL's special EXCLUDED column, which is
  153. used in upserts to refer to the data about to be inserted/updated.
  154. See: https://www.postgresql.org/docs/current/sql-insert.html#SQL-ON-CONFLICT
  155. """
  156. def __init__(self, field_or_name: Union[Field, str]):
  157. # We support both field classes or just field names here. We prefer
  158. # fields because when the expression is compiled, it might need
  159. # the field information to figure out the correct placeholder.
  160. # Even though that isn't require for this particular expression.
  161. if isinstance(field_or_name, Field):
  162. super().__init__(field_or_name)
  163. self.name = field_or_name.column
  164. else:
  165. super().__init__(None)
  166. self.name = field_or_name
  167. def as_sql(self, compiler, connection):
  168. quoted_name = connection.ops.quote_name(self.name)
  169. return f"EXCLUDED.{quoted_name}", tuple()