From b21d2d1c2168fda89d6cb2cd75fc6d80a95154e3 Mon Sep 17 00:00:00 2001 From: Jan Lindemann Date: Sat, 21 Mar 2026 11:34:35 +0100 Subject: [PATCH] lib.ec.ssh.AsyncSSH: Add interactivity Request a remote PTY from AsyncSSH, and wire the local terminal's stdin up with it if interactive == True. This gives a real interactive session if local stdin belongs to a terminal. Also, thanks to AsyncSSH understanding that, forward terminal size changes to the remote end. Signed-off-by: Jan Lindemann --- src/python/jw/pkg/lib/ec/ssh/AsyncSSH.py | 296 +++++++++++++++++++++-- 1 file changed, 274 insertions(+), 22 deletions(-) diff --git a/src/python/jw/pkg/lib/ec/ssh/AsyncSSH.py b/src/python/jw/pkg/lib/ec/ssh/AsyncSSH.py index 43fc0444..67b110a3 100644 --- a/src/python/jw/pkg/lib/ec/ssh/AsyncSSH.py +++ b/src/python/jw/pkg/lib/ec/ssh/AsyncSSH.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -import os, sys, shlex, asyncio, asyncssh +import os, sys, shlex, asyncio, asyncssh, shutil, signal from ...log import * from ...ExecContext import Result @@ -22,7 +22,7 @@ class AsyncSSH(Base): super().__init__( uri, - caps=self.Caps.LogOutput | self.Caps.Wd, + caps=self.Caps.LogOutput | self.Caps.Wd | self.Caps.Interactive, **kwargs ) @@ -82,6 +82,31 @@ class AsyncSSH(Base): return tuple(args), kwargs + @staticmethod + def _has_local_tty() -> bool: + try: + return sys.stdin.isatty() and sys.stdout.isatty() + except Exception: + return False + + @staticmethod + def _get_local_term_size() -> tuple[int, int, int, int]: + cols, rows = shutil.get_terminal_size(fallback=(80, 24)) + xpixel = ypixel = 0 + + try: + import fcntl, termios, struct + + packed = fcntl.ioctl(sys.stdout.fileno(), termios.TIOCGWINSZ, b"\0" * 8) + rows2, cols2, xpixel, ypixel = struct.unpack("HHHH", packed) + + if cols2 > 0 and rows2 > 0: + cols, rows = cols2, rows2 + except Exception: + pass + + return (cols, rows, xpixel, ypixel) + async def _read_stream( self, stream, @@ -110,6 +135,219 @@ class AsyncSSH(Base): if verbose and buf: log(prio, log_prefix, buf.decode(log_enc, errors="replace")) + async def _run_interactive_on_conn( + self, + conn: asyncssh.SSHClientConnection, + cmd: list[str], + wd: str | None, + cmd_input: str | None, + env: dict[str, str] | None, + ) -> Result: + command = self._build_remote_command(cmd, wd) + stdout_parts: list[bytes] = [] + + proc = await conn.create_process( + command=command, + env=env, + stdin=asyncssh.PIPE, + stdout=asyncssh.PIPE, + stderr=asyncssh.STDOUT, + encoding=None, + request_pty="force", + term_type=self.term_type, + term_size=self._get_local_term_size(), + ) + + loop = asyncio.get_running_loop() + stdin_fd = sys.stdin.fileno() + + stdin_queue: asyncio.Queue[bytes | None] = asyncio.Queue() + old_tty_state = None + old_winch_handler = None + stdin_reader_installed = False + + def _write_local(data: bytes) -> None: + try: + sys.stdout.buffer.write(data) + sys.stdout.buffer.flush() + except AttributeError: + os.write(sys.stdout.fileno(), data) + + def _on_stdin_ready() -> None: + try: + data = os.read(stdin_fd, 4096) + except OSError: + data = b"" + + if data: + stdin_queue.put_nowait(data) + else: + try: + loop.remove_reader(stdin_fd) + except Exception: + pass + stdin_queue.put_nowait(None) + + async def _pump_stdin() -> None: + if cmd_input is not None and proc.stdin is not None: + proc.stdin.write(cmd_input.encode(sys.stdout.encoding or "utf-8")) + await proc.stdin.drain() + + while True: + data = await stdin_queue.get() + + if data is None: + if proc.stdin is not None: + try: + proc.stdin.write_eof() + except (BrokenPipeError, OSError): + pass + return + + if proc.stdin is None: + return + + proc.stdin.write(data) + await proc.stdin.drain() + + async def _pump_stdout() -> None: + while True: + chunk = await proc.stdout.read(4096) + if not chunk: + break + + stdout_parts.append(chunk) + _write_local(chunk) + + def _on_winch(*_args) -> None: + try: + proc.change_terminal_size(*self._get_local_term_size()) + except Exception: + pass + + try: + sys.stdout.flush() + sys.stderr.flush() + + try: + import termios, tty + + old_tty_state = termios.tcgetattr(stdin_fd) + tty.setraw(stdin_fd) + except Exception: + old_tty_state = None + + try: + loop.add_reader(stdin_fd, _on_stdin_ready) + stdin_reader_installed = True + except (NotImplementedError, RuntimeError): + stdin_queue.put_nowait(None) + + if hasattr(signal, "SIGWINCH"): + try: + old_winch_handler = signal.getsignal(signal.SIGWINCH) + signal.signal(signal.SIGWINCH, _on_winch) + except Exception: + old_winch_handler = None + + stdin_task = asyncio.create_task(_pump_stdin()) + stdout_task = asyncio.create_task(_pump_stdout()) + + completed = await proc.wait(check=False) + await stdout_task + + if not stdin_task.done(): + stdin_task.cancel() + try: + await stdin_task + except asyncio.CancelledError: + pass + + exit_code = completed.exit_status + if exit_code is None: + exit_code = completed.returncode if completed.returncode is not None else -1 + + stdout = b"".join(stdout_parts) if stdout_parts else None + return Result(stdout, None, exit_code) + + finally: + if stdin_reader_installed: + try: + loop.remove_reader(stdin_fd) + except Exception: + pass + + if old_winch_handler is not None and hasattr(signal, "SIGWINCH"): + try: + signal.signal(signal.SIGWINCH, old_winch_handler) + except Exception: + pass + + if old_tty_state is not None: + try: + import termios + termios.tcsetattr(stdin_fd, termios.TCSADRAIN, old_tty_state) + except Exception: + pass + + try: + sys.stdout.flush() + sys.stderr.flush() + except Exception: + pass + + async def _run_captured_pty_on_conn( + self, + conn: asyncssh.SSHClientConnection, + cmd: list[str], + wd: str | None, + verbose: bool, + cmd_input: str | None, + env: dict[str, str] | None, + log_prefix: str, + ) -> Result: + command = self._build_remote_command(cmd, wd) + + stdout_parts: list[bytes] = [] + stdout_log_enc = sys.stdout.encoding or "utf-8" + + proc = await conn.create_process( + command=command, + env=env, + stdin=asyncssh.PIPE if cmd_input is not None else asyncssh.DEVNULL, + stdout=asyncssh.PIPE, + stderr=asyncssh.STDOUT, + encoding=None, + request_pty="force", + term_type=self.term_type, + ) + + task = asyncio.create_task( + self._read_stream( + proc.stdout, + NOTICE, + stdout_parts, + verbose=verbose, + log_prefix=log_prefix, + log_enc=stdout_log_enc, + ) + ) + + if cmd_input is not None and proc.stdin is not None: + proc.stdin.write(cmd_input.encode(sys.stdout.encoding or "utf-8")) + await proc.stdin.drain() + proc.stdin.write_eof() + + completed = await proc.wait(check=False) + await task + + exit_code = completed.exit_status + if exit_code is None: + exit_code = completed.returncode if completed.returncode is not None else -1 + + stdout = b"".join(stdout_parts) if stdout_parts else None + return Result(stdout, None, exit_code) + async def _run_on_conn( self, conn: asyncssh.SSHClientConnection, @@ -121,6 +359,26 @@ class AsyncSSH(Base): interactive: bool, log_prefix: str, ) -> Result: + if interactive: + if self._has_local_tty(): + return await self._run_interactive_on_conn( + conn, + cmd, + wd, + cmd_input, + env, + ) + + return await self._run_captured_pty_on_conn( + conn, + cmd, + wd, + verbose, + cmd_input, + env, + log_prefix, + ) + command = self._build_remote_command(cmd, wd) stdout_parts: list[bytes] = [] @@ -130,17 +388,15 @@ class AsyncSSH(Base): stderr_log_enc = sys.stderr.encoding or "utf-8" stdin_mode = asyncssh.PIPE if cmd_input is not None else asyncssh.DEVNULL - stderr_mode = asyncssh.STDOUT if interactive else asyncssh.PIPE proc = await conn.create_process( command=command, env=env, stdin=stdin_mode, stdout=asyncssh.PIPE, - stderr=stderr_mode, + stderr=asyncssh.PIPE, encoding=None, - request_pty="force" if interactive else False, - term_type=self.term_type if interactive else None, + request_pty=False, ) tasks = [ @@ -153,22 +409,18 @@ class AsyncSSH(Base): log_prefix=log_prefix, log_enc=stdout_log_enc, ) - ) - ] - - if not interactive: - tasks.append( - asyncio.create_task( - self._read_stream( - proc.stderr, - ERR, - stderr_parts, - verbose=verbose, - log_prefix=log_prefix, - log_enc=stderr_log_enc, - ) + ), + asyncio.create_task( + self._read_stream( + proc.stderr, + ERR, + stderr_parts, + verbose=verbose, + log_prefix=log_prefix, + log_enc=stderr_log_enc, ) - ) + ), + ] if cmd_input is not None and proc.stdin is not None: proc.stdin.write(cmd_input.encode(sys.stdout.encoding or "utf-8")) @@ -179,7 +431,7 @@ class AsyncSSH(Base): await asyncio.gather(*tasks) stdout = b"".join(stdout_parts) if stdout_parts else None - stderr = None if interactive else (b"".join(stderr_parts) if stderr_parts else None) + stderr = b"".join(stderr_parts) if stderr_parts else None exit_code = completed.exit_status if exit_code is None: