fields.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. from collections import OrderedDict
  2. from urllib.parse import parse_qs
  3. from typing import List
  4. import re
  5. from rest_framework import serializers
  6. from rest_framework.exceptions import ValidationError, ErrorDetail
  7. class ErrorValueDetail(ErrorDetail):
  8. """Extended ErrorDetail with validation value"""
  9. value = None
  10. def __new__(cls, string, code=None, value=None):
  11. self = super().__new__(cls, string, code)
  12. self.value = value
  13. return self
  14. def __repr__(self):
  15. return "ErrorDetail(string=%r, code=%r, value=%r)" % (
  16. str(self),
  17. self.code,
  18. self.value,
  19. )
  20. class GenericField(serializers.Field):
  21. def to_internal_value(self, data):
  22. return data
  23. class ForgivingFieldMixin:
  24. def update_handled_errors_context(self, errors: List[ErrorValueDetail]):
  25. if errors:
  26. handled_errors = self.context.get("handled_errors", {})
  27. self.context["handled_errors"] = handled_errors | {self.field_name: errors}
  28. class ForgivingHStoreField(ForgivingFieldMixin, serializers.HStoreField):
  29. def run_child_validation(self, data):
  30. result = {}
  31. errors: List[ErrorValueDetail] = []
  32. for key, value in data.items():
  33. if value is None:
  34. continue
  35. key = str(key)
  36. try:
  37. result[key] = self.child.run_validation(value)
  38. except ValidationError as e:
  39. for detail in e.detail:
  40. errors.append(ErrorValueDetail(str(detail), detail.code, value))
  41. if errors:
  42. self.update_handled_errors_context(errors)
  43. return result
  44. class ForgivingDisallowRegexField(ForgivingFieldMixin, serializers.CharField):
  45. """ Disallow bad matches, set disallow_regex kwarg to use """
  46. def __init__(self, **kwargs):
  47. self.disallow_regex = kwargs.pop("disallow_regex", None)
  48. super().__init__(**kwargs)
  49. def to_internal_value(self, data):
  50. data = super().to_internal_value(data)
  51. if self.disallow_regex:
  52. pattern = re.compile(self.disallow_regex)
  53. if pattern.match(data) is None:
  54. error = ErrorValueDetail(
  55. "invalid characters in string", "invalid_data", data
  56. )
  57. self.update_handled_errors_context([error])
  58. return None
  59. return data
  60. class QueryStringField(serializers.ListField):
  61. """
  62. Can be given as unparsed string, dictionary, or list of tuples
  63. Should store as List[List[str]] where inner List is always of length 2
  64. """
  65. child = serializers.ListField(child=serializers.CharField())
  66. def to_internal_value(self, data):
  67. if isinstance(data, str) and data:
  68. qs = parse_qs(data)
  69. result = []
  70. for key, values in qs.items():
  71. for value in values:
  72. result.append([key, value])
  73. return result
  74. elif isinstance(data, dict):
  75. return [[key, value] for key, value in data.items()]
  76. elif isinstance(data, list):
  77. result = []
  78. for item in data:
  79. if isinstance(item, list) and len(item) >= 2:
  80. result.append(item[:2])
  81. return result
  82. return None