context.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. import logging
  2. import re
  3. from typing import Optional, Dict, Union, Any
  4. logger = logging.getLogger(__name__)
  5. _empty_map = {}
  6. # pylint: disable=too-many-instance-attributes
  7. class BaseQueryContext:
  8. def __init__(self,
  9. settings: Optional[Dict[str, Any]] = None,
  10. query_formats: Optional[Dict[str, str]] = None,
  11. column_formats: Optional[Dict[str, Union[str, Dict[str, str]]]] = None,
  12. encoding: Optional[str] = None,
  13. use_extended_dtypes: bool = False,
  14. use_numpy: bool = False):
  15. self.settings = settings or {}
  16. if query_formats is None:
  17. self.type_formats = _empty_map
  18. else:
  19. self.type_formats = {re.compile(type_name.replace('*', '.*'), re.IGNORECASE): fmt
  20. for type_name, fmt in query_formats.items()}
  21. if column_formats is None:
  22. self.col_simple_formats = _empty_map
  23. self.col_type_formats = _empty_map
  24. else:
  25. self.col_simple_formats = {col_name: fmt for col_name, fmt in column_formats.items() if
  26. isinstance(fmt, str)}
  27. self.col_type_formats = {}
  28. for col_name, fmt in column_formats.items():
  29. if not isinstance(fmt, str):
  30. self.col_type_formats[col_name] = {re.compile(type_name.replace('*', '.*'), re.IGNORECASE): fmt
  31. for type_name, fmt in fmt.items()}
  32. self.query_formats = query_formats or {}
  33. self.column_formats = column_formats or {}
  34. self.encoding = encoding
  35. self.use_numpy = use_numpy
  36. self.use_extended_dtypes = use_extended_dtypes
  37. self._active_col_fmt = None
  38. self._active_col_type_fmts = _empty_map
  39. def start_column(self, name: str):
  40. self._active_col_fmt = self.col_simple_formats.get(name)
  41. self._active_col_type_fmts = self.col_type_formats.get(name, _empty_map)
  42. def active_fmt(self, ch_type):
  43. if self._active_col_fmt:
  44. return self._active_col_fmt
  45. for type_pattern, fmt in self._active_col_type_fmts.items():
  46. if type_pattern.match(ch_type):
  47. return fmt
  48. for type_pattern, fmt in self.type_formats.items():
  49. if type_pattern.match(ch_type):
  50. return fmt
  51. return None