shimmodule.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. """A shim module for deprecated imports
  2. """
  3. # Copyright (c) IPython Development Team.
  4. # Distributed under the terms of the Modified BSD License.
  5. import sys
  6. import types
  7. from .importstring import import_item
  8. class ShimWarning(Warning):
  9. """A warning to show when a module has moved, and a shim is in its place."""
  10. class ShimImporter(object):
  11. """Import hook for a shim.
  12. This ensures that submodule imports return the real target module,
  13. not a clone that will confuse `is` and `isinstance` checks.
  14. """
  15. def __init__(self, src, mirror):
  16. self.src = src
  17. self.mirror = mirror
  18. def _mirror_name(self, fullname):
  19. """get the name of the mirrored module"""
  20. return self.mirror + fullname[len(self.src):]
  21. def find_module(self, fullname, path=None):
  22. """Return self if we should be used to import the module."""
  23. if fullname.startswith(self.src + '.'):
  24. mirror_name = self._mirror_name(fullname)
  25. try:
  26. mod = import_item(mirror_name)
  27. except ImportError:
  28. return
  29. else:
  30. if not isinstance(mod, types.ModuleType):
  31. # not a module
  32. return None
  33. return self
  34. def load_module(self, fullname):
  35. """Import the mirrored module, and insert it into sys.modules"""
  36. mirror_name = self._mirror_name(fullname)
  37. mod = import_item(mirror_name)
  38. sys.modules[fullname] = mod
  39. return mod
  40. class ShimModule(types.ModuleType):
  41. def __init__(self, *args, **kwargs):
  42. self._mirror = kwargs.pop("mirror")
  43. src = kwargs.pop("src", None)
  44. if src:
  45. kwargs['name'] = src.rsplit('.', 1)[-1]
  46. super(ShimModule, self).__init__(*args, **kwargs)
  47. # add import hook for descendent modules
  48. if src:
  49. sys.meta_path.append(
  50. ShimImporter(src=src, mirror=self._mirror)
  51. )
  52. @property
  53. def __path__(self):
  54. return []
  55. @property
  56. def __spec__(self):
  57. """Don't produce __spec__ until requested"""
  58. return __import__(self._mirror).__spec__
  59. def __dir__(self):
  60. return dir(__import__(self._mirror))
  61. @property
  62. def __all__(self):
  63. """Ensure __all__ is always defined"""
  64. mod = __import__(self._mirror)
  65. try:
  66. return mod.__all__
  67. except AttributeError:
  68. return [name for name in dir(mod) if not name.startswith('_')]
  69. def __getattr__(self, key):
  70. # Use the equivalent of import_item(name), see below
  71. name = "%s.%s" % (self._mirror, key)
  72. try:
  73. return import_item(name)
  74. except ImportError:
  75. raise AttributeError(key)