_methodical.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545
  1. # -*- test-case-name: automat._test.test_methodical -*-
  2. from __future__ import annotations
  3. import collections
  4. import sys
  5. from dataclasses import dataclass, field
  6. from functools import wraps
  7. from inspect import getfullargspec as getArgsSpec
  8. from itertools import count
  9. from typing import Any, Callable, Hashable, Iterable, TypeVar
  10. if sys.version_info < (3, 10):
  11. from typing_extensions import TypeAlias
  12. else:
  13. from typing import TypeAlias
  14. from ._core import Automaton, OutputTracer, Tracer, Transitioner
  15. from ._introspection import preserveName
  16. ArgSpec = collections.namedtuple(
  17. "ArgSpec",
  18. [
  19. "args",
  20. "varargs",
  21. "varkw",
  22. "defaults",
  23. "kwonlyargs",
  24. "kwonlydefaults",
  25. "annotations",
  26. ],
  27. )
  28. def _getArgSpec(func):
  29. """
  30. Normalize inspect.ArgSpec across python versions
  31. and convert mutable attributes to immutable types.
  32. :param Callable func: A function.
  33. :return: The function's ArgSpec.
  34. :rtype: ArgSpec
  35. """
  36. spec = getArgsSpec(func)
  37. return ArgSpec(
  38. args=tuple(spec.args),
  39. varargs=spec.varargs,
  40. varkw=spec.varkw,
  41. defaults=spec.defaults if spec.defaults else (),
  42. kwonlyargs=tuple(spec.kwonlyargs),
  43. kwonlydefaults=(
  44. tuple(spec.kwonlydefaults.items()) if spec.kwonlydefaults else ()
  45. ),
  46. annotations=tuple(spec.annotations.items()),
  47. )
  48. def _getArgNames(spec):
  49. """
  50. Get the name of all arguments defined in a function signature.
  51. The name of * and ** arguments is normalized to "*args" and "**kwargs".
  52. :param ArgSpec spec: A function to interrogate for a signature.
  53. :return: The set of all argument names in `func`s signature.
  54. :rtype: Set[str]
  55. """
  56. return set(
  57. spec.args
  58. + spec.kwonlyargs
  59. + (("*args",) if spec.varargs else ())
  60. + (("**kwargs",) if spec.varkw else ())
  61. + spec.annotations
  62. )
  63. def _keywords_only(f):
  64. """
  65. Decorate a function so all its arguments must be passed by keyword.
  66. A useful utility for decorators that take arguments so that they don't
  67. accidentally get passed the thing they're decorating as their first
  68. argument.
  69. Only works for methods right now.
  70. """
  71. @wraps(f)
  72. def g(self, **kw):
  73. return f(self, **kw)
  74. return g
  75. @dataclass(frozen=True)
  76. class MethodicalState(object):
  77. """
  78. A state for a L{MethodicalMachine}.
  79. """
  80. machine: MethodicalMachine = field(repr=False)
  81. method: Callable[..., Any] = field()
  82. serialized: bool = field(repr=False)
  83. def upon(
  84. self,
  85. input: MethodicalInput,
  86. enter: MethodicalState | None = None,
  87. outputs: Iterable[MethodicalOutput] | None = None,
  88. collector: Callable[[Iterable[T]], object] = list,
  89. ) -> None:
  90. """
  91. Declare a state transition within the L{MethodicalMachine} associated
  92. with this L{MethodicalState}: upon the receipt of the `input`, enter
  93. the `state`, emitting each output in `outputs`.
  94. @param input: The input triggering a state transition.
  95. @param enter: The resulting state.
  96. @param outputs: The outputs to be triggered as a result of the declared
  97. state transition.
  98. @param collector: The function to be used when collecting output return
  99. values.
  100. @raises TypeError: if any of the `outputs` signatures do not match the
  101. `inputs` signature.
  102. @raises ValueError: if the state transition from `self` via `input` has
  103. already been defined.
  104. """
  105. if enter is None:
  106. enter = self
  107. if outputs is None:
  108. outputs = []
  109. inputArgs = _getArgNames(input.argSpec)
  110. for output in outputs:
  111. outputArgs = _getArgNames(output.argSpec)
  112. if not outputArgs.issubset(inputArgs):
  113. raise TypeError(
  114. "method {input} signature {inputSignature} "
  115. "does not match output {output} "
  116. "signature {outputSignature}".format(
  117. input=input.method.__name__,
  118. output=output.method.__name__,
  119. inputSignature=getArgsSpec(input.method),
  120. outputSignature=getArgsSpec(output.method),
  121. )
  122. )
  123. self.machine._oneTransition(self, input, enter, outputs, collector)
  124. def _name(self) -> str:
  125. return self.method.__name__
  126. def _transitionerFromInstance(
  127. oself: object,
  128. symbol: str,
  129. automaton: Automaton[MethodicalState, MethodicalInput, MethodicalOutput],
  130. ) -> Transitioner[MethodicalState, MethodicalInput, MethodicalOutput]:
  131. """
  132. Get a L{Transitioner}
  133. """
  134. transitioner = getattr(oself, symbol, None)
  135. if transitioner is None:
  136. transitioner = Transitioner(
  137. automaton,
  138. automaton.initialState,
  139. )
  140. setattr(oself, symbol, transitioner)
  141. return transitioner
  142. def _empty():
  143. pass
  144. def _docstring():
  145. """docstring"""
  146. def assertNoCode(f: Callable[..., Any]) -> None:
  147. # The function body must be empty, i.e. "pass" or "return None", which
  148. # both yield the same bytecode: LOAD_CONST (None), RETURN_VALUE. We also
  149. # accept functions with only a docstring, which yields slightly different
  150. # bytecode, because the "None" is put in a different constant slot.
  151. # Unfortunately, this does not catch function bodies that return a
  152. # constant value, e.g. "return 1", because their code is identical to a
  153. # "return None". They differ in the contents of their constant table, but
  154. # checking that would require us to parse the bytecode, find the index
  155. # being returned, then making sure the table has a None at that index.
  156. if f.__code__.co_code not in (_empty.__code__.co_code, _docstring.__code__.co_code):
  157. raise ValueError("function body must be empty")
  158. def _filterArgs(args, kwargs, inputSpec, outputSpec):
  159. """
  160. Filter out arguments that were passed to input that output won't accept.
  161. :param tuple args: The *args that input received.
  162. :param dict kwargs: The **kwargs that input received.
  163. :param ArgSpec inputSpec: The input's arg spec.
  164. :param ArgSpec outputSpec: The output's arg spec.
  165. :return: The args and kwargs that output will accept.
  166. :rtype: Tuple[tuple, dict]
  167. """
  168. named_args = tuple(zip(inputSpec.args[1:], args))
  169. if outputSpec.varargs:
  170. # Only return all args if the output accepts *args.
  171. return_args = args
  172. else:
  173. # Filter out arguments that don't appear
  174. # in the output's method signature.
  175. return_args = [v for n, v in named_args if n in outputSpec.args]
  176. # Get any of input's default arguments that were not passed.
  177. passed_arg_names = tuple(kwargs)
  178. for name, value in named_args:
  179. passed_arg_names += (name, value)
  180. defaults = zip(inputSpec.args[::-1], inputSpec.defaults[::-1])
  181. full_kwargs = {n: v for n, v in defaults if n not in passed_arg_names}
  182. full_kwargs.update(kwargs)
  183. if outputSpec.varkw:
  184. # Only pass all kwargs if the output method accepts **kwargs.
  185. return_kwargs = full_kwargs
  186. else:
  187. # Filter out names that the output method does not accept.
  188. all_accepted_names = outputSpec.args[1:] + outputSpec.kwonlyargs
  189. return_kwargs = {
  190. n: v for n, v in full_kwargs.items() if n in all_accepted_names
  191. }
  192. return return_args, return_kwargs
  193. T = TypeVar("T")
  194. R = TypeVar("R")
  195. @dataclass(eq=False)
  196. class MethodicalInput(object):
  197. """
  198. An input for a L{MethodicalMachine}.
  199. """
  200. automaton: Automaton[MethodicalState, MethodicalInput, MethodicalOutput] = field(
  201. repr=False
  202. )
  203. method: Callable[..., Any] = field()
  204. symbol: str = field(repr=False)
  205. collectors: dict[MethodicalState, Callable[[Iterable[T]], R]] = field(
  206. default_factory=dict, repr=False
  207. )
  208. argSpec: ArgSpec = field(init=False, repr=False)
  209. def __post_init__(self) -> None:
  210. self.argSpec = _getArgSpec(self.method)
  211. assertNoCode(self.method)
  212. def __get__(self, oself: object, type: None = None) -> object:
  213. """
  214. Return a function that takes no arguments and returns values returned
  215. by output functions produced by the given L{MethodicalInput} in
  216. C{oself}'s current state.
  217. """
  218. transitioner = _transitionerFromInstance(oself, self.symbol, self.automaton)
  219. @preserveName(self.method)
  220. @wraps(self.method)
  221. def doInput(*args: object, **kwargs: object) -> object:
  222. self.method(oself, *args, **kwargs)
  223. previousState = transitioner._state
  224. (outputs, outTracer) = transitioner.transition(self)
  225. collector = self.collectors[previousState]
  226. values = []
  227. for output in outputs:
  228. if outTracer is not None:
  229. outTracer(output)
  230. a, k = _filterArgs(args, kwargs, self.argSpec, output.argSpec)
  231. value = output(oself, *a, **k)
  232. values.append(value)
  233. return collector(values)
  234. return doInput
  235. def _name(self) -> str:
  236. return self.method.__name__
  237. @dataclass(frozen=True)
  238. class MethodicalOutput(object):
  239. """
  240. An output for a L{MethodicalMachine}.
  241. """
  242. machine: MethodicalMachine = field(repr=False)
  243. method: Callable[..., Any]
  244. argSpec: ArgSpec = field(init=False, repr=False, compare=False)
  245. def __post_init__(self) -> None:
  246. self.__dict__["argSpec"] = _getArgSpec(self.method)
  247. def __get__(self, oself, type=None):
  248. """
  249. Outputs are private, so raise an exception when we attempt to get one.
  250. """
  251. raise AttributeError(
  252. "{cls}.{method} is a state-machine output method; "
  253. "to produce this output, call an input method instead.".format(
  254. cls=type.__name__, method=self.method.__name__
  255. )
  256. )
  257. def __call__(self, oself, *args, **kwargs):
  258. """
  259. Call the underlying method.
  260. """
  261. return self.method(oself, *args, **kwargs)
  262. def _name(self) -> str:
  263. return self.method.__name__
  264. StringOutputTracer = Callable[[str], None]
  265. StringTracer: TypeAlias = "Callable[[str, str, str], StringOutputTracer | None]"
  266. def wrapTracer(
  267. wrapped: StringTracer | None,
  268. ) -> Tracer[MethodicalState, MethodicalInput, MethodicalOutput] | None:
  269. if wrapped is None:
  270. return None
  271. def tracer(
  272. state: MethodicalState,
  273. input: MethodicalInput,
  274. output: MethodicalState,
  275. ) -> OutputTracer[MethodicalOutput] | None:
  276. result = wrapped(state._name(), input._name(), output._name())
  277. if result is not None:
  278. return lambda out: result(out._name())
  279. return None
  280. return tracer
  281. @dataclass(eq=False)
  282. class MethodicalTracer(object):
  283. automaton: Automaton[MethodicalState, MethodicalInput, MethodicalOutput] = field(
  284. repr=False
  285. )
  286. symbol: str = field(repr=False)
  287. def __get__(
  288. self, oself: object, type: object = None
  289. ) -> Callable[[StringTracer], None]:
  290. transitioner = _transitionerFromInstance(oself, self.symbol, self.automaton)
  291. def setTrace(tracer: StringTracer | None) -> None:
  292. transitioner.setTrace(wrapTracer(tracer))
  293. return setTrace
  294. counter = count()
  295. def gensym():
  296. """
  297. Create a unique Python identifier.
  298. """
  299. return "_symbol_" + str(next(counter))
  300. class MethodicalMachine(object):
  301. """
  302. A L{MethodicalMachine} is an interface to an L{Automaton} that uses methods
  303. on a class.
  304. """
  305. def __init__(self):
  306. self._automaton = Automaton()
  307. self._reducers = {}
  308. self._symbol = gensym()
  309. def __get__(self, oself, type=None):
  310. """
  311. L{MethodicalMachine} is an implementation detail for setting up
  312. class-level state; applications should never need to access it on an
  313. instance.
  314. """
  315. if oself is not None:
  316. raise AttributeError("MethodicalMachine is an implementation detail.")
  317. return self
  318. @_keywords_only
  319. def state(
  320. self, initial: bool = False, terminal: bool = False, serialized: Hashable = None
  321. ):
  322. """
  323. Declare a state, possibly an initial state or a terminal state.
  324. This is a decorator for methods, but it will modify the method so as
  325. not to be callable any more.
  326. @param initial: is this state the initial state? Only one state on
  327. this L{automat.MethodicalMachine} may be an initial state; more
  328. than one is an error.
  329. @param terminal: Is this state a terminal state? i.e. a state that the
  330. machine can end up in? (This is purely informational at this
  331. point.)
  332. @param serialized: a serializable value to be used to represent this
  333. state to external systems. This value should be hashable; L{str}
  334. is a good type to use.
  335. """
  336. def decorator(stateMethod):
  337. state = MethodicalState(
  338. machine=self, method=stateMethod, serialized=serialized
  339. )
  340. if initial:
  341. self._automaton.initialState = state
  342. return state
  343. return decorator
  344. @_keywords_only
  345. def input(self):
  346. """
  347. Declare an input.
  348. This is a decorator for methods.
  349. """
  350. def decorator(inputMethod):
  351. return MethodicalInput(
  352. automaton=self._automaton, method=inputMethod, symbol=self._symbol
  353. )
  354. return decorator
  355. @_keywords_only
  356. def output(self):
  357. """
  358. Declare an output.
  359. This is a decorator for methods.
  360. This method will be called when the state machine transitions to this
  361. state as specified in the decorated `output` method.
  362. """
  363. def decorator(outputMethod):
  364. return MethodicalOutput(machine=self, method=outputMethod)
  365. return decorator
  366. def _oneTransition(self, startState, inputToken, endState, outputTokens, collector):
  367. """
  368. See L{MethodicalState.upon}.
  369. """
  370. # FIXME: tests for all of this (some of it is wrong)
  371. # if not isinstance(startState, MethodicalState):
  372. # raise NotImplementedError("start state {} isn't a state"
  373. # .format(startState))
  374. # if not isinstance(inputToken, MethodicalInput):
  375. # raise NotImplementedError("start state {} isn't an input"
  376. # .format(inputToken))
  377. # if not isinstance(endState, MethodicalState):
  378. # raise NotImplementedError("end state {} isn't a state"
  379. # .format(startState))
  380. # for output in outputTokens:
  381. # if not isinstance(endState, MethodicalState):
  382. # raise NotImplementedError("output state {} isn't a state"
  383. # .format(endState))
  384. self._automaton.addTransition(
  385. startState, inputToken, endState, tuple(outputTokens)
  386. )
  387. inputToken.collectors[startState] = collector
  388. @_keywords_only
  389. def serializer(self):
  390. """ """
  391. def decorator(decoratee):
  392. @wraps(decoratee)
  393. def serialize(oself):
  394. transitioner = _transitionerFromInstance(
  395. oself, self._symbol, self._automaton
  396. )
  397. return decoratee(oself, transitioner._state.serialized)
  398. return serialize
  399. return decorator
  400. @_keywords_only
  401. def unserializer(self):
  402. """ """
  403. def decorator(decoratee):
  404. @wraps(decoratee)
  405. def unserialize(oself, *args, **kwargs):
  406. state = decoratee(oself, *args, **kwargs)
  407. mapping = {}
  408. for eachState in self._automaton.states():
  409. mapping[eachState.serialized] = eachState
  410. transitioner = _transitionerFromInstance(
  411. oself, self._symbol, self._automaton
  412. )
  413. transitioner._state = mapping[state]
  414. return None # it's on purpose
  415. return unserialize
  416. return decorator
  417. @property
  418. def _setTrace(self) -> MethodicalTracer:
  419. return MethodicalTracer(self._automaton, self._symbol)
  420. def asDigraph(self):
  421. """
  422. Generate a L{graphviz.Digraph} that represents this machine's
  423. states and transitions.
  424. @return: L{graphviz.Digraph} object; for more information, please
  425. see the documentation for
  426. U{graphviz<https://graphviz.readthedocs.io/>}
  427. """
  428. from ._visualize import makeDigraph
  429. return makeDigraph(
  430. self._automaton,
  431. stateAsString=lambda state: state.method.__name__,
  432. inputAsString=lambda input: input.method.__name__,
  433. outputAsString=lambda output: output.method.__name__,
  434. )