fields.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. from urllib.parse import parse_qs
  2. from typing import List
  3. import re
  4. from rest_framework import serializers
  5. from rest_framework.exceptions import ValidationError, ErrorDetail
  6. class ErrorValueDetail(ErrorDetail):
  7. """Extended ErrorDetail with validation value"""
  8. value = None
  9. def __new__(cls, string, code=None, value=None):
  10. self = super().__new__(cls, string, code)
  11. self.value = value
  12. return self
  13. def __repr__(self):
  14. return "ErrorDetail(string=%r, code=%r, value=%r)" % (
  15. str(self),
  16. self.code,
  17. self.value,
  18. )
  19. class GenericField(serializers.Field):
  20. def to_internal_value(self, data):
  21. return data
  22. class ForgivingFieldMixin:
  23. def update_handled_errors_context(self, errors: List[ErrorValueDetail]):
  24. if errors:
  25. handled_errors = self.context.get("handled_errors", {})
  26. self.context["handled_errors"] = handled_errors | {self.field_name: errors}
  27. class ForgivingHStoreField(ForgivingFieldMixin, serializers.HStoreField):
  28. def run_child_validation(self, data):
  29. result = {}
  30. errors: List[ErrorValueDetail] = []
  31. for key, value in data.items():
  32. if value is None:
  33. continue
  34. key = str(key)
  35. try:
  36. result[key] = self.child.run_validation(value)
  37. except ValidationError as e:
  38. for detail in e.detail:
  39. errors.append(ErrorValueDetail(str(detail), detail.code, value))
  40. if errors:
  41. self.update_handled_errors_context(errors)
  42. return result
  43. class ForgivingDisallowRegexField(ForgivingFieldMixin, serializers.CharField):
  44. """Disallow bad matches, set disallow_regex kwarg to use"""
  45. def __init__(self, **kwargs):
  46. self.disallow_regex = kwargs.pop("disallow_regex", None)
  47. super().__init__(**kwargs)
  48. def to_internal_value(self, data):
  49. data = super().to_internal_value(data)
  50. if self.disallow_regex:
  51. pattern = re.compile(self.disallow_regex)
  52. if pattern.match(data) is None:
  53. error = ErrorValueDetail(
  54. "invalid characters in string", "invalid_data", data
  55. )
  56. self.update_handled_errors_context([error])
  57. return None
  58. return data
  59. class QueryStringField(serializers.ListField):
  60. """
  61. Can be given as unparsed string, dictionary, or list of tuples
  62. Should store as List[List[str]] where inner List is always of length 2
  63. """
  64. child = serializers.ListField(child=serializers.CharField())
  65. def to_internal_value(self, data):
  66. if isinstance(data, str) and data:
  67. qs = parse_qs(data)
  68. result = []
  69. for key, values in qs.items():
  70. for value in values:
  71. result.append([key, value])
  72. return result
  73. elif isinstance(data, dict):
  74. return [[key, value] for key, value in data.items()]
  75. elif isinstance(data, list):
  76. result = []
  77. for item in data:
  78. if isinstance(item, list) and len(item) >= 2:
  79. result.append(item[:2])
  80. return result
  81. return None