From 737cbc3e24bb9cf3a29ff9a5b00946fd81aac0b7 Mon Sep 17 00:00:00 2001 From: Jan Lindemann Date: Sat, 21 Mar 2026 04:29:58 +0100 Subject: [PATCH] lib.ec.ssh.AsyncSSH: Add class Add a SSHClient implementation using AsyncSSH. This is the first and currently only class derived from SSHClient which implements SSHClient.Cap.LogOutput, designed to consume and log command output as it streams in. It felt like the lower hanging fruit not to do that with Paramiko: Paramiko doesn't provide a native async API, so it would need to spawn additional worker threads. I think. Signed-off-by: Jan Lindemann --- src/python/jw/pkg/lib/ec/SSHClient.py | 6 +- src/python/jw/pkg/lib/ec/ssh/AsyncSSH.py | 240 +++++++++++++++++++++++ 2 files changed, 244 insertions(+), 2 deletions(-) create mode 100644 src/python/jw/pkg/lib/ec/ssh/AsyncSSH.py diff --git a/src/python/jw/pkg/lib/ec/SSHClient.py b/src/python/jw/pkg/lib/ec/SSHClient.py index d944d0be..d6aaffd3 100644 --- a/src/python/jw/pkg/lib/ec/SSHClient.py +++ b/src/python/jw/pkg/lib/ec/SSHClient.py @@ -129,9 +129,11 @@ class SSHClient(ExecContext): def ssh_client(*args, **kwargs) -> SSHClient: # export from importlib import import_module errors: list[str] = [] - for name in ['Paramiko', 'Exec']: + for name in ['AsyncSSH', 'Paramiko', 'Exec']: try: - return getattr(import_module(f'jw.pkg.lib.ec.ssh.{name}'), name)(*args, **kwargs) + ret = getattr(import_module(f'jw.pkg.lib.ec.ssh.{name}'), name)(*args, **kwargs) + log(INFO, f'Using SSH-client "{name}"') + return ret except Exception as e: msg = f'Can\'t instantiate SSH client class {name} ({str(e)})' errors.append(msg) diff --git a/src/python/jw/pkg/lib/ec/ssh/AsyncSSH.py b/src/python/jw/pkg/lib/ec/ssh/AsyncSSH.py new file mode 100644 index 00000000..43fc0444 --- /dev/null +++ b/src/python/jw/pkg/lib/ec/ssh/AsyncSSH.py @@ -0,0 +1,240 @@ +# -*- coding: utf-8 -*- + +import os, sys, shlex, asyncio, asyncssh + +from ...log import * +from ...ExecContext import Result +from ..SSHClient import SSHClient as Base + +_USE_DEFAULT_KNOWN_HOSTS = object() + +class AsyncSSH(Base): + def __init__( + self, + uri: str, + *, + client_keys: list[str] | None = None, + known_hosts=_USE_DEFAULT_KNOWN_HOSTS, + term_type: str | None = None, + connect_timeout: float | None = 30.0, + **kwargs, + ) -> None: + + super().__init__( + uri, + caps=self.Caps.LogOutput | self.Caps.Wd, + **kwargs + ) + + self.client_keys = client_keys + self.known_hosts = known_hosts + self.term_type = term_type or os.environ.get("TERM", "xterm") + self.connect_timeout = connect_timeout + + def _connect_kwargs(self) -> dict: + kwargs: dict = { + "host": self.hostname, + "port": self.port, + "username": self.username, + "password": self.password, + "client_keys": self.client_keys, + "connect_timeout": self.connect_timeout, + } + + if self.known_hosts is not _USE_DEFAULT_KNOWN_HOSTS: + kwargs["known_hosts"] = self.known_hosts + + return {k: v for k, v in kwargs.items() if v is not None} + + @staticmethod + def _build_remote_command(cmd: list[str], wd: str | None) -> str: + + if not cmd: + raise ValueError("cmd must not be empty") + + inner = f"exec {shlex.join(cmd)}" + if wd is not None: + inner = f"cd {shlex.quote(wd)} && {inner}" + + return f"/bin/sh -lc {shlex.quote(inner)}" + + @staticmethod + def _merge_env_into_forwarded_args( + args: tuple, + kwargs: dict, + mod_env: dict[str, str], + ) -> tuple[tuple, dict]: + args = list(args) + kwargs = dict(kwargs) + + if "env" in kwargs: + base_env = kwargs["env"] + merged_env = dict(base_env or {}) + merged_env.update(mod_env) + kwargs["env"] = merged_env or None + elif len(args) >= 4: + base_env = args[3] + merged_env = dict(base_env or {}) + merged_env.update(mod_env) + args[3] = merged_env or None + else: + kwargs["env"] = dict(mod_env) if mod_env else None + + return tuple(args), kwargs + + async def _read_stream( + self, + stream, + prio, + collector: list[bytes], + *, + verbose: bool, + log_prefix: str, + log_enc: str, + ) -> None: + buf = b"" + + while True: + chunk = await stream.read(4096) + if not chunk: + break + + collector.append(chunk) + + if verbose: + buf += chunk + while b"\n" in buf: + line, buf = buf.split(b"\n", 1) + log(prio, log_prefix, line.decode(log_enc, errors="replace")) + + if verbose and buf: + log(prio, log_prefix, buf.decode(log_enc, errors="replace")) + + async def _run_on_conn( + self, + conn: asyncssh.SSHClientConnection, + cmd: list[str], + wd: str | None, + verbose: bool, + cmd_input: str | None, + env: dict[str, str] | None, + interactive: bool, + log_prefix: str, + ) -> Result: + command = self._build_remote_command(cmd, wd) + + stdout_parts: list[bytes] = [] + stderr_parts: list[bytes] = [] + + stdout_log_enc = sys.stdout.encoding or "utf-8" + stderr_log_enc = sys.stderr.encoding or "utf-8" + + stdin_mode = asyncssh.PIPE if cmd_input is not None else asyncssh.DEVNULL + stderr_mode = asyncssh.STDOUT if interactive else asyncssh.PIPE + + proc = await conn.create_process( + command=command, + env=env, + stdin=stdin_mode, + stdout=asyncssh.PIPE, + stderr=stderr_mode, + encoding=None, + request_pty="force" if interactive else False, + term_type=self.term_type if interactive else None, + ) + + tasks = [ + asyncio.create_task( + self._read_stream( + proc.stdout, + NOTICE, + stdout_parts, + verbose=verbose, + log_prefix=log_prefix, + log_enc=stdout_log_enc, + ) + ) + ] + + if not interactive: + tasks.append( + asyncio.create_task( + self._read_stream( + proc.stderr, + ERR, + stderr_parts, + verbose=verbose, + log_prefix=log_prefix, + log_enc=stderr_log_enc, + ) + ) + ) + + if cmd_input is not None and proc.stdin is not None: + proc.stdin.write(cmd_input.encode(sys.stdout.encoding or "utf-8")) + await proc.stdin.drain() + proc.stdin.write_eof() + + completed = await proc.wait(check=False) + await asyncio.gather(*tasks) + + stdout = b"".join(stdout_parts) if stdout_parts else None + stderr = None if interactive else (b"".join(stderr_parts) if stderr_parts else None) + + exit_code = completed.exit_status + if exit_code is None: + exit_code = completed.returncode if completed.returncode is not None else -1 + + return Result(stdout, stderr, exit_code) + + async def _run_ssh( + self, + cmd: list[str], + wd: str | None, + verbose: bool, + cmd_input: str | None, + env: dict[str, str] | None, + interactive: bool, + log_prefix: str, + ) -> Result: + async with asyncssh.connect(**self._connect_kwargs()) as conn: + return await self._run_on_conn( + conn, + cmd, + wd, + verbose, + cmd_input, + env, + interactive, + log_prefix, + ) + + async def _sudo( + self, + cmd: list[str], + mod_env: dict[str, str], + opts: list[str], + *args, + **kwargs, + ) -> Result: + args, kwargs = self._merge_env_into_forwarded_args(args, kwargs, mod_env) + + async with asyncssh.connect(**self._connect_kwargs()) as conn: + uid_result = await conn.run("id -u", check=False) + is_root = ( + uid_result.exit_status == 0 + and isinstance(uid_result.stdout, str) + and uid_result.stdout.strip() == "0" + ) + + cmdline: list[str] = [] + + if not is_root: + cmdline.append("/usr/bin/sudo") + if mod_env: + cmdline.append("--preserve-env=" + ",".join(mod_env.keys())) + cmdline.extend(opts) + + cmdline.extend(cmd) + + return await self._run_on_conn(conn, cmdline, *args, **kwargs)