readline_hook.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. from __future__ import print_function # PY2
  2. import sys
  3. import traceback
  4. import warnings
  5. import ctypes.util
  6. from ctypes import (pythonapi, cdll, cast,
  7. c_char_p, c_void_p, c_size_t, CFUNCTYPE)
  8. from .info import WINDOWS
  9. try:
  10. import pyreadline
  11. except ImportError:
  12. pyreadline = None
  13. def get_libc():
  14. if WINDOWS:
  15. path = "msvcrt"
  16. else:
  17. path = ctypes.util.find_library("c")
  18. if path is None:
  19. raise RuntimeError("cannot locate libc")
  20. return cdll[path]
  21. LIBC = get_libc()
  22. PyMem_Malloc = pythonapi.PyMem_Malloc
  23. PyMem_Malloc.restype = c_size_t
  24. PyMem_Malloc.argtypes = [c_size_t]
  25. strncpy = LIBC.strncpy
  26. strncpy.restype = c_char_p
  27. strncpy.argtypes = [c_char_p, c_char_p, c_size_t]
  28. HOOKFUNC = CFUNCTYPE(c_char_p, c_void_p, c_void_p, c_char_p)
  29. #PyOS_ReadlineFunctionPointer = c_void_p.in_dll(pythonapi, "PyOS_ReadlineFunctionPointer")
  30. def new_zero_terminated_string(b):
  31. p = PyMem_Malloc(len(b) + 1)
  32. strncpy(cast(p, c_char_p), b, len(b) + 1)
  33. return p
  34. def check_encodings():
  35. if sys.stdin.encoding != sys.stdout.encoding:
  36. # raise RuntimeError("sys.stdin.encoding != sys.stdout.encoding, readline hook doesn't know, which one to use to decode prompt")
  37. warnings.warn("sys.stdin.encoding == {!r}, whereas sys.stdout.encoding == {!r}, readline hook consumer may assume they are the same".format(sys.stdin.encoding, sys.stdout.encoding),
  38. RuntimeWarning, stacklevel=3)
  39. def stdio_readline(prompt=""):
  40. sys.stdout.write(prompt)
  41. sys.stdout.flush()
  42. return sys.stdin.readline()
  43. class ReadlineHookManager:
  44. def __init__(self):
  45. self.readline_wrapper_ref = HOOKFUNC(self.readline_wrapper)
  46. self.address = cast(self.readline_wrapper_ref, c_void_p).value
  47. #self.original_address = PyOS_ReadlineFunctionPointer.value
  48. self.readline_hook = None
  49. def readline_wrapper(self, stdin, stdout, prompt):
  50. try:
  51. try:
  52. check_encodings()
  53. except RuntimeError:
  54. traceback.print_exc(file=sys.stderr)
  55. try:
  56. prompt = prompt.decode("utf-8")
  57. except UnicodeDecodeError:
  58. prompt = ""
  59. else:
  60. prompt = prompt.decode(sys.stdout.encoding)
  61. try:
  62. line = self.readline_hook(prompt)
  63. except KeyboardInterrupt:
  64. return 0
  65. else:
  66. return new_zero_terminated_string(line.encode(sys.stdin.encoding))
  67. except:
  68. self.restore_original()
  69. print("Internal win_unicode_console error, disabling custom readline hook...", file=sys.stderr)
  70. traceback.print_exc(file=sys.stderr)
  71. return new_zero_terminated_string(b"\n")
  72. def install_hook(self, hook):
  73. self.readline_hook = hook
  74. PyOS_ReadlineFunctionPointer.value = self.address
  75. def restore_original(self):
  76. self.readline_hook = None
  77. PyOS_ReadlineFunctionPointer.value = self.original_address
  78. class PyReadlineManager:
  79. def __init__(self):
  80. self.original_codepage = pyreadline.unicode_helper.pyreadline_codepage
  81. def set_codepage(self, codepage):
  82. pyreadline.unicode_helper.pyreadline_codepage = codepage
  83. def restore_original(self):
  84. self.set_codepage(self.original_codepage)
  85. def pyreadline_is_active():
  86. if not pyreadline:
  87. return False
  88. ref = pyreadline.console.console.readline_ref
  89. if ref is None:
  90. return False
  91. return cast(ref, c_void_p).value == PyOS_ReadlineFunctionPointer.value
  92. manager = ReadlineHookManager()
  93. if pyreadline:
  94. pyreadline_manager = PyReadlineManager()
  95. # PY3 # def enable(*, use_pyreadline=True):
  96. def enable(use_pyreadline=True):
  97. check_encodings()
  98. if use_pyreadline and pyreadline:
  99. pyreadline_manager.set_codepage(sys.stdin.encoding)
  100. # pyreadline assumes that encoding of all sys.stdio objects is the same
  101. if not pyreadline_is_active():
  102. manager.install_hook(stdio_readline)
  103. else:
  104. manager.install_hook(stdio_readline)
  105. def disable():
  106. if pyreadline:
  107. pyreadline_manager.restore_original()
  108. else:
  109. manager.restore_original()