|
@@ -11,7 +11,7 @@ from prompt_toolkit.application.run_in_terminal import run_in_terminal
|
|
|
from prompt_toolkit.data_structures import Size
|
|
|
from prompt_toolkit.eventloop import get_event_loop
|
|
|
from prompt_toolkit.formatted_text import AnyFormattedText, to_formatted_text
|
|
|
-from prompt_toolkit.input import create_pipe_input
|
|
|
+from prompt_toolkit.input import PipeInput, create_pipe_input
|
|
|
from prompt_toolkit.output.vt100 import Vt100_Output
|
|
|
from prompt_toolkit.renderer import print_formatted_text as print_formatted_text
|
|
|
from prompt_toolkit.styles import BaseStyle, DummyStyle
|
|
@@ -87,6 +87,7 @@ class _ConnectionStdout:
|
|
|
self._connection = connection
|
|
|
self._errors = "strict"
|
|
|
self._buffer: List[bytes] = []
|
|
|
+ self._closed = False
|
|
|
|
|
|
def write(self, data: str) -> None:
|
|
|
data = data.replace("\n", "\r\n")
|
|
@@ -98,12 +99,16 @@ class _ConnectionStdout:
|
|
|
|
|
|
def flush(self) -> None:
|
|
|
try:
|
|
|
- self._connection.send(b"".join(self._buffer))
|
|
|
+ if not self._closed:
|
|
|
+ self._connection.send(b"".join(self._buffer))
|
|
|
except OSError as e:
|
|
|
logger.warning("Couldn't send data over socket: %s" % e)
|
|
|
|
|
|
self._buffer = []
|
|
|
|
|
|
+ def close(self) -> None:
|
|
|
+ self._closed = True
|
|
|
+
|
|
|
@property
|
|
|
def encoding(self) -> str:
|
|
|
return self._encoding
|
|
@@ -126,6 +131,7 @@ class TelnetConnection:
|
|
|
server: "TelnetServer",
|
|
|
encoding: str,
|
|
|
style: Optional[BaseStyle],
|
|
|
+ vt100_input: PipeInput,
|
|
|
) -> None:
|
|
|
|
|
|
self.conn = conn
|
|
@@ -136,6 +142,7 @@ class TelnetConnection:
|
|
|
self.style = style
|
|
|
self._closed = False
|
|
|
self._ready = asyncio.Event()
|
|
|
+ self.vt100_input = vt100_input
|
|
|
self.vt100_output = None
|
|
|
|
|
|
# Create "Output" object.
|
|
@@ -144,9 +151,6 @@ class TelnetConnection:
|
|
|
# Initialize.
|
|
|
_initialize_telnet(conn)
|
|
|
|
|
|
- # Create input.
|
|
|
- self.vt100_input = create_pipe_input()
|
|
|
-
|
|
|
# Create output.
|
|
|
def get_size() -> Size:
|
|
|
return self.size
|
|
@@ -160,8 +164,8 @@ class TelnetConnection:
|
|
|
def size_received(rows: int, columns: int) -> None:
|
|
|
"""TelnetProtocolParser 'size_received' callback"""
|
|
|
self.size = Size(rows=rows, columns=columns)
|
|
|
- if self.vt100_output is not None:
|
|
|
- get_app()._on_resize()
|
|
|
+ if self.vt100_output is not None and self.context:
|
|
|
+ self.context.run(lambda: get_app()._on_resize())
|
|
|
|
|
|
def ttype_received(ttype: str) -> None:
|
|
|
"""TelnetProtocolParser 'ttype_received' callback"""
|
|
@@ -197,12 +201,6 @@ class TelnetConnection:
|
|
|
with create_app_session(input=self.vt100_input, output=self.vt100_output):
|
|
|
self.context = contextvars.copy_context()
|
|
|
await self.interact(self)
|
|
|
- except Exception as e:
|
|
|
- print("Got %s" % type(e).__name__, e)
|
|
|
- import traceback
|
|
|
-
|
|
|
- traceback.print_exc()
|
|
|
- raise
|
|
|
finally:
|
|
|
self.close()
|
|
|
|
|
@@ -222,6 +220,7 @@ class TelnetConnection:
|
|
|
self.vt100_input.close()
|
|
|
get_event_loop().remove_reader(self.conn)
|
|
|
self.conn.close()
|
|
|
+ self.stdout.close()
|
|
|
|
|
|
def send(self, formatted_text: AnyFormattedText) -> None:
|
|
|
"""
|
|
@@ -336,22 +335,43 @@ class TelnetServer:
|
|
|
conn, addr = self._listen_socket.accept()
|
|
|
logger.info("New connection %r %r", *addr)
|
|
|
|
|
|
- connection = TelnetConnection(
|
|
|
- conn, addr, self.interact, self, encoding=self.encoding, style=self.style
|
|
|
- )
|
|
|
- self.connections.add(connection)
|
|
|
-
|
|
|
# Run application for this connection.
|
|
|
async def run() -> None:
|
|
|
- logger.info("Starting interaction %r %r", *addr)
|
|
|
try:
|
|
|
- await connection.run_application()
|
|
|
- except Exception as e:
|
|
|
- print(e)
|
|
|
+ with create_pipe_input() as vt100_input:
|
|
|
+ connection = TelnetConnection(
|
|
|
+ conn,
|
|
|
+ addr,
|
|
|
+ self.interact,
|
|
|
+ self,
|
|
|
+ encoding=self.encoding,
|
|
|
+ style=self.style,
|
|
|
+ vt100_input=vt100_input,
|
|
|
+ )
|
|
|
+ self.connections.add(connection)
|
|
|
+
|
|
|
+ logger.info("Starting interaction %r %r", *addr)
|
|
|
+ try:
|
|
|
+ await connection.run_application()
|
|
|
+ finally:
|
|
|
+ self.connections.remove(connection)
|
|
|
+ logger.info("Stopping interaction %r %r", *addr)
|
|
|
+ except EOFError:
|
|
|
+ # Happens either when the connection is closed by the client
|
|
|
+ # (e.g., when the user types 'control-]', then 'quit' in the
|
|
|
+ # telnet client) or when the user types control-d in a prompt
|
|
|
+ # and this is not handled by the interact function.
|
|
|
+ logger.info("Unhandled EOFError in telnet application.")
|
|
|
+ except KeyboardInterrupt:
|
|
|
+ # Unhandled control-c propagated by a prompt.
|
|
|
+ logger.info("Unhandled KeyboardInterrupt in telnet application.")
|
|
|
+ except BaseException as e:
|
|
|
+ print("Got %s" % type(e).__name__, e)
|
|
|
+ import traceback
|
|
|
+
|
|
|
+ traceback.print_exc()
|
|
|
finally:
|
|
|
- self.connections.remove(connection)
|
|
|
self._application_tasks.remove(task)
|
|
|
- logger.info("Stopping interaction %r %r", *addr)
|
|
|
|
|
|
task = get_event_loop().create_task(run())
|
|
|
self._application_tasks.append(task)
|