hstore_required.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. from psqlextra.fields import HStoreField
  2. class HStoreRequiredSchemaEditorSideEffect:
  3. sql_hstore_required_create = (
  4. "ALTER TABLE {table} "
  5. "ADD CONSTRAINT {name} "
  6. "CHECK (({field}->'{key}') "
  7. "IS NOT NULL)"
  8. )
  9. sql_hstore_required_rename = (
  10. "ALTER TABLE {table} "
  11. "RENAME CONSTRAINT "
  12. "{old_name} "
  13. "TO "
  14. "{new_name}"
  15. )
  16. sql_hstore_required_drop = (
  17. "ALTER TABLE {table} " "DROP CONSTRAINT IF EXISTS {name}"
  18. )
  19. def create_model(self, model):
  20. """Ran when a new model is created."""
  21. for field in model._meta.local_fields:
  22. if not isinstance(field, HStoreField):
  23. continue
  24. self.add_field(model, field)
  25. def delete_model(self, model):
  26. """Ran when a model is being deleted."""
  27. for field in model._meta.local_fields:
  28. if not isinstance(field, HStoreField):
  29. continue
  30. self.remove_field(model, field)
  31. def alter_db_table(self, model, old_db_table, new_db_table):
  32. """Ran when the name of a model is changed."""
  33. for field in model._meta.local_fields:
  34. if not isinstance(field, HStoreField):
  35. continue
  36. for key in self._iterate_required_keys(field):
  37. self._rename_hstore_required(
  38. old_db_table, new_db_table, field, field, key
  39. )
  40. def add_field(self, model, field):
  41. """Ran when a field is added to a model."""
  42. for key in self._iterate_required_keys(field):
  43. self._create_hstore_required(model._meta.db_table, field, key)
  44. def remove_field(self, model, field):
  45. """Ran when a field is removed from a model."""
  46. for key in self._iterate_required_keys(field):
  47. self._drop_hstore_required(model._meta.db_table, field, key)
  48. def alter_field(self, model, old_field, new_field, strict=False):
  49. """Ran when the configuration on a field changed."""
  50. is_old_field_hstore = isinstance(old_field, HStoreField)
  51. is_new_field_hstore = isinstance(new_field, HStoreField)
  52. if not is_old_field_hstore and not is_new_field_hstore:
  53. return
  54. old_required = getattr(old_field, "required", []) or []
  55. new_required = getattr(new_field, "required", []) or []
  56. # handle field renames before moving on
  57. if str(old_field.column) != str(new_field.column):
  58. for key in self._iterate_required_keys(old_field):
  59. self._rename_hstore_required(
  60. model._meta.db_table,
  61. model._meta.db_table,
  62. old_field,
  63. new_field,
  64. key,
  65. )
  66. # drop the constraints for keys that have been removed
  67. for key in old_required:
  68. if key not in new_required:
  69. self._drop_hstore_required(model._meta.db_table, old_field, key)
  70. # create new constraints for keys that have been added
  71. for key in new_required:
  72. if key not in old_required:
  73. self._create_hstore_required(
  74. model._meta.db_table, new_field, key
  75. )
  76. def _create_hstore_required(self, table_name, field, key):
  77. """Creates a REQUIRED CONSTRAINT for the specified hstore key."""
  78. name = self._required_constraint_name(table_name, field, key)
  79. sql = self.sql_hstore_required_create.format(
  80. name=self.quote_name(name),
  81. table=self.quote_name(table_name),
  82. field=self.quote_name(field.column),
  83. key=key,
  84. )
  85. self.execute(sql)
  86. def _rename_hstore_required(
  87. self, old_table_name, new_table_name, old_field, new_field, key
  88. ):
  89. """Renames an existing REQUIRED CONSTRAINT for the specified hstore
  90. key."""
  91. old_name = self._required_constraint_name(
  92. old_table_name, old_field, key
  93. )
  94. new_name = self._required_constraint_name(
  95. new_table_name, new_field, key
  96. )
  97. sql = self.sql_hstore_required_rename.format(
  98. table=self.quote_name(new_table_name),
  99. old_name=self.quote_name(old_name),
  100. new_name=self.quote_name(new_name),
  101. )
  102. self.execute(sql)
  103. def _drop_hstore_required(self, table_name, field, key):
  104. """Drops a REQUIRED CONSTRAINT for the specified hstore key."""
  105. name = self._required_constraint_name(table_name, field, key)
  106. sql = self.sql_hstore_required_drop.format(
  107. table=self.quote_name(table_name), name=self.quote_name(name)
  108. )
  109. self.execute(sql)
  110. @staticmethod
  111. def _required_constraint_name(table: str, field, key):
  112. """Gets the name for a CONSTRAINT that applies to a single hstore key.
  113. Arguments:
  114. table:
  115. The name of the table the field is
  116. a part of.
  117. field:
  118. The hstore field to create a
  119. UNIQUE INDEX for.
  120. key:
  121. The name of the hstore key
  122. to create the name for.
  123. Returns:
  124. The name for the UNIQUE index.
  125. """
  126. return "{table}_{field}_required_{postfix}".format(
  127. table=table, field=field.column, postfix=key
  128. )
  129. @staticmethod
  130. def _iterate_required_keys(field):
  131. """Iterates over the keys marked as "required" in the specified field.
  132. Arguments:
  133. field:
  134. The field of which key's to
  135. iterate over.
  136. """
  137. required_keys = getattr(field, "required", None)
  138. if not required_keys:
  139. return
  140. for key in required_keys:
  141. yield key