visitor.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. """Generic visitor pattern implementation for Python objects."""
  2. import enum
  3. class Visitor(object):
  4. defaultStop = False
  5. @classmethod
  6. def _register(celf, clazzes_attrs):
  7. assert celf != Visitor, "Subclass Visitor instead."
  8. if "_visitors" not in celf.__dict__:
  9. celf._visitors = {}
  10. def wrapper(method):
  11. assert method.__name__ == "visit"
  12. for clazzes, attrs in clazzes_attrs:
  13. if type(clazzes) != tuple:
  14. clazzes = (clazzes,)
  15. if type(attrs) == str:
  16. attrs = (attrs,)
  17. for clazz in clazzes:
  18. _visitors = celf._visitors.setdefault(clazz, {})
  19. for attr in attrs:
  20. assert attr not in _visitors, (
  21. "Oops, class '%s' has visitor function for '%s' defined already."
  22. % (clazz.__name__, attr)
  23. )
  24. _visitors[attr] = method
  25. return None
  26. return wrapper
  27. @classmethod
  28. def register(celf, clazzes):
  29. if type(clazzes) != tuple:
  30. clazzes = (clazzes,)
  31. return celf._register([(clazzes, (None,))])
  32. @classmethod
  33. def register_attr(celf, clazzes, attrs):
  34. clazzes_attrs = []
  35. if type(clazzes) != tuple:
  36. clazzes = (clazzes,)
  37. if type(attrs) == str:
  38. attrs = (attrs,)
  39. for clazz in clazzes:
  40. clazzes_attrs.append((clazz, attrs))
  41. return celf._register(clazzes_attrs)
  42. @classmethod
  43. def register_attrs(celf, clazzes_attrs):
  44. return celf._register(clazzes_attrs)
  45. @classmethod
  46. def _visitorsFor(celf, thing, _default={}):
  47. typ = type(thing)
  48. for celf in celf.mro():
  49. _visitors = getattr(celf, "_visitors", None)
  50. if _visitors is None:
  51. break
  52. for base in typ.mro():
  53. m = celf._visitors.get(base, None)
  54. if m is not None:
  55. return m
  56. return _default
  57. def visitObject(self, obj, *args, **kwargs):
  58. """Called to visit an object. This function loops over all non-private
  59. attributes of the objects and calls any user-registered (via
  60. @register_attr() or @register_attrs()) visit() functions.
  61. If there is no user-registered visit function, of if there is and it
  62. returns True, or it returns None (or doesn't return anything) and
  63. visitor.defaultStop is False (default), then the visitor will proceed
  64. to call self.visitAttr()"""
  65. keys = sorted(vars(obj).keys())
  66. _visitors = self._visitorsFor(obj)
  67. defaultVisitor = _visitors.get("*", None)
  68. for key in keys:
  69. if key[0] == "_":
  70. continue
  71. value = getattr(obj, key)
  72. visitorFunc = _visitors.get(key, defaultVisitor)
  73. if visitorFunc is not None:
  74. ret = visitorFunc(self, obj, key, value, *args, **kwargs)
  75. if ret == False or (ret is None and self.defaultStop):
  76. continue
  77. self.visitAttr(obj, key, value, *args, **kwargs)
  78. def visitAttr(self, obj, attr, value, *args, **kwargs):
  79. """Called to visit an attribute of an object."""
  80. self.visit(value, *args, **kwargs)
  81. def visitList(self, obj, *args, **kwargs):
  82. """Called to visit any value that is a list."""
  83. for value in obj:
  84. self.visit(value, *args, **kwargs)
  85. def visitDict(self, obj, *args, **kwargs):
  86. """Called to visit any value that is a dictionary."""
  87. for value in obj.values():
  88. self.visit(value, *args, **kwargs)
  89. def visitLeaf(self, obj, *args, **kwargs):
  90. """Called to visit any value that is not an object, list,
  91. or dictionary."""
  92. pass
  93. def visit(self, obj, *args, **kwargs):
  94. """This is the main entry to the visitor. The visitor will visit object
  95. obj.
  96. The visitor will first determine if there is a registered (via
  97. @register()) visit function for the type of object. If there is, it
  98. will be called, and (visitor, obj, *args, **kwargs) will be passed to
  99. the user visit function.
  100. If there is no user-registered visit function, of if there is and it
  101. returns True, or it returns None (or doesn't return anything) and
  102. visitor.defaultStop is False (default), then the visitor will proceed
  103. to dispatch to one of self.visitObject(), self.visitList(),
  104. self.visitDict(), or self.visitLeaf() (any of which can be overriden in
  105. a subclass)."""
  106. visitorFunc = self._visitorsFor(obj).get(None, None)
  107. if visitorFunc is not None:
  108. ret = visitorFunc(self, obj, *args, **kwargs)
  109. if ret == False or (ret is None and self.defaultStop):
  110. return
  111. if hasattr(obj, "__dict__") and not isinstance(obj, enum.Enum):
  112. self.visitObject(obj, *args, **kwargs)
  113. elif isinstance(obj, list):
  114. self.visitList(obj, *args, **kwargs)
  115. elif isinstance(obj, dict):
  116. self.visitDict(obj, *args, **kwargs)
  117. else:
  118. self.visitLeaf(obj, *args, **kwargs)