lib.ec.ssh.AsyncSSH: Add interactivity

Request a remote PTY from AsyncSSH, and wire the local terminal's
stdin up with it if interactive == True. This gives a real
interactive session if local stdin belongs to a terminal. Also,
thanks to AsyncSSH understanding that, forward terminal size changes
to the remote end.

Signed-off-by: Jan Lindemann <jan@janware.com>
This commit is contained in:
Jan Lindemann 2026-03-21 11:34:35 +01:00
commit b21d2d1c21

View file

@ -1,6 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import os, sys, shlex, asyncio, asyncssh import os, sys, shlex, asyncio, asyncssh, shutil, signal
from ...log import * from ...log import *
from ...ExecContext import Result from ...ExecContext import Result
@ -22,7 +22,7 @@ class AsyncSSH(Base):
super().__init__( super().__init__(
uri, uri,
caps=self.Caps.LogOutput | self.Caps.Wd, caps=self.Caps.LogOutput | self.Caps.Wd | self.Caps.Interactive,
**kwargs **kwargs
) )
@ -82,6 +82,31 @@ class AsyncSSH(Base):
return tuple(args), kwargs return tuple(args), kwargs
@staticmethod
def _has_local_tty() -> bool:
try:
return sys.stdin.isatty() and sys.stdout.isatty()
except Exception:
return False
@staticmethod
def _get_local_term_size() -> tuple[int, int, int, int]:
cols, rows = shutil.get_terminal_size(fallback=(80, 24))
xpixel = ypixel = 0
try:
import fcntl, termios, struct
packed = fcntl.ioctl(sys.stdout.fileno(), termios.TIOCGWINSZ, b"\0" * 8)
rows2, cols2, xpixel, ypixel = struct.unpack("HHHH", packed)
if cols2 > 0 and rows2 > 0:
cols, rows = cols2, rows2
except Exception:
pass
return (cols, rows, xpixel, ypixel)
async def _read_stream( async def _read_stream(
self, self,
stream, stream,
@ -110,6 +135,219 @@ class AsyncSSH(Base):
if verbose and buf: if verbose and buf:
log(prio, log_prefix, buf.decode(log_enc, errors="replace")) log(prio, log_prefix, buf.decode(log_enc, errors="replace"))
async def _run_interactive_on_conn(
self,
conn: asyncssh.SSHClientConnection,
cmd: list[str],
wd: str | None,
cmd_input: str | None,
env: dict[str, str] | None,
) -> Result:
command = self._build_remote_command(cmd, wd)
stdout_parts: list[bytes] = []
proc = await conn.create_process(
command=command,
env=env,
stdin=asyncssh.PIPE,
stdout=asyncssh.PIPE,
stderr=asyncssh.STDOUT,
encoding=None,
request_pty="force",
term_type=self.term_type,
term_size=self._get_local_term_size(),
)
loop = asyncio.get_running_loop()
stdin_fd = sys.stdin.fileno()
stdin_queue: asyncio.Queue[bytes | None] = asyncio.Queue()
old_tty_state = None
old_winch_handler = None
stdin_reader_installed = False
def _write_local(data: bytes) -> None:
try:
sys.stdout.buffer.write(data)
sys.stdout.buffer.flush()
except AttributeError:
os.write(sys.stdout.fileno(), data)
def _on_stdin_ready() -> None:
try:
data = os.read(stdin_fd, 4096)
except OSError:
data = b""
if data:
stdin_queue.put_nowait(data)
else:
try:
loop.remove_reader(stdin_fd)
except Exception:
pass
stdin_queue.put_nowait(None)
async def _pump_stdin() -> None:
if cmd_input is not None and proc.stdin is not None:
proc.stdin.write(cmd_input.encode(sys.stdout.encoding or "utf-8"))
await proc.stdin.drain()
while True:
data = await stdin_queue.get()
if data is None:
if proc.stdin is not None:
try:
proc.stdin.write_eof()
except (BrokenPipeError, OSError):
pass
return
if proc.stdin is None:
return
proc.stdin.write(data)
await proc.stdin.drain()
async def _pump_stdout() -> None:
while True:
chunk = await proc.stdout.read(4096)
if not chunk:
break
stdout_parts.append(chunk)
_write_local(chunk)
def _on_winch(*_args) -> None:
try:
proc.change_terminal_size(*self._get_local_term_size())
except Exception:
pass
try:
sys.stdout.flush()
sys.stderr.flush()
try:
import termios, tty
old_tty_state = termios.tcgetattr(stdin_fd)
tty.setraw(stdin_fd)
except Exception:
old_tty_state = None
try:
loop.add_reader(stdin_fd, _on_stdin_ready)
stdin_reader_installed = True
except (NotImplementedError, RuntimeError):
stdin_queue.put_nowait(None)
if hasattr(signal, "SIGWINCH"):
try:
old_winch_handler = signal.getsignal(signal.SIGWINCH)
signal.signal(signal.SIGWINCH, _on_winch)
except Exception:
old_winch_handler = None
stdin_task = asyncio.create_task(_pump_stdin())
stdout_task = asyncio.create_task(_pump_stdout())
completed = await proc.wait(check=False)
await stdout_task
if not stdin_task.done():
stdin_task.cancel()
try:
await stdin_task
except asyncio.CancelledError:
pass
exit_code = completed.exit_status
if exit_code is None:
exit_code = completed.returncode if completed.returncode is not None else -1
stdout = b"".join(stdout_parts) if stdout_parts else None
return Result(stdout, None, exit_code)
finally:
if stdin_reader_installed:
try:
loop.remove_reader(stdin_fd)
except Exception:
pass
if old_winch_handler is not None and hasattr(signal, "SIGWINCH"):
try:
signal.signal(signal.SIGWINCH, old_winch_handler)
except Exception:
pass
if old_tty_state is not None:
try:
import termios
termios.tcsetattr(stdin_fd, termios.TCSADRAIN, old_tty_state)
except Exception:
pass
try:
sys.stdout.flush()
sys.stderr.flush()
except Exception:
pass
async def _run_captured_pty_on_conn(
self,
conn: asyncssh.SSHClientConnection,
cmd: list[str],
wd: str | None,
verbose: bool,
cmd_input: str | None,
env: dict[str, str] | None,
log_prefix: str,
) -> Result:
command = self._build_remote_command(cmd, wd)
stdout_parts: list[bytes] = []
stdout_log_enc = sys.stdout.encoding or "utf-8"
proc = await conn.create_process(
command=command,
env=env,
stdin=asyncssh.PIPE if cmd_input is not None else asyncssh.DEVNULL,
stdout=asyncssh.PIPE,
stderr=asyncssh.STDOUT,
encoding=None,
request_pty="force",
term_type=self.term_type,
)
task = asyncio.create_task(
self._read_stream(
proc.stdout,
NOTICE,
stdout_parts,
verbose=verbose,
log_prefix=log_prefix,
log_enc=stdout_log_enc,
)
)
if cmd_input is not None and proc.stdin is not None:
proc.stdin.write(cmd_input.encode(sys.stdout.encoding or "utf-8"))
await proc.stdin.drain()
proc.stdin.write_eof()
completed = await proc.wait(check=False)
await task
exit_code = completed.exit_status
if exit_code is None:
exit_code = completed.returncode if completed.returncode is not None else -1
stdout = b"".join(stdout_parts) if stdout_parts else None
return Result(stdout, None, exit_code)
async def _run_on_conn( async def _run_on_conn(
self, self,
conn: asyncssh.SSHClientConnection, conn: asyncssh.SSHClientConnection,
@ -121,6 +359,26 @@ class AsyncSSH(Base):
interactive: bool, interactive: bool,
log_prefix: str, log_prefix: str,
) -> Result: ) -> Result:
if interactive:
if self._has_local_tty():
return await self._run_interactive_on_conn(
conn,
cmd,
wd,
cmd_input,
env,
)
return await self._run_captured_pty_on_conn(
conn,
cmd,
wd,
verbose,
cmd_input,
env,
log_prefix,
)
command = self._build_remote_command(cmd, wd) command = self._build_remote_command(cmd, wd)
stdout_parts: list[bytes] = [] stdout_parts: list[bytes] = []
@ -130,17 +388,15 @@ class AsyncSSH(Base):
stderr_log_enc = sys.stderr.encoding or "utf-8" stderr_log_enc = sys.stderr.encoding or "utf-8"
stdin_mode = asyncssh.PIPE if cmd_input is not None else asyncssh.DEVNULL stdin_mode = asyncssh.PIPE if cmd_input is not None else asyncssh.DEVNULL
stderr_mode = asyncssh.STDOUT if interactive else asyncssh.PIPE
proc = await conn.create_process( proc = await conn.create_process(
command=command, command=command,
env=env, env=env,
stdin=stdin_mode, stdin=stdin_mode,
stdout=asyncssh.PIPE, stdout=asyncssh.PIPE,
stderr=stderr_mode, stderr=asyncssh.PIPE,
encoding=None, encoding=None,
request_pty="force" if interactive else False, request_pty=False,
term_type=self.term_type if interactive else None,
) )
tasks = [ tasks = [
@ -153,11 +409,7 @@ class AsyncSSH(Base):
log_prefix=log_prefix, log_prefix=log_prefix,
log_enc=stdout_log_enc, log_enc=stdout_log_enc,
) )
) ),
]
if not interactive:
tasks.append(
asyncio.create_task( asyncio.create_task(
self._read_stream( self._read_stream(
proc.stderr, proc.stderr,
@ -167,8 +419,8 @@ class AsyncSSH(Base):
log_prefix=log_prefix, log_prefix=log_prefix,
log_enc=stderr_log_enc, log_enc=stderr_log_enc,
) )
) ),
) ]
if cmd_input is not None and proc.stdin is not None: if cmd_input is not None and proc.stdin is not None:
proc.stdin.write(cmd_input.encode(sys.stdout.encoding or "utf-8")) proc.stdin.write(cmd_input.encode(sys.stdout.encoding or "utf-8"))
@ -179,7 +431,7 @@ class AsyncSSH(Base):
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
stdout = b"".join(stdout_parts) if stdout_parts else None stdout = b"".join(stdout_parts) if stdout_parts else None
stderr = None if interactive else (b"".join(stderr_parts) if stderr_parts else None) stderr = b"".join(stderr_parts) if stderr_parts else None
exit_code = completed.exit_status exit_code = completed.exit_status
if exit_code is None: if exit_code is None: