deepreload.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362
  1. # -*- coding: utf-8 -*-
  2. """
  3. Provides a reload() function that acts recursively.
  4. Python's normal :func:`python:reload` function only reloads the module that it's
  5. passed. The :func:`reload` function in this module also reloads everything
  6. imported from that module, which is useful when you're changing files deep
  7. inside a package.
  8. To use this as your default reload function, type this for Python 2::
  9. import __builtin__
  10. from IPython.lib import deepreload
  11. __builtin__.reload = deepreload.reload
  12. Or this for Python 3::
  13. import builtins
  14. from IPython.lib import deepreload
  15. builtins.reload = deepreload.reload
  16. A reference to the original :func:`python:reload` is stored in this module as
  17. :data:`original_reload`, so you can restore it later.
  18. This code is almost entirely based on knee.py, which is a Python
  19. re-implementation of hierarchical module import.
  20. """
  21. from __future__ import print_function
  22. #*****************************************************************************
  23. # Copyright (C) 2001 Nathaniel Gray <n8gray@caltech.edu>
  24. #
  25. # Distributed under the terms of the BSD License. The full license is in
  26. # the file COPYING, distributed as part of this software.
  27. #*****************************************************************************
  28. from contextlib import contextmanager
  29. import imp
  30. import sys
  31. from types import ModuleType
  32. from warnings import warn
  33. from IPython.utils.py3compat import builtin_mod, builtin_mod_name
  34. original_import = builtin_mod.__import__
  35. @contextmanager
  36. def replace_import_hook(new_import):
  37. saved_import = builtin_mod.__import__
  38. builtin_mod.__import__ = new_import
  39. try:
  40. yield
  41. finally:
  42. builtin_mod.__import__ = saved_import
  43. def get_parent(globals, level):
  44. """
  45. parent, name = get_parent(globals, level)
  46. Return the package that an import is being performed in. If globals comes
  47. from the module foo.bar.bat (not itself a package), this returns the
  48. sys.modules entry for foo.bar. If globals is from a package's __init__.py,
  49. the package's entry in sys.modules is returned.
  50. If globals doesn't come from a package or a module in a package, or a
  51. corresponding entry is not found in sys.modules, None is returned.
  52. """
  53. orig_level = level
  54. if not level or not isinstance(globals, dict):
  55. return None, ''
  56. pkgname = globals.get('__package__', None)
  57. if pkgname is not None:
  58. # __package__ is set, so use it
  59. if not hasattr(pkgname, 'rindex'):
  60. raise ValueError('__package__ set to non-string')
  61. if len(pkgname) == 0:
  62. if level > 0:
  63. raise ValueError('Attempted relative import in non-package')
  64. return None, ''
  65. name = pkgname
  66. else:
  67. # __package__ not set, so figure it out and set it
  68. if '__name__' not in globals:
  69. return None, ''
  70. modname = globals['__name__']
  71. if '__path__' in globals:
  72. # __path__ is set, so modname is already the package name
  73. globals['__package__'] = name = modname
  74. else:
  75. # Normal module, so work out the package name if any
  76. lastdot = modname.rfind('.')
  77. if lastdot < 0 < level:
  78. raise ValueError("Attempted relative import in non-package")
  79. if lastdot < 0:
  80. globals['__package__'] = None
  81. return None, ''
  82. globals['__package__'] = name = modname[:lastdot]
  83. dot = len(name)
  84. for x in range(level, 1, -1):
  85. try:
  86. dot = name.rindex('.', 0, dot)
  87. except ValueError:
  88. raise ValueError("attempted relative import beyond top-level "
  89. "package")
  90. name = name[:dot]
  91. try:
  92. parent = sys.modules[name]
  93. except:
  94. if orig_level < 1:
  95. warn("Parent module '%.200s' not found while handling absolute "
  96. "import" % name)
  97. parent = None
  98. else:
  99. raise SystemError("Parent module '%.200s' not loaded, cannot "
  100. "perform relative import" % name)
  101. # We expect, but can't guarantee, if parent != None, that:
  102. # - parent.__name__ == name
  103. # - parent.__dict__ is globals
  104. # If this is violated... Who cares?
  105. return parent, name
  106. def load_next(mod, altmod, name, buf):
  107. """
  108. mod, name, buf = load_next(mod, altmod, name, buf)
  109. altmod is either None or same as mod
  110. """
  111. if len(name) == 0:
  112. # completely empty module name should only happen in
  113. # 'from . import' (or '__import__("")')
  114. return mod, None, buf
  115. dot = name.find('.')
  116. if dot == 0:
  117. raise ValueError('Empty module name')
  118. if dot < 0:
  119. subname = name
  120. next = None
  121. else:
  122. subname = name[:dot]
  123. next = name[dot+1:]
  124. if buf != '':
  125. buf += '.'
  126. buf += subname
  127. result = import_submodule(mod, subname, buf)
  128. if result is None and mod != altmod:
  129. result = import_submodule(altmod, subname, subname)
  130. if result is not None:
  131. buf = subname
  132. if result is None:
  133. raise ImportError("No module named %.200s" % name)
  134. return result, next, buf
  135. # Need to keep track of what we've already reloaded to prevent cyclic evil
  136. found_now = {}
  137. def import_submodule(mod, subname, fullname):
  138. """m = import_submodule(mod, subname, fullname)"""
  139. # Require:
  140. # if mod == None: subname == fullname
  141. # else: mod.__name__ + "." + subname == fullname
  142. global found_now
  143. if fullname in found_now and fullname in sys.modules:
  144. m = sys.modules[fullname]
  145. else:
  146. print('Reloading', fullname)
  147. found_now[fullname] = 1
  148. oldm = sys.modules.get(fullname, None)
  149. if mod is None:
  150. path = None
  151. elif hasattr(mod, '__path__'):
  152. path = mod.__path__
  153. else:
  154. return None
  155. try:
  156. # This appears to be necessary on Python 3, because imp.find_module()
  157. # tries to import standard libraries (like io) itself, and we don't
  158. # want them to be processed by our deep_import_hook.
  159. with replace_import_hook(original_import):
  160. fp, filename, stuff = imp.find_module(subname, path)
  161. except ImportError:
  162. return None
  163. try:
  164. m = imp.load_module(fullname, fp, filename, stuff)
  165. except:
  166. # load_module probably removed name from modules because of
  167. # the error. Put back the original module object.
  168. if oldm:
  169. sys.modules[fullname] = oldm
  170. raise
  171. finally:
  172. if fp: fp.close()
  173. add_submodule(mod, m, fullname, subname)
  174. return m
  175. def add_submodule(mod, submod, fullname, subname):
  176. """mod.{subname} = submod"""
  177. if mod is None:
  178. return #Nothing to do here.
  179. if submod is None:
  180. submod = sys.modules[fullname]
  181. setattr(mod, subname, submod)
  182. return
  183. def ensure_fromlist(mod, fromlist, buf, recursive):
  184. """Handle 'from module import a, b, c' imports."""
  185. if not hasattr(mod, '__path__'):
  186. return
  187. for item in fromlist:
  188. if not hasattr(item, 'rindex'):
  189. raise TypeError("Item in ``from list'' not a string")
  190. if item == '*':
  191. if recursive:
  192. continue # avoid endless recursion
  193. try:
  194. all = mod.__all__
  195. except AttributeError:
  196. pass
  197. else:
  198. ret = ensure_fromlist(mod, all, buf, 1)
  199. if not ret:
  200. return 0
  201. elif not hasattr(mod, item):
  202. import_submodule(mod, item, buf + '.' + item)
  203. def deep_import_hook(name, globals=None, locals=None, fromlist=None, level=-1):
  204. """Replacement for __import__()"""
  205. parent, buf = get_parent(globals, level)
  206. head, name, buf = load_next(parent, None if level < 0 else parent, name, buf)
  207. tail = head
  208. while name:
  209. tail, name, buf = load_next(tail, tail, name, buf)
  210. # If tail is None, both get_parent and load_next found
  211. # an empty module name: someone called __import__("") or
  212. # doctored faulty bytecode
  213. if tail is None:
  214. raise ValueError('Empty module name')
  215. if not fromlist:
  216. return head
  217. ensure_fromlist(tail, fromlist, buf, 0)
  218. return tail
  219. modules_reloading = {}
  220. def deep_reload_hook(m):
  221. """Replacement for reload()."""
  222. if not isinstance(m, ModuleType):
  223. raise TypeError("reload() argument must be module")
  224. name = m.__name__
  225. if name not in sys.modules:
  226. raise ImportError("reload(): module %.200s not in sys.modules" % name)
  227. global modules_reloading
  228. try:
  229. return modules_reloading[name]
  230. except:
  231. modules_reloading[name] = m
  232. dot = name.rfind('.')
  233. if dot < 0:
  234. subname = name
  235. path = None
  236. else:
  237. try:
  238. parent = sys.modules[name[:dot]]
  239. except KeyError:
  240. modules_reloading.clear()
  241. raise ImportError("reload(): parent %.200s not in sys.modules" % name[:dot])
  242. subname = name[dot+1:]
  243. path = getattr(parent, "__path__", None)
  244. try:
  245. # This appears to be necessary on Python 3, because imp.find_module()
  246. # tries to import standard libraries (like io) itself, and we don't
  247. # want them to be processed by our deep_import_hook.
  248. with replace_import_hook(original_import):
  249. fp, filename, stuff = imp.find_module(subname, path)
  250. finally:
  251. modules_reloading.clear()
  252. try:
  253. newm = imp.load_module(name, fp, filename, stuff)
  254. except:
  255. # load_module probably removed name from modules because of
  256. # the error. Put back the original module object.
  257. sys.modules[name] = m
  258. raise
  259. finally:
  260. if fp: fp.close()
  261. modules_reloading.clear()
  262. return newm
  263. # Save the original hooks
  264. try:
  265. original_reload = builtin_mod.reload
  266. except AttributeError:
  267. original_reload = imp.reload # Python 3
  268. # Replacement for reload()
  269. def reload(module, exclude=('sys', 'os.path', builtin_mod_name, '__main__',
  270. 'numpy', 'numpy._globals')):
  271. """Recursively reload all modules used in the given module. Optionally
  272. takes a list of modules to exclude from reloading. The default exclude
  273. list contains sys, __main__, and __builtin__, to prevent, e.g., resetting
  274. display, exception, and io hooks.
  275. """
  276. global found_now
  277. for i in exclude:
  278. found_now[i] = 1
  279. try:
  280. with replace_import_hook(deep_import_hook):
  281. return deep_reload_hook(module)
  282. finally:
  283. found_now = {}
  284. def _dreload(module, **kwargs):
  285. """
  286. **deprecated**
  287. import reload explicitly from `IPython.lib.deepreload` to use it
  288. """
  289. # this was marked as deprecated and for 5.0 removal, but
  290. # IPython.core_builtin_trap have a Deprecation warning for 6.0, so cannot
  291. # remove that now.
  292. warn("""
  293. injecting `dreload` in interactive namespace is deprecated since IPython 4.0.
  294. Please import `reload` explicitly from `IPython.lib.deepreload`.
  295. """, DeprecationWarning, stacklevel=2)
  296. reload(module, **kwargs)