mirror of
ssh://git.janware.com/janware/proj/jw-pkg
synced 2026-04-24 09:13:37 +02:00
lib.ec.ssh.AsyncSSH: Add interactivity
Request a remote PTY from AsyncSSH, and wire the local terminal's stdin up with it if interactive == True. This gives a real interactive session if local stdin belongs to a terminal. Also, thanks to AsyncSSH understanding that, forward terminal size changes to the remote end. Signed-off-by: Jan Lindemann <jan@janware.com>
This commit is contained in:
parent
737cbc3e24
commit
b21d2d1c21
1 changed files with 274 additions and 22 deletions
|
|
@ -1,6 +1,6 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import os, sys, shlex, asyncio, asyncssh
|
||||
import os, sys, shlex, asyncio, asyncssh, shutil, signal
|
||||
|
||||
from ...log import *
|
||||
from ...ExecContext import Result
|
||||
|
|
@ -22,7 +22,7 @@ class AsyncSSH(Base):
|
|||
|
||||
super().__init__(
|
||||
uri,
|
||||
caps=self.Caps.LogOutput | self.Caps.Wd,
|
||||
caps=self.Caps.LogOutput | self.Caps.Wd | self.Caps.Interactive,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
|
|
@ -82,6 +82,31 @@ class AsyncSSH(Base):
|
|||
|
||||
return tuple(args), kwargs
|
||||
|
||||
@staticmethod
|
||||
def _has_local_tty() -> bool:
|
||||
try:
|
||||
return sys.stdin.isatty() and sys.stdout.isatty()
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _get_local_term_size() -> tuple[int, int, int, int]:
|
||||
cols, rows = shutil.get_terminal_size(fallback=(80, 24))
|
||||
xpixel = ypixel = 0
|
||||
|
||||
try:
|
||||
import fcntl, termios, struct
|
||||
|
||||
packed = fcntl.ioctl(sys.stdout.fileno(), termios.TIOCGWINSZ, b"\0" * 8)
|
||||
rows2, cols2, xpixel, ypixel = struct.unpack("HHHH", packed)
|
||||
|
||||
if cols2 > 0 and rows2 > 0:
|
||||
cols, rows = cols2, rows2
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return (cols, rows, xpixel, ypixel)
|
||||
|
||||
async def _read_stream(
|
||||
self,
|
||||
stream,
|
||||
|
|
@ -110,6 +135,219 @@ class AsyncSSH(Base):
|
|||
if verbose and buf:
|
||||
log(prio, log_prefix, buf.decode(log_enc, errors="replace"))
|
||||
|
||||
async def _run_interactive_on_conn(
|
||||
self,
|
||||
conn: asyncssh.SSHClientConnection,
|
||||
cmd: list[str],
|
||||
wd: str | None,
|
||||
cmd_input: str | None,
|
||||
env: dict[str, str] | None,
|
||||
) -> Result:
|
||||
command = self._build_remote_command(cmd, wd)
|
||||
stdout_parts: list[bytes] = []
|
||||
|
||||
proc = await conn.create_process(
|
||||
command=command,
|
||||
env=env,
|
||||
stdin=asyncssh.PIPE,
|
||||
stdout=asyncssh.PIPE,
|
||||
stderr=asyncssh.STDOUT,
|
||||
encoding=None,
|
||||
request_pty="force",
|
||||
term_type=self.term_type,
|
||||
term_size=self._get_local_term_size(),
|
||||
)
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
stdin_fd = sys.stdin.fileno()
|
||||
|
||||
stdin_queue: asyncio.Queue[bytes | None] = asyncio.Queue()
|
||||
old_tty_state = None
|
||||
old_winch_handler = None
|
||||
stdin_reader_installed = False
|
||||
|
||||
def _write_local(data: bytes) -> None:
|
||||
try:
|
||||
sys.stdout.buffer.write(data)
|
||||
sys.stdout.buffer.flush()
|
||||
except AttributeError:
|
||||
os.write(sys.stdout.fileno(), data)
|
||||
|
||||
def _on_stdin_ready() -> None:
|
||||
try:
|
||||
data = os.read(stdin_fd, 4096)
|
||||
except OSError:
|
||||
data = b""
|
||||
|
||||
if data:
|
||||
stdin_queue.put_nowait(data)
|
||||
else:
|
||||
try:
|
||||
loop.remove_reader(stdin_fd)
|
||||
except Exception:
|
||||
pass
|
||||
stdin_queue.put_nowait(None)
|
||||
|
||||
async def _pump_stdin() -> None:
|
||||
if cmd_input is not None and proc.stdin is not None:
|
||||
proc.stdin.write(cmd_input.encode(sys.stdout.encoding or "utf-8"))
|
||||
await proc.stdin.drain()
|
||||
|
||||
while True:
|
||||
data = await stdin_queue.get()
|
||||
|
||||
if data is None:
|
||||
if proc.stdin is not None:
|
||||
try:
|
||||
proc.stdin.write_eof()
|
||||
except (BrokenPipeError, OSError):
|
||||
pass
|
||||
return
|
||||
|
||||
if proc.stdin is None:
|
||||
return
|
||||
|
||||
proc.stdin.write(data)
|
||||
await proc.stdin.drain()
|
||||
|
||||
async def _pump_stdout() -> None:
|
||||
while True:
|
||||
chunk = await proc.stdout.read(4096)
|
||||
if not chunk:
|
||||
break
|
||||
|
||||
stdout_parts.append(chunk)
|
||||
_write_local(chunk)
|
||||
|
||||
def _on_winch(*_args) -> None:
|
||||
try:
|
||||
proc.change_terminal_size(*self._get_local_term_size())
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
sys.stdout.flush()
|
||||
sys.stderr.flush()
|
||||
|
||||
try:
|
||||
import termios, tty
|
||||
|
||||
old_tty_state = termios.tcgetattr(stdin_fd)
|
||||
tty.setraw(stdin_fd)
|
||||
except Exception:
|
||||
old_tty_state = None
|
||||
|
||||
try:
|
||||
loop.add_reader(stdin_fd, _on_stdin_ready)
|
||||
stdin_reader_installed = True
|
||||
except (NotImplementedError, RuntimeError):
|
||||
stdin_queue.put_nowait(None)
|
||||
|
||||
if hasattr(signal, "SIGWINCH"):
|
||||
try:
|
||||
old_winch_handler = signal.getsignal(signal.SIGWINCH)
|
||||
signal.signal(signal.SIGWINCH, _on_winch)
|
||||
except Exception:
|
||||
old_winch_handler = None
|
||||
|
||||
stdin_task = asyncio.create_task(_pump_stdin())
|
||||
stdout_task = asyncio.create_task(_pump_stdout())
|
||||
|
||||
completed = await proc.wait(check=False)
|
||||
await stdout_task
|
||||
|
||||
if not stdin_task.done():
|
||||
stdin_task.cancel()
|
||||
try:
|
||||
await stdin_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
exit_code = completed.exit_status
|
||||
if exit_code is None:
|
||||
exit_code = completed.returncode if completed.returncode is not None else -1
|
||||
|
||||
stdout = b"".join(stdout_parts) if stdout_parts else None
|
||||
return Result(stdout, None, exit_code)
|
||||
|
||||
finally:
|
||||
if stdin_reader_installed:
|
||||
try:
|
||||
loop.remove_reader(stdin_fd)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if old_winch_handler is not None and hasattr(signal, "SIGWINCH"):
|
||||
try:
|
||||
signal.signal(signal.SIGWINCH, old_winch_handler)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if old_tty_state is not None:
|
||||
try:
|
||||
import termios
|
||||
termios.tcsetattr(stdin_fd, termios.TCSADRAIN, old_tty_state)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
sys.stdout.flush()
|
||||
sys.stderr.flush()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def _run_captured_pty_on_conn(
|
||||
self,
|
||||
conn: asyncssh.SSHClientConnection,
|
||||
cmd: list[str],
|
||||
wd: str | None,
|
||||
verbose: bool,
|
||||
cmd_input: str | None,
|
||||
env: dict[str, str] | None,
|
||||
log_prefix: str,
|
||||
) -> Result:
|
||||
command = self._build_remote_command(cmd, wd)
|
||||
|
||||
stdout_parts: list[bytes] = []
|
||||
stdout_log_enc = sys.stdout.encoding or "utf-8"
|
||||
|
||||
proc = await conn.create_process(
|
||||
command=command,
|
||||
env=env,
|
||||
stdin=asyncssh.PIPE if cmd_input is not None else asyncssh.DEVNULL,
|
||||
stdout=asyncssh.PIPE,
|
||||
stderr=asyncssh.STDOUT,
|
||||
encoding=None,
|
||||
request_pty="force",
|
||||
term_type=self.term_type,
|
||||
)
|
||||
|
||||
task = asyncio.create_task(
|
||||
self._read_stream(
|
||||
proc.stdout,
|
||||
NOTICE,
|
||||
stdout_parts,
|
||||
verbose=verbose,
|
||||
log_prefix=log_prefix,
|
||||
log_enc=stdout_log_enc,
|
||||
)
|
||||
)
|
||||
|
||||
if cmd_input is not None and proc.stdin is not None:
|
||||
proc.stdin.write(cmd_input.encode(sys.stdout.encoding or "utf-8"))
|
||||
await proc.stdin.drain()
|
||||
proc.stdin.write_eof()
|
||||
|
||||
completed = await proc.wait(check=False)
|
||||
await task
|
||||
|
||||
exit_code = completed.exit_status
|
||||
if exit_code is None:
|
||||
exit_code = completed.returncode if completed.returncode is not None else -1
|
||||
|
||||
stdout = b"".join(stdout_parts) if stdout_parts else None
|
||||
return Result(stdout, None, exit_code)
|
||||
|
||||
async def _run_on_conn(
|
||||
self,
|
||||
conn: asyncssh.SSHClientConnection,
|
||||
|
|
@ -121,6 +359,26 @@ class AsyncSSH(Base):
|
|||
interactive: bool,
|
||||
log_prefix: str,
|
||||
) -> Result:
|
||||
if interactive:
|
||||
if self._has_local_tty():
|
||||
return await self._run_interactive_on_conn(
|
||||
conn,
|
||||
cmd,
|
||||
wd,
|
||||
cmd_input,
|
||||
env,
|
||||
)
|
||||
|
||||
return await self._run_captured_pty_on_conn(
|
||||
conn,
|
||||
cmd,
|
||||
wd,
|
||||
verbose,
|
||||
cmd_input,
|
||||
env,
|
||||
log_prefix,
|
||||
)
|
||||
|
||||
command = self._build_remote_command(cmd, wd)
|
||||
|
||||
stdout_parts: list[bytes] = []
|
||||
|
|
@ -130,17 +388,15 @@ class AsyncSSH(Base):
|
|||
stderr_log_enc = sys.stderr.encoding or "utf-8"
|
||||
|
||||
stdin_mode = asyncssh.PIPE if cmd_input is not None else asyncssh.DEVNULL
|
||||
stderr_mode = asyncssh.STDOUT if interactive else asyncssh.PIPE
|
||||
|
||||
proc = await conn.create_process(
|
||||
command=command,
|
||||
env=env,
|
||||
stdin=stdin_mode,
|
||||
stdout=asyncssh.PIPE,
|
||||
stderr=stderr_mode,
|
||||
stderr=asyncssh.PIPE,
|
||||
encoding=None,
|
||||
request_pty="force" if interactive else False,
|
||||
term_type=self.term_type if interactive else None,
|
||||
request_pty=False,
|
||||
)
|
||||
|
||||
tasks = [
|
||||
|
|
@ -153,22 +409,18 @@ class AsyncSSH(Base):
|
|||
log_prefix=log_prefix,
|
||||
log_enc=stdout_log_enc,
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
if not interactive:
|
||||
tasks.append(
|
||||
asyncio.create_task(
|
||||
self._read_stream(
|
||||
proc.stderr,
|
||||
ERR,
|
||||
stderr_parts,
|
||||
verbose=verbose,
|
||||
log_prefix=log_prefix,
|
||||
log_enc=stderr_log_enc,
|
||||
)
|
||||
),
|
||||
asyncio.create_task(
|
||||
self._read_stream(
|
||||
proc.stderr,
|
||||
ERR,
|
||||
stderr_parts,
|
||||
verbose=verbose,
|
||||
log_prefix=log_prefix,
|
||||
log_enc=stderr_log_enc,
|
||||
)
|
||||
)
|
||||
),
|
||||
]
|
||||
|
||||
if cmd_input is not None and proc.stdin is not None:
|
||||
proc.stdin.write(cmd_input.encode(sys.stdout.encoding or "utf-8"))
|
||||
|
|
@ -179,7 +431,7 @@ class AsyncSSH(Base):
|
|||
await asyncio.gather(*tasks)
|
||||
|
||||
stdout = b"".join(stdout_parts) if stdout_parts else None
|
||||
stderr = None if interactive else (b"".join(stderr_parts) if stderr_parts else None)
|
||||
stderr = b"".join(stderr_parts) if stderr_parts else None
|
||||
|
||||
exit_code = completed.exit_status
|
||||
if exit_code is None:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue