base.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. import logging
  2. from typing import TYPE_CHECKING
  3. from django.conf import settings
  4. from django.db import ProgrammingError
  5. from . import base_impl
  6. from .introspection import PostgresIntrospection
  7. from .operations import PostgresOperations
  8. from .schema import PostgresSchemaEditor
  9. from django.db.backends.postgresql.base import ( # isort:skip
  10. DatabaseWrapper as PostgresDatabaseWrapper,
  11. )
  12. logger = logging.getLogger(__name__)
  13. if TYPE_CHECKING:
  14. class Wrapper(PostgresDatabaseWrapper):
  15. pass
  16. else:
  17. Wrapper = base_impl.backend()
  18. class DatabaseWrapper(Wrapper):
  19. """Wraps the standard PostgreSQL database back-end.
  20. Overrides the schema editor with our custom schema editor and makes
  21. sure the `hstore` extension is enabled.
  22. """
  23. SchemaEditorClass = PostgresSchemaEditor # type: ignore[assignment]
  24. introspection_class = PostgresIntrospection
  25. ops_class = PostgresOperations
  26. def __init__(self, *args, **kwargs):
  27. super().__init__(*args, **kwargs)
  28. # Some base back-ends such as the PostGIS back-end don't properly
  29. # set `ops_class` and `introspection_class` and initialize these
  30. # classes themselves.
  31. #
  32. # This can lead to broken functionality. We fix this automatically.
  33. if not isinstance(self.introspection, self.introspection_class):
  34. self.introspection = self.introspection_class(self)
  35. if not isinstance(self.ops, self.ops_class):
  36. self.ops = self.ops_class(self)
  37. for expected_compiler_class in self.ops.compiler_classes:
  38. compiler_class = self.ops.compiler(expected_compiler_class.__name__)
  39. if not issubclass(compiler_class, expected_compiler_class):
  40. logger.warning(
  41. "Compiler '%s.%s' is not properly deriving from '%s.%s'."
  42. % (
  43. compiler_class.__module__,
  44. compiler_class.__name__,
  45. expected_compiler_class.__module__,
  46. expected_compiler_class.__name__,
  47. )
  48. )
  49. def prepare_database(self):
  50. """Ran to prepare the configured database.
  51. This is where we enable the `hstore` extension if it wasn't
  52. enabled yet.
  53. """
  54. super().prepare_database()
  55. setup_ext = getattr(
  56. settings, "POSTGRES_EXTRA_AUTO_EXTENSION_SET_UP", True
  57. )
  58. if not setup_ext:
  59. return False
  60. with self.cursor() as cursor:
  61. try:
  62. cursor.execute("CREATE EXTENSION IF NOT EXISTS hstore")
  63. except ProgrammingError: # permission denied
  64. logger.warning(
  65. 'Failed to create "hstore" extension. '
  66. "Tables with hstore columns may fail to migrate. "
  67. "If hstore is needed, make sure you are connected "
  68. "to the database as a superuser "
  69. "or add the extension manually.",
  70. exc_info=True,
  71. )