mirror of
ssh://git.janware.com/janware/proj/jw-pkg
synced 2026-04-25 17:45:55 +02:00
lib.ec.ssh.AsyncSSH: Add interactivity
Request a remote PTY from AsyncSSH, and wire the local terminal's stdin up with it if interactive == True. This gives a real interactive session if local stdin belongs to a terminal. Also, thanks to AsyncSSH understanding that, forward terminal size changes to the remote end. Signed-off-by: Jan Lindemann <jan@janware.com>
This commit is contained in:
parent
737cbc3e24
commit
b21d2d1c21
1 changed files with 274 additions and 22 deletions
|
|
@ -1,6 +1,6 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
import os, sys, shlex, asyncio, asyncssh
|
import os, sys, shlex, asyncio, asyncssh, shutil, signal
|
||||||
|
|
||||||
from ...log import *
|
from ...log import *
|
||||||
from ...ExecContext import Result
|
from ...ExecContext import Result
|
||||||
|
|
@ -22,7 +22,7 @@ class AsyncSSH(Base):
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
uri,
|
uri,
|
||||||
caps=self.Caps.LogOutput | self.Caps.Wd,
|
caps=self.Caps.LogOutput | self.Caps.Wd | self.Caps.Interactive,
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -82,6 +82,31 @@ class AsyncSSH(Base):
|
||||||
|
|
||||||
return tuple(args), kwargs
|
return tuple(args), kwargs
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _has_local_tty() -> bool:
|
||||||
|
try:
|
||||||
|
return sys.stdin.isatty() and sys.stdout.isatty()
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_local_term_size() -> tuple[int, int, int, int]:
|
||||||
|
cols, rows = shutil.get_terminal_size(fallback=(80, 24))
|
||||||
|
xpixel = ypixel = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
import fcntl, termios, struct
|
||||||
|
|
||||||
|
packed = fcntl.ioctl(sys.stdout.fileno(), termios.TIOCGWINSZ, b"\0" * 8)
|
||||||
|
rows2, cols2, xpixel, ypixel = struct.unpack("HHHH", packed)
|
||||||
|
|
||||||
|
if cols2 > 0 and rows2 > 0:
|
||||||
|
cols, rows = cols2, rows2
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return (cols, rows, xpixel, ypixel)
|
||||||
|
|
||||||
async def _read_stream(
|
async def _read_stream(
|
||||||
self,
|
self,
|
||||||
stream,
|
stream,
|
||||||
|
|
@ -110,6 +135,219 @@ class AsyncSSH(Base):
|
||||||
if verbose and buf:
|
if verbose and buf:
|
||||||
log(prio, log_prefix, buf.decode(log_enc, errors="replace"))
|
log(prio, log_prefix, buf.decode(log_enc, errors="replace"))
|
||||||
|
|
||||||
|
async def _run_interactive_on_conn(
|
||||||
|
self,
|
||||||
|
conn: asyncssh.SSHClientConnection,
|
||||||
|
cmd: list[str],
|
||||||
|
wd: str | None,
|
||||||
|
cmd_input: str | None,
|
||||||
|
env: dict[str, str] | None,
|
||||||
|
) -> Result:
|
||||||
|
command = self._build_remote_command(cmd, wd)
|
||||||
|
stdout_parts: list[bytes] = []
|
||||||
|
|
||||||
|
proc = await conn.create_process(
|
||||||
|
command=command,
|
||||||
|
env=env,
|
||||||
|
stdin=asyncssh.PIPE,
|
||||||
|
stdout=asyncssh.PIPE,
|
||||||
|
stderr=asyncssh.STDOUT,
|
||||||
|
encoding=None,
|
||||||
|
request_pty="force",
|
||||||
|
term_type=self.term_type,
|
||||||
|
term_size=self._get_local_term_size(),
|
||||||
|
)
|
||||||
|
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
stdin_fd = sys.stdin.fileno()
|
||||||
|
|
||||||
|
stdin_queue: asyncio.Queue[bytes | None] = asyncio.Queue()
|
||||||
|
old_tty_state = None
|
||||||
|
old_winch_handler = None
|
||||||
|
stdin_reader_installed = False
|
||||||
|
|
||||||
|
def _write_local(data: bytes) -> None:
|
||||||
|
try:
|
||||||
|
sys.stdout.buffer.write(data)
|
||||||
|
sys.stdout.buffer.flush()
|
||||||
|
except AttributeError:
|
||||||
|
os.write(sys.stdout.fileno(), data)
|
||||||
|
|
||||||
|
def _on_stdin_ready() -> None:
|
||||||
|
try:
|
||||||
|
data = os.read(stdin_fd, 4096)
|
||||||
|
except OSError:
|
||||||
|
data = b""
|
||||||
|
|
||||||
|
if data:
|
||||||
|
stdin_queue.put_nowait(data)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
loop.remove_reader(stdin_fd)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
stdin_queue.put_nowait(None)
|
||||||
|
|
||||||
|
async def _pump_stdin() -> None:
|
||||||
|
if cmd_input is not None and proc.stdin is not None:
|
||||||
|
proc.stdin.write(cmd_input.encode(sys.stdout.encoding or "utf-8"))
|
||||||
|
await proc.stdin.drain()
|
||||||
|
|
||||||
|
while True:
|
||||||
|
data = await stdin_queue.get()
|
||||||
|
|
||||||
|
if data is None:
|
||||||
|
if proc.stdin is not None:
|
||||||
|
try:
|
||||||
|
proc.stdin.write_eof()
|
||||||
|
except (BrokenPipeError, OSError):
|
||||||
|
pass
|
||||||
|
return
|
||||||
|
|
||||||
|
if proc.stdin is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
proc.stdin.write(data)
|
||||||
|
await proc.stdin.drain()
|
||||||
|
|
||||||
|
async def _pump_stdout() -> None:
|
||||||
|
while True:
|
||||||
|
chunk = await proc.stdout.read(4096)
|
||||||
|
if not chunk:
|
||||||
|
break
|
||||||
|
|
||||||
|
stdout_parts.append(chunk)
|
||||||
|
_write_local(chunk)
|
||||||
|
|
||||||
|
def _on_winch(*_args) -> None:
|
||||||
|
try:
|
||||||
|
proc.change_terminal_size(*self._get_local_term_size())
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
sys.stdout.flush()
|
||||||
|
sys.stderr.flush()
|
||||||
|
|
||||||
|
try:
|
||||||
|
import termios, tty
|
||||||
|
|
||||||
|
old_tty_state = termios.tcgetattr(stdin_fd)
|
||||||
|
tty.setraw(stdin_fd)
|
||||||
|
except Exception:
|
||||||
|
old_tty_state = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
loop.add_reader(stdin_fd, _on_stdin_ready)
|
||||||
|
stdin_reader_installed = True
|
||||||
|
except (NotImplementedError, RuntimeError):
|
||||||
|
stdin_queue.put_nowait(None)
|
||||||
|
|
||||||
|
if hasattr(signal, "SIGWINCH"):
|
||||||
|
try:
|
||||||
|
old_winch_handler = signal.getsignal(signal.SIGWINCH)
|
||||||
|
signal.signal(signal.SIGWINCH, _on_winch)
|
||||||
|
except Exception:
|
||||||
|
old_winch_handler = None
|
||||||
|
|
||||||
|
stdin_task = asyncio.create_task(_pump_stdin())
|
||||||
|
stdout_task = asyncio.create_task(_pump_stdout())
|
||||||
|
|
||||||
|
completed = await proc.wait(check=False)
|
||||||
|
await stdout_task
|
||||||
|
|
||||||
|
if not stdin_task.done():
|
||||||
|
stdin_task.cancel()
|
||||||
|
try:
|
||||||
|
await stdin_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
exit_code = completed.exit_status
|
||||||
|
if exit_code is None:
|
||||||
|
exit_code = completed.returncode if completed.returncode is not None else -1
|
||||||
|
|
||||||
|
stdout = b"".join(stdout_parts) if stdout_parts else None
|
||||||
|
return Result(stdout, None, exit_code)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
if stdin_reader_installed:
|
||||||
|
try:
|
||||||
|
loop.remove_reader(stdin_fd)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if old_winch_handler is not None and hasattr(signal, "SIGWINCH"):
|
||||||
|
try:
|
||||||
|
signal.signal(signal.SIGWINCH, old_winch_handler)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if old_tty_state is not None:
|
||||||
|
try:
|
||||||
|
import termios
|
||||||
|
termios.tcsetattr(stdin_fd, termios.TCSADRAIN, old_tty_state)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
sys.stdout.flush()
|
||||||
|
sys.stderr.flush()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def _run_captured_pty_on_conn(
|
||||||
|
self,
|
||||||
|
conn: asyncssh.SSHClientConnection,
|
||||||
|
cmd: list[str],
|
||||||
|
wd: str | None,
|
||||||
|
verbose: bool,
|
||||||
|
cmd_input: str | None,
|
||||||
|
env: dict[str, str] | None,
|
||||||
|
log_prefix: str,
|
||||||
|
) -> Result:
|
||||||
|
command = self._build_remote_command(cmd, wd)
|
||||||
|
|
||||||
|
stdout_parts: list[bytes] = []
|
||||||
|
stdout_log_enc = sys.stdout.encoding or "utf-8"
|
||||||
|
|
||||||
|
proc = await conn.create_process(
|
||||||
|
command=command,
|
||||||
|
env=env,
|
||||||
|
stdin=asyncssh.PIPE if cmd_input is not None else asyncssh.DEVNULL,
|
||||||
|
stdout=asyncssh.PIPE,
|
||||||
|
stderr=asyncssh.STDOUT,
|
||||||
|
encoding=None,
|
||||||
|
request_pty="force",
|
||||||
|
term_type=self.term_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
task = asyncio.create_task(
|
||||||
|
self._read_stream(
|
||||||
|
proc.stdout,
|
||||||
|
NOTICE,
|
||||||
|
stdout_parts,
|
||||||
|
verbose=verbose,
|
||||||
|
log_prefix=log_prefix,
|
||||||
|
log_enc=stdout_log_enc,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if cmd_input is not None and proc.stdin is not None:
|
||||||
|
proc.stdin.write(cmd_input.encode(sys.stdout.encoding or "utf-8"))
|
||||||
|
await proc.stdin.drain()
|
||||||
|
proc.stdin.write_eof()
|
||||||
|
|
||||||
|
completed = await proc.wait(check=False)
|
||||||
|
await task
|
||||||
|
|
||||||
|
exit_code = completed.exit_status
|
||||||
|
if exit_code is None:
|
||||||
|
exit_code = completed.returncode if completed.returncode is not None else -1
|
||||||
|
|
||||||
|
stdout = b"".join(stdout_parts) if stdout_parts else None
|
||||||
|
return Result(stdout, None, exit_code)
|
||||||
|
|
||||||
async def _run_on_conn(
|
async def _run_on_conn(
|
||||||
self,
|
self,
|
||||||
conn: asyncssh.SSHClientConnection,
|
conn: asyncssh.SSHClientConnection,
|
||||||
|
|
@ -121,6 +359,26 @@ class AsyncSSH(Base):
|
||||||
interactive: bool,
|
interactive: bool,
|
||||||
log_prefix: str,
|
log_prefix: str,
|
||||||
) -> Result:
|
) -> Result:
|
||||||
|
if interactive:
|
||||||
|
if self._has_local_tty():
|
||||||
|
return await self._run_interactive_on_conn(
|
||||||
|
conn,
|
||||||
|
cmd,
|
||||||
|
wd,
|
||||||
|
cmd_input,
|
||||||
|
env,
|
||||||
|
)
|
||||||
|
|
||||||
|
return await self._run_captured_pty_on_conn(
|
||||||
|
conn,
|
||||||
|
cmd,
|
||||||
|
wd,
|
||||||
|
verbose,
|
||||||
|
cmd_input,
|
||||||
|
env,
|
||||||
|
log_prefix,
|
||||||
|
)
|
||||||
|
|
||||||
command = self._build_remote_command(cmd, wd)
|
command = self._build_remote_command(cmd, wd)
|
||||||
|
|
||||||
stdout_parts: list[bytes] = []
|
stdout_parts: list[bytes] = []
|
||||||
|
|
@ -130,17 +388,15 @@ class AsyncSSH(Base):
|
||||||
stderr_log_enc = sys.stderr.encoding or "utf-8"
|
stderr_log_enc = sys.stderr.encoding or "utf-8"
|
||||||
|
|
||||||
stdin_mode = asyncssh.PIPE if cmd_input is not None else asyncssh.DEVNULL
|
stdin_mode = asyncssh.PIPE if cmd_input is not None else asyncssh.DEVNULL
|
||||||
stderr_mode = asyncssh.STDOUT if interactive else asyncssh.PIPE
|
|
||||||
|
|
||||||
proc = await conn.create_process(
|
proc = await conn.create_process(
|
||||||
command=command,
|
command=command,
|
||||||
env=env,
|
env=env,
|
||||||
stdin=stdin_mode,
|
stdin=stdin_mode,
|
||||||
stdout=asyncssh.PIPE,
|
stdout=asyncssh.PIPE,
|
||||||
stderr=stderr_mode,
|
stderr=asyncssh.PIPE,
|
||||||
encoding=None,
|
encoding=None,
|
||||||
request_pty="force" if interactive else False,
|
request_pty=False,
|
||||||
term_type=self.term_type if interactive else None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
tasks = [
|
tasks = [
|
||||||
|
|
@ -153,11 +409,7 @@ class AsyncSSH(Base):
|
||||||
log_prefix=log_prefix,
|
log_prefix=log_prefix,
|
||||||
log_enc=stdout_log_enc,
|
log_enc=stdout_log_enc,
|
||||||
)
|
)
|
||||||
)
|
),
|
||||||
]
|
|
||||||
|
|
||||||
if not interactive:
|
|
||||||
tasks.append(
|
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
self._read_stream(
|
self._read_stream(
|
||||||
proc.stderr,
|
proc.stderr,
|
||||||
|
|
@ -167,8 +419,8 @@ class AsyncSSH(Base):
|
||||||
log_prefix=log_prefix,
|
log_prefix=log_prefix,
|
||||||
log_enc=stderr_log_enc,
|
log_enc=stderr_log_enc,
|
||||||
)
|
)
|
||||||
)
|
),
|
||||||
)
|
]
|
||||||
|
|
||||||
if cmd_input is not None and proc.stdin is not None:
|
if cmd_input is not None and proc.stdin is not None:
|
||||||
proc.stdin.write(cmd_input.encode(sys.stdout.encoding or "utf-8"))
|
proc.stdin.write(cmd_input.encode(sys.stdout.encoding or "utf-8"))
|
||||||
|
|
@ -179,7 +431,7 @@ class AsyncSSH(Base):
|
||||||
await asyncio.gather(*tasks)
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
stdout = b"".join(stdout_parts) if stdout_parts else None
|
stdout = b"".join(stdout_parts) if stdout_parts else None
|
||||||
stderr = None if interactive else (b"".join(stderr_parts) if stderr_parts else None)
|
stderr = b"".join(stderr_parts) if stderr_parts else None
|
||||||
|
|
||||||
exit_code = completed.exit_status
|
exit_code = completed.exit_status
|
||||||
if exit_code is None:
|
if exit_code is None:
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue