context.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. import logging
  2. import re
  3. from datetime import datetime
  4. from typing import Optional, Dict, Union, Any
  5. import pytz
  6. logger = logging.getLogger(__name__)
  7. _empty_map = {}
  8. # pylint: disable=too-many-instance-attributes
  9. class BaseQueryContext:
  10. local_tz: pytz.timezone
  11. def __init__(self,
  12. settings: Optional[Dict[str, Any]] = None,
  13. query_formats: Optional[Dict[str, str]] = None,
  14. column_formats: Optional[Dict[str, Union[str, Dict[str, str]]]] = None,
  15. encoding: Optional[str] = None,
  16. use_extended_dtypes: bool = False,
  17. use_numpy: bool = False):
  18. self.settings = settings or {}
  19. if query_formats is None:
  20. self.type_formats = _empty_map
  21. else:
  22. self.type_formats = {re.compile(type_name.replace('*', '.*'), re.IGNORECASE): fmt
  23. for type_name, fmt in query_formats.items()}
  24. if column_formats is None:
  25. self.col_simple_formats = _empty_map
  26. self.col_type_formats = _empty_map
  27. else:
  28. self.col_simple_formats = {col_name: fmt for col_name, fmt in column_formats.items() if
  29. isinstance(fmt, str)}
  30. self.col_type_formats = {}
  31. for col_name, fmt in column_formats.items():
  32. if not isinstance(fmt, str):
  33. self.col_type_formats[col_name] = {re.compile(type_name.replace('*', '.*'), re.IGNORECASE): fmt
  34. for type_name, fmt in fmt.items()}
  35. self.query_formats = query_formats or {}
  36. self.column_formats = column_formats or {}
  37. self.encoding = encoding
  38. self.use_numpy = use_numpy
  39. self.use_extended_dtypes = use_extended_dtypes
  40. self._active_col_fmt = None
  41. self._active_col_type_fmts = _empty_map
  42. def start_column(self, name: str):
  43. self._active_col_fmt = self.col_simple_formats.get(name)
  44. self._active_col_type_fmts = self.col_type_formats.get(name, _empty_map)
  45. def active_fmt(self, ch_type):
  46. if self._active_col_fmt:
  47. return self._active_col_fmt
  48. for type_pattern, fmt in self._active_col_type_fmts.items():
  49. if type_pattern.match(ch_type):
  50. return fmt
  51. for type_pattern, fmt in self.type_formats.items():
  52. if type_pattern.match(ch_type):
  53. return fmt
  54. return None
  55. def _init_context_cls():
  56. local_tz = datetime.now().astimezone().tzinfo
  57. if local_tz.tzname(datetime.now()) in ('UTC', 'GMT', 'Universal', 'GMT-0', 'Zulu', 'Greenwich'):
  58. local_tz = pytz.UTC
  59. BaseQueryContext.local_tz = local_tz
  60. _init_context_cls()