deepreload.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310
  1. # -*- coding: utf-8 -*-
  2. """
  3. Provides a reload() function that acts recursively.
  4. Python's normal :func:`python:reload` function only reloads the module that it's
  5. passed. The :func:`reload` function in this module also reloads everything
  6. imported from that module, which is useful when you're changing files deep
  7. inside a package.
  8. To use this as your default reload function, type this::
  9. import builtins
  10. from IPython.lib import deepreload
  11. builtins.reload = deepreload.reload
  12. A reference to the original :func:`python:reload` is stored in this module as
  13. :data:`original_reload`, so you can restore it later.
  14. This code is almost entirely based on knee.py, which is a Python
  15. re-implementation of hierarchical module import.
  16. """
  17. #*****************************************************************************
  18. # Copyright (C) 2001 Nathaniel Gray <n8gray@caltech.edu>
  19. #
  20. # Distributed under the terms of the BSD License. The full license is in
  21. # the file COPYING, distributed as part of this software.
  22. #*****************************************************************************
  23. import builtins as builtin_mod
  24. from contextlib import contextmanager
  25. import importlib
  26. import sys
  27. from types import ModuleType
  28. from warnings import warn
  29. import types
  30. original_import = builtin_mod.__import__
  31. @contextmanager
  32. def replace_import_hook(new_import):
  33. saved_import = builtin_mod.__import__
  34. builtin_mod.__import__ = new_import
  35. try:
  36. yield
  37. finally:
  38. builtin_mod.__import__ = saved_import
  39. def get_parent(globals, level):
  40. """
  41. parent, name = get_parent(globals, level)
  42. Return the package that an import is being performed in. If globals comes
  43. from the module foo.bar.bat (not itself a package), this returns the
  44. sys.modules entry for foo.bar. If globals is from a package's __init__.py,
  45. the package's entry in sys.modules is returned.
  46. If globals doesn't come from a package or a module in a package, or a
  47. corresponding entry is not found in sys.modules, None is returned.
  48. """
  49. orig_level = level
  50. if not level or not isinstance(globals, dict):
  51. return None, ''
  52. pkgname = globals.get('__package__', None)
  53. if pkgname is not None:
  54. # __package__ is set, so use it
  55. if not hasattr(pkgname, 'rindex'):
  56. raise ValueError('__package__ set to non-string')
  57. if len(pkgname) == 0:
  58. if level > 0:
  59. raise ValueError('Attempted relative import in non-package')
  60. return None, ''
  61. name = pkgname
  62. else:
  63. # __package__ not set, so figure it out and set it
  64. if '__name__' not in globals:
  65. return None, ''
  66. modname = globals['__name__']
  67. if '__path__' in globals:
  68. # __path__ is set, so modname is already the package name
  69. globals['__package__'] = name = modname
  70. else:
  71. # Normal module, so work out the package name if any
  72. lastdot = modname.rfind('.')
  73. if lastdot < 0 < level:
  74. raise ValueError("Attempted relative import in non-package")
  75. if lastdot < 0:
  76. globals['__package__'] = None
  77. return None, ''
  78. globals['__package__'] = name = modname[:lastdot]
  79. dot = len(name)
  80. for x in range(level, 1, -1):
  81. try:
  82. dot = name.rindex('.', 0, dot)
  83. except ValueError as e:
  84. raise ValueError("attempted relative import beyond top-level "
  85. "package") from e
  86. name = name[:dot]
  87. try:
  88. parent = sys.modules[name]
  89. except BaseException as e:
  90. if orig_level < 1:
  91. warn("Parent module '%.200s' not found while handling absolute "
  92. "import" % name)
  93. parent = None
  94. else:
  95. raise SystemError("Parent module '%.200s' not loaded, cannot "
  96. "perform relative import" % name) from e
  97. # We expect, but can't guarantee, if parent != None, that:
  98. # - parent.__name__ == name
  99. # - parent.__dict__ is globals
  100. # If this is violated... Who cares?
  101. return parent, name
  102. def load_next(mod, altmod, name, buf):
  103. """
  104. mod, name, buf = load_next(mod, altmod, name, buf)
  105. altmod is either None or same as mod
  106. """
  107. if len(name) == 0:
  108. # completely empty module name should only happen in
  109. # 'from . import' (or '__import__("")')
  110. return mod, None, buf
  111. dot = name.find('.')
  112. if dot == 0:
  113. raise ValueError('Empty module name')
  114. if dot < 0:
  115. subname = name
  116. next = None
  117. else:
  118. subname = name[:dot]
  119. next = name[dot+1:]
  120. if buf != '':
  121. buf += '.'
  122. buf += subname
  123. result = import_submodule(mod, subname, buf)
  124. if result is None and mod != altmod:
  125. result = import_submodule(altmod, subname, subname)
  126. if result is not None:
  127. buf = subname
  128. if result is None:
  129. raise ImportError("No module named %.200s" % name)
  130. return result, next, buf
  131. # Need to keep track of what we've already reloaded to prevent cyclic evil
  132. found_now = {}
  133. def import_submodule(mod, subname, fullname):
  134. """m = import_submodule(mod, subname, fullname)"""
  135. # Require:
  136. # if mod == None: subname == fullname
  137. # else: mod.__name__ + "." + subname == fullname
  138. global found_now
  139. if fullname in found_now and fullname in sys.modules:
  140. m = sys.modules[fullname]
  141. else:
  142. print('Reloading', fullname)
  143. found_now[fullname] = 1
  144. oldm = sys.modules.get(fullname, None)
  145. try:
  146. if oldm is not None:
  147. m = importlib.reload(oldm)
  148. else:
  149. m = importlib.import_module(subname, mod)
  150. except:
  151. # load_module probably removed name from modules because of
  152. # the error. Put back the original module object.
  153. if oldm:
  154. sys.modules[fullname] = oldm
  155. raise
  156. add_submodule(mod, m, fullname, subname)
  157. return m
  158. def add_submodule(mod, submod, fullname, subname):
  159. """mod.{subname} = submod"""
  160. if mod is None:
  161. return #Nothing to do here.
  162. if submod is None:
  163. submod = sys.modules[fullname]
  164. setattr(mod, subname, submod)
  165. return
  166. def ensure_fromlist(mod, fromlist, buf, recursive):
  167. """Handle 'from module import a, b, c' imports."""
  168. if not hasattr(mod, '__path__'):
  169. return
  170. for item in fromlist:
  171. if not hasattr(item, 'rindex'):
  172. raise TypeError("Item in ``from list'' not a string")
  173. if item == '*':
  174. if recursive:
  175. continue # avoid endless recursion
  176. try:
  177. all = mod.__all__
  178. except AttributeError:
  179. pass
  180. else:
  181. ret = ensure_fromlist(mod, all, buf, 1)
  182. if not ret:
  183. return 0
  184. elif not hasattr(mod, item):
  185. import_submodule(mod, item, buf + '.' + item)
  186. def deep_import_hook(name, globals=None, locals=None, fromlist=None, level=-1):
  187. """Replacement for __import__()"""
  188. parent, buf = get_parent(globals, level)
  189. head, name, buf = load_next(parent, None if level < 0 else parent, name, buf)
  190. tail = head
  191. while name:
  192. tail, name, buf = load_next(tail, tail, name, buf)
  193. # If tail is None, both get_parent and load_next found
  194. # an empty module name: someone called __import__("") or
  195. # doctored faulty bytecode
  196. if tail is None:
  197. raise ValueError('Empty module name')
  198. if not fromlist:
  199. return head
  200. ensure_fromlist(tail, fromlist, buf, 0)
  201. return tail
  202. modules_reloading = {}
  203. def deep_reload_hook(m):
  204. """Replacement for reload()."""
  205. # Hardcode this one as it would raise a NotImplementedError from the
  206. # bowels of Python and screw up the import machinery after.
  207. # unlike other imports the `exclude` list already in place is not enough.
  208. if m is types:
  209. return m
  210. if not isinstance(m, ModuleType):
  211. raise TypeError("reload() argument must be module")
  212. name = m.__name__
  213. if name not in sys.modules:
  214. raise ImportError("reload(): module %.200s not in sys.modules" % name)
  215. global modules_reloading
  216. try:
  217. return modules_reloading[name]
  218. except:
  219. modules_reloading[name] = m
  220. try:
  221. newm = importlib.reload(m)
  222. except:
  223. sys.modules[name] = m
  224. raise
  225. finally:
  226. modules_reloading.clear()
  227. return newm
  228. # Save the original hooks
  229. original_reload = importlib.reload
  230. # Replacement for reload()
  231. def reload(
  232. module,
  233. exclude=(
  234. *sys.builtin_module_names,
  235. "sys",
  236. "os.path",
  237. "builtins",
  238. "__main__",
  239. "numpy",
  240. "numpy._globals",
  241. ),
  242. ):
  243. """Recursively reload all modules used in the given module. Optionally
  244. takes a list of modules to exclude from reloading. The default exclude
  245. list contains modules listed in sys.builtin_module_names with additional
  246. sys, os.path, builtins and __main__, to prevent, e.g., resetting
  247. display, exception, and io hooks.
  248. """
  249. global found_now
  250. for i in exclude:
  251. found_now[i] = 1
  252. try:
  253. with replace_import_hook(deep_import_hook):
  254. return deep_reload_hook(module)
  255. finally:
  256. found_now = {}