sql.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. from collections import OrderedDict
  2. from typing import List, Optional, Tuple
  3. import django
  4. from django.core.exceptions import SuspiciousOperation
  5. from django.db import connections, models
  6. from django.db.models import sql
  7. from django.db.models.constants import LOOKUP_SEP
  8. from .compiler import PostgresInsertOnConflictCompiler
  9. from .compiler import SQLUpdateCompiler as PostgresUpdateCompiler
  10. from .expressions import HStoreColumn
  11. from .fields import HStoreField
  12. from .types import ConflictAction
  13. class PostgresQuery(sql.Query):
  14. def chain(self, klass=None):
  15. """Chains this query to another.
  16. We override this so that we can make sure our subclassed query
  17. classes are used.
  18. """
  19. if klass == sql.UpdateQuery:
  20. return super().chain(PostgresUpdateQuery)
  21. if klass == sql.InsertQuery:
  22. return super().chain(PostgresInsertQuery)
  23. return super().chain(klass)
  24. def rename_annotations(self, annotations) -> None:
  25. """Renames the aliases for the specified annotations:
  26. .annotate(myfield=F('somestuf__myfield'))
  27. .rename_annotations(myfield='field')
  28. Arguments:
  29. annotations:
  30. The annotations to rename. Mapping the
  31. old name to the new name.
  32. """
  33. # safety check only, make sure there are no renames
  34. # left that cannot be mapped back to the original name
  35. for old_name, new_name in annotations.items():
  36. annotation = self.annotations.get(old_name)
  37. if not annotation:
  38. raise SuspiciousOperation(
  39. (
  40. 'Cannot rename annotation "{old_name}" to "{new_name}", because there'
  41. ' is no annotation named "{old_name}".'
  42. ).format(old_name=old_name, new_name=new_name)
  43. )
  44. # rebuild the annotations according to the original order
  45. new_annotations = OrderedDict()
  46. for old_name, annotation in self.annotations.items():
  47. new_name = annotations.get(old_name)
  48. new_annotations[new_name or old_name] = annotation
  49. if new_name and self.annotation_select_mask:
  50. self.annotation_select_mask.discard(old_name)
  51. self.annotation_select_mask.add(new_name)
  52. self.annotations.clear()
  53. self.annotations.update(new_annotations)
  54. def add_fields(self, field_names: List[str], *args, **kwargs) -> None:
  55. """Adds the given (model) fields to the select set.
  56. The field names are added in the order specified. This overrides
  57. the base class's add_fields method. This is called by the
  58. .values() or .values_list() method of the query set. It
  59. instructs the ORM to only select certain values. A lot of
  60. processing is neccesarry because it can be used to easily do
  61. joins. For example, `my_fk__name` pulls in the `name` field in
  62. foreign key `my_fk`. In our case, we want to be able to do
  63. `title__en`, where `title` is a HStoreField and `en` a key. This
  64. doesn't really involve a join. We iterate over the specified
  65. field names and filter out the ones that refer to HStoreField
  66. and compile it into an expression which is added to the list of
  67. to be selected fields using `self.add_select`.
  68. """
  69. # django knows how to do all of this natively from v2.1
  70. # see: https://github.com/django/django/commit/20bab2cf9d02a5c6477d8aac066a635986e0d3f3
  71. if django.VERSION >= (2, 1):
  72. return super().add_fields(field_names, *args, **kwargs)
  73. select = []
  74. field_names_without_hstore = []
  75. for name in field_names:
  76. parts = name.split(LOOKUP_SEP)
  77. # it cannot be a special hstore thing if there's no __ in it
  78. if len(parts) > 1:
  79. column_name, hstore_key = parts[:2]
  80. is_hstore, field = self._is_hstore_field(column_name)
  81. if is_hstore:
  82. select.append(
  83. HStoreColumn(
  84. self.model._meta.db_table or self.model.name,
  85. field,
  86. hstore_key,
  87. )
  88. )
  89. continue
  90. field_names_without_hstore.append(name)
  91. super().add_fields(field_names_without_hstore, *args, **kwargs)
  92. if len(select) > 0:
  93. self.set_select(self.select + tuple(select))
  94. def _is_hstore_field(
  95. self, field_name: str
  96. ) -> Tuple[bool, Optional[models.Field]]:
  97. """Gets whether the field with the specified name is a HStoreField.
  98. Returns A tuple of a boolean indicating whether the field
  99. with the specified name is a HStoreField, and the field
  100. instance.
  101. """
  102. field_instance = None
  103. for field in self.model._meta.local_concrete_fields:
  104. if field.name == field_name or field.column == field_name:
  105. field_instance = field
  106. break
  107. return isinstance(field_instance, HStoreField), field_instance
  108. class PostgresInsertQuery(sql.InsertQuery):
  109. """Insert query using PostgreSQL."""
  110. def __init__(self, *args, **kwargs):
  111. """Initializes a new instance :see:PostgresInsertQuery."""
  112. super(PostgresInsertQuery, self).__init__(*args, **kwargs)
  113. self.conflict_target = []
  114. self.conflict_action = ConflictAction.UPDATE
  115. self.conflict_update_condition = None
  116. self.index_predicate = None
  117. self.update_fields = []
  118. def values(self, objs: List, insert_fields: List, update_fields: List = []):
  119. """Sets the values to be used in this query.
  120. Insert fields are fields that are definitely
  121. going to be inserted, and if an existing row
  122. is found, are going to be overwritten with the
  123. specified value.
  124. Update fields are fields that should be overwritten
  125. in case an update takes place rather than an insert.
  126. If we're dealing with a INSERT, these will not be used.
  127. Arguments:
  128. objs:
  129. The objects to apply this query to.
  130. insert_fields:
  131. The fields to use in the INSERT statement
  132. update_fields:
  133. The fields to only use in the UPDATE statement.
  134. """
  135. self.insert_values(insert_fields, objs, raw=False)
  136. self.update_fields = update_fields
  137. def get_compiler(self, using=None, connection=None):
  138. if using:
  139. connection = connections[using]
  140. return PostgresInsertOnConflictCompiler(self, connection, using)
  141. class PostgresUpdateQuery(sql.UpdateQuery):
  142. """Update query using PostgreSQL."""
  143. def get_compiler(self, using=None, connection=None):
  144. if using:
  145. connection = connections[using]
  146. return PostgresUpdateCompiler(self, connection, using)