compiler.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491
  1. import inspect
  2. import os
  3. import sys
  4. from collections.abc import Iterable
  5. from typing import TYPE_CHECKING, Tuple, Union, cast
  6. import django
  7. from django.conf import settings
  8. from django.core.exceptions import SuspiciousOperation
  9. from django.db.models import Expression, Model, Q
  10. from django.db.models.fields.related import RelatedField
  11. from django.db.models.sql import compiler as django_compiler
  12. from .expressions import HStoreValue
  13. from .types import ConflictAction
  14. if TYPE_CHECKING:
  15. from .sql import PostgresInsertQuery
  16. def append_caller_to_sql(sql):
  17. """Append the caller to SQL queries.
  18. Adds the calling file and function as an SQL comment to each query.
  19. Examples:
  20. INSERT INTO "tests_47ee19d1" ("id", "title")
  21. VALUES (1, 'Test')
  22. RETURNING "tests_47ee19d1"."id"
  23. /* 998020 test_append_caller_to_sql_crud .../django-postgres-extra/tests/test_append_caller_to_sql.py 55 */
  24. SELECT "tests_47ee19d1"."id", "tests_47ee19d1"."title"
  25. FROM "tests_47ee19d1"
  26. WHERE "tests_47ee19d1"."id" = 1
  27. LIMIT 1
  28. /* 998020 test_append_caller_to_sql_crud .../django-postgres-extra/tests/test_append_caller_to_sql.py 69 */
  29. UPDATE "tests_47ee19d1"
  30. SET "title" = 'success'
  31. WHERE "tests_47ee19d1"."id" = 1
  32. /* 998020 test_append_caller_to_sql_crud .../django-postgres-extra/tests/test_append_caller_to_sql.py 64 */
  33. DELETE FROM "tests_47ee19d1"
  34. WHERE "tests_47ee19d1"."id" IN (1)
  35. /* 998020 test_append_caller_to_sql_crud .../django-postgres-extra/tests/test_append_caller_to_sql.py 74 */
  36. Slow and blocking queries could be easily tracked down to their originator
  37. within the source code using the "pg_stat_activity" table.
  38. Enable "POSTGRES_EXTRA_ANNOTATE_SQL" within the database settings to enable this feature.
  39. """
  40. if not getattr(settings, "POSTGRES_EXTRA_ANNOTATE_SQL", None):
  41. return sql
  42. try:
  43. # Search for the first non-Django caller
  44. stack = inspect.stack()
  45. for stack_frame in stack[1:]:
  46. frame_filename = stack_frame[1]
  47. frame_line = stack_frame[2]
  48. frame_function = stack_frame[3]
  49. if "/django/" in frame_filename or "/psqlextra/" in frame_filename:
  50. continue
  51. return f"{sql} /* {os.getpid()} {frame_function} {frame_filename} {frame_line} */"
  52. # Django internal commands (like migrations) end up here
  53. return f"{sql} /* {os.getpid()} {sys.argv[0]} */"
  54. except Exception:
  55. # Don't break anything because this convinence function runs into an unexpected situation
  56. return sql
  57. class SQLCompiler(django_compiler.SQLCompiler): # type: ignore [attr-defined]
  58. def as_sql(self, *args, **kwargs):
  59. sql, params = super().as_sql(*args, **kwargs)
  60. return append_caller_to_sql(sql), params
  61. class SQLDeleteCompiler(django_compiler.SQLDeleteCompiler): # type: ignore [name-defined]
  62. def as_sql(self, *args, **kwargs):
  63. sql, params = super().as_sql(*args, **kwargs)
  64. return append_caller_to_sql(sql), params
  65. class SQLAggregateCompiler(django_compiler.SQLAggregateCompiler): # type: ignore [name-defined]
  66. def as_sql(self, *args, **kwargs):
  67. sql, params = super().as_sql(*args, **kwargs)
  68. return append_caller_to_sql(sql), params
  69. class SQLUpdateCompiler(django_compiler.SQLUpdateCompiler): # type: ignore [name-defined]
  70. """Compiler for SQL UPDATE statements that allows us to use expressions
  71. inside HStore values.
  72. Like:
  73. .update(name=dict(en=F('test')))
  74. """
  75. def as_sql(self, *args, **kwargs):
  76. self._prepare_query_values()
  77. sql, params = super().as_sql(*args, **kwargs)
  78. return append_caller_to_sql(sql), params
  79. def _prepare_query_values(self):
  80. """Extra prep on query values by converting dictionaries into
  81. :see:HStoreValue expressions.
  82. This allows putting expressions in a dictionary. The
  83. :see:HStoreValue will take care of resolving the expressions
  84. inside the dictionary.
  85. """
  86. if not self.query.values:
  87. return
  88. new_query_values = []
  89. for field, model, val in self.query.values:
  90. if not isinstance(val, dict):
  91. new_query_values.append((field, model, val))
  92. continue
  93. if not self._does_dict_contain_expression(val):
  94. new_query_values.append((field, model, val))
  95. continue
  96. expression = HStoreValue(dict(val))
  97. new_query_values.append((field, model, expression))
  98. self.query.values = new_query_values
  99. @staticmethod
  100. def _does_dict_contain_expression(data: dict) -> bool:
  101. """Gets whether the specified dictionary contains any expressions that
  102. need to be resolved."""
  103. for value in data.values():
  104. if hasattr(value, "resolve_expression"):
  105. return True
  106. if hasattr(value, "as_sql"):
  107. return True
  108. return False
  109. class SQLInsertCompiler(django_compiler.SQLInsertCompiler): # type: ignore [name-defined]
  110. """Compiler for SQL INSERT statements."""
  111. def as_sql(self, *args, **kwargs):
  112. """Builds the SQL INSERT statement."""
  113. queries = [
  114. (append_caller_to_sql(sql), params)
  115. for sql, params in super().as_sql(*args, **kwargs)
  116. ]
  117. return queries
  118. class PostgresInsertOnConflictCompiler(django_compiler.SQLInsertCompiler): # type: ignore [name-defined]
  119. """Compiler for SQL INSERT statements."""
  120. query: "PostgresInsertQuery"
  121. def __init__(self, *args, **kwargs):
  122. """Initializes a new instance of
  123. :see:PostgresInsertOnConflictCompiler."""
  124. super().__init__(*args, **kwargs)
  125. self.qn = self.connection.ops.quote_name
  126. def as_sql(self, return_id=False, *args, **kwargs):
  127. """Builds the SQL INSERT statement."""
  128. queries = [
  129. self._rewrite_insert(sql, params, return_id)
  130. for sql, params in super().as_sql(*args, **kwargs)
  131. ]
  132. return queries
  133. def _rewrite_insert(self, sql, params, return_id=False):
  134. """Rewrites a formed SQL INSERT query to include the ON CONFLICT
  135. clause.
  136. Arguments:
  137. sql:
  138. The SQL INSERT query to rewrite.
  139. params:
  140. The parameters passed to the query.
  141. return_id:
  142. Whether to only return the ID or all
  143. columns.
  144. Returns:
  145. A tuple of the rewritten SQL query and new params.
  146. """
  147. returning = (
  148. self.qn(self.query.model._meta.pk.attname) if return_id else "*"
  149. )
  150. (sql, params) = self._rewrite_insert_on_conflict(
  151. sql, params, self.query.conflict_action.value, returning
  152. )
  153. return append_caller_to_sql(sql), params
  154. def _rewrite_insert_on_conflict(
  155. self, sql, params, conflict_action: ConflictAction, returning
  156. ):
  157. """Rewrites a normal SQL INSERT query to add the 'ON CONFLICT'
  158. clause."""
  159. # build the conflict target, the columns to watch
  160. # for conflicts
  161. on_conflict_clause = self._build_on_conflict_clause()
  162. index_predicate = self.query.index_predicate # type: ignore[attr-defined]
  163. update_condition = self.query.conflict_update_condition # type: ignore[attr-defined]
  164. rewritten_sql = f"{sql} {on_conflict_clause}"
  165. if index_predicate:
  166. expr_sql, expr_params = self._compile_expression(index_predicate)
  167. rewritten_sql += f" WHERE {expr_sql}"
  168. params += tuple(expr_params)
  169. # Fallback in case the user didn't specify any update values. We can still
  170. # make the query work if we switch to ConflictAction.NOTHING
  171. if (
  172. conflict_action == ConflictAction.UPDATE.value
  173. and not self.query.update_values
  174. ):
  175. conflict_action = ConflictAction.NOTHING
  176. rewritten_sql += f" DO {conflict_action}"
  177. if conflict_action == ConflictAction.UPDATE.value:
  178. set_sql, sql_params = self._build_set_statement()
  179. rewritten_sql += f" SET {set_sql}"
  180. params += sql_params
  181. if update_condition:
  182. expr_sql, expr_params = self._compile_expression(
  183. update_condition
  184. )
  185. rewritten_sql += f" WHERE {expr_sql}"
  186. params += tuple(expr_params)
  187. rewritten_sql += f" RETURNING {returning}"
  188. return (rewritten_sql, params)
  189. def _build_set_statement(self) -> Tuple[str, tuple]:
  190. """Builds the SET statement for the ON CONFLICT DO UPDATE clause.
  191. This uses the update compiler to provide full compatibility with
  192. the standard Django's `update(...)`.
  193. """
  194. # Local import to work around the circular dependency between
  195. # the compiler and the queries.
  196. from .sql import PostgresUpdateQuery
  197. query = cast(PostgresUpdateQuery, self.query.chain(PostgresUpdateQuery))
  198. query.add_update_values(self.query.update_values)
  199. sql, params = query.get_compiler(self.connection.alias).as_sql()
  200. return sql.split("SET")[1].split(" WHERE")[0], tuple(params)
  201. def _build_on_conflict_clause(self):
  202. if django.VERSION >= (2, 2):
  203. from django.db.models.constraints import BaseConstraint
  204. from django.db.models.indexes import Index
  205. if isinstance(
  206. self.query.conflict_target, BaseConstraint
  207. ) or isinstance(self.query.conflict_target, Index):
  208. return "ON CONFLICT ON CONSTRAINT %s" % self.qn(
  209. self.query.conflict_target.name
  210. )
  211. conflict_target = self._build_conflict_target()
  212. return f"ON CONFLICT {conflict_target}"
  213. def _build_conflict_target(self):
  214. """Builds the `conflict_target` for the ON CONFLICT clause."""
  215. if not isinstance(self.query.conflict_target, Iterable):
  216. raise SuspiciousOperation(
  217. (
  218. "%s is not a valid conflict target, specify "
  219. "a list of column names, or tuples with column "
  220. "names and hstore key."
  221. )
  222. % str(self.query.conflict_target)
  223. )
  224. conflict_target = self._build_conflict_target_by_index()
  225. if conflict_target:
  226. return conflict_target
  227. return self._build_conflict_target_by_fields()
  228. def _build_conflict_target_by_fields(self):
  229. """Builds the `conflict_target` for the ON CONFLICT clauses by matching
  230. the fields specified in the specified conflict target against the
  231. model's fields.
  232. This requires some special handling because the fields names
  233. might not be same as the column names.
  234. """
  235. conflict_target = []
  236. for field_name in self.query.conflict_target:
  237. self._assert_valid_field(field_name)
  238. # special handling for hstore keys
  239. if isinstance(field_name, tuple):
  240. conflict_target.append(
  241. "(%s->'%s')"
  242. % (self._format_field_name(field_name), field_name[1])
  243. )
  244. else:
  245. conflict_target.append(self._format_field_name(field_name))
  246. return "(%s)" % ",".join(conflict_target)
  247. def _build_conflict_target_by_index(self):
  248. """Builds the `conflict_target` for the ON CONFLICT clause by trying to
  249. find an index that matches the specified conflict target on the query.
  250. Conflict targets must match some unique constraint, usually this
  251. is a `UNIQUE INDEX`.
  252. """
  253. matching_index = next(
  254. (
  255. index
  256. for index in self.query.model._meta.indexes
  257. if list(index.fields) == list(self.query.conflict_target)
  258. ),
  259. None,
  260. )
  261. if not matching_index:
  262. return None
  263. with self.connection.schema_editor() as schema_editor:
  264. stmt = matching_index.create_sql(self.query.model, schema_editor)
  265. return "(%s)" % stmt.parts["columns"]
  266. def _get_model_field(self, name: str):
  267. """Gets the field on a model with the specified name.
  268. Arguments:
  269. name:
  270. The name of the field to look for.
  271. This can be both the actual field name, or
  272. the name of the column, both will work :)
  273. Returns:
  274. The field with the specified name or None if
  275. no such field exists.
  276. """
  277. field_name = self._normalize_field_name(name)
  278. if not self.query.model:
  279. return None
  280. # 'pk' has special meaning and always refers to the primary
  281. # key of a model, we have to respect this de-facto standard behaviour
  282. if field_name == "pk" and self.query.model._meta.pk:
  283. return self.query.model._meta.pk
  284. for field in self.query.model._meta.local_concrete_fields: # type: ignore[attr-defined]
  285. if field.name == field_name or field.column == field_name:
  286. return field
  287. return None
  288. def _format_field_name(self, field_name) -> str:
  289. """Formats a field's name for usage in SQL.
  290. Arguments:
  291. field_name:
  292. The field name to format.
  293. Returns:
  294. The specified field name formatted for
  295. usage in SQL.
  296. """
  297. field = self._get_model_field(field_name)
  298. return self.qn(field.column)
  299. def _format_field_value(self, field_name) -> str:
  300. """Formats a field's value for usage in SQL.
  301. Arguments:
  302. field_name:
  303. The name of the field to format
  304. the value of.
  305. Returns:
  306. The field's value formatted for usage
  307. in SQL.
  308. """
  309. field_name = self._normalize_field_name(field_name)
  310. field = self._get_model_field(field_name)
  311. value = getattr(self.query.objs[0], field.attname)
  312. if isinstance(field, RelatedField) and isinstance(value, Model):
  313. value = value.pk
  314. return django_compiler.SQLInsertCompiler.prepare_value( # type: ignore[attr-defined]
  315. self,
  316. field,
  317. # Note: this deliberately doesn't use `pre_save_val` as we don't
  318. # want things like auto_now on DateTimeField (etc.) to change the
  319. # value. We rely on pre_save having already been done by the
  320. # underlying compiler so that things like FileField have already had
  321. # the opportunity to save out their data.
  322. value,
  323. )
  324. def _compile_expression(
  325. self, expression: Union[Expression, Q, str]
  326. ) -> Tuple[str, Union[tuple, list]]:
  327. """Compiles an expression, Q object or raw SQL string into SQL and
  328. tuple of parameters."""
  329. if isinstance(expression, Q):
  330. if django.VERSION < (3, 1):
  331. raise SuspiciousOperation(
  332. "Q objects in psqlextra can only be used with Django 3.1 and newer"
  333. )
  334. return self.query.build_where(expression).as_sql(
  335. self, self.connection
  336. )
  337. elif isinstance(expression, Expression):
  338. return self.compile(expression)
  339. return expression, tuple()
  340. def _assert_valid_field(self, field_name: str):
  341. """Asserts that a field with the specified name exists on the model and
  342. raises :see:SuspiciousOperation if it does not."""
  343. field_name = self._normalize_field_name(field_name)
  344. if self._get_model_field(field_name):
  345. return
  346. raise SuspiciousOperation(
  347. (
  348. "%s is not a valid conflict target, specify "
  349. "a list of column names, or tuples with column "
  350. "names and hstore key."
  351. )
  352. % str(field_name)
  353. )
  354. @staticmethod
  355. def _normalize_field_name(field_name: str) -> str:
  356. """Normalizes a field name into a string by extracting the field name
  357. if it was specified as a reference to a HStore key (as a tuple).
  358. Arguments:
  359. field_name:
  360. The field name to normalize.
  361. Returns:
  362. The normalized field name.
  363. """
  364. if isinstance(field_name, tuple):
  365. field_name, _ = field_name
  366. return field_name