from contextlib import contextmanager from typing import Generator, List, Optional, Union from django.core.exceptions import SuspiciousOperation from django.db import DEFAULT_DB_ALIAS, connections @contextmanager def postgres_set_local( *, using: str = DEFAULT_DB_ALIAS, **options: Optional[Union[str, int, float, List[str]]], ) -> Generator[None, None, None]: """Sets the specified PostgreSQL options using SET LOCAL so that they apply to the current transacton only. The effect is undone when the context manager exits. See https://www.postgresql.org/docs/current/runtime-config-client.html for an overview of all available options. """ connection = connections[using] qn = connection.ops.quote_name if not connection.in_atomic_block: raise SuspiciousOperation( "SET LOCAL makes no sense outside a transaction. Start a transaction first." ) sql = [] params: List[Union[str, int, float, List[str]]] = [] for name, value in options.items(): if value is None: sql.append(f"SET LOCAL {qn(name)} TO DEFAULT") continue # Settings that accept a list of values are actually # stored as string lists. We cannot just pass a list # of values. We have to create the comma separated # string ourselves. if isinstance(value, list) or isinstance(value, tuple): placeholder = ", ".join(["%s" for _ in value]) params.extend(value) else: placeholder = "%s" params.append(value) sql.append(f"SET LOCAL {qn(name)} = {placeholder}") with connection.cursor() as cursor: cursor.execute( "SELECT name, setting FROM pg_settings WHERE name = ANY(%s)", (list(options.keys()),), ) original_values = dict(cursor.fetchall()) cursor.execute("; ".join(sql), params) yield # Put everything back to how it was. DEFAULT is # not good enough as a outer SET LOCAL might # have set a different value. with connection.cursor() as cursor: sql = [] params = [] for name, value in options.items(): original_value = original_values.get(name) if original_value: sql.append(f"SET LOCAL {qn(name)} = {original_value}") else: sql.append(f"SET LOCAL {qn(name)} TO DEFAULT") cursor.execute("; ".join(sql), params) @contextmanager def postgres_set_local_search_path( search_path: List[str], *, using: str = DEFAULT_DB_ALIAS ) -> Generator[None, None, None]: """Sets the search path to the specified schemas.""" with postgres_set_local(search_path=search_path, using=using): yield @contextmanager def postgres_prepend_local_search_path( search_path: List[str], *, using: str = DEFAULT_DB_ALIAS ) -> Generator[None, None, None]: """Prepends the current local search path with the specified schemas.""" connection = connections[using] with connection.cursor() as cursor: cursor.execute("SHOW search_path") [ original_search_path, ] = cursor.fetchone() placeholders = ", ".join(["%s" for _ in search_path]) cursor.execute( f"SET LOCAL search_path = {placeholders}, {original_search_path}", tuple(search_path), ) yield cursor.execute(f"SET LOCAL search_path = {original_search_path}") @contextmanager def postgres_reset_local_search_path( *, using: str = DEFAULT_DB_ALIAS ) -> Generator[None, None, None]: """Resets the local search path to the default.""" with postgres_set_local(search_path=None, using=using): yield