lib.ec.ssh.AsyncSSH: Add class

Add a SSHClient implementation using AsyncSSH. This is the first and
currently only class derived from SSHClient which implements
SSHClient.Cap.LogOutput, designed to consume and log command output
as it streams in. It felt like the lower hanging fruit not to do that
with Paramiko: Paramiko doesn't provide a native async API, so it
would need to spawn additional worker threads. I think.

Signed-off-by: Jan Lindemann <jan@janware.com>
This commit is contained in:
Jan Lindemann 2026-03-21 04:29:58 +01:00
commit 737cbc3e24
2 changed files with 244 additions and 2 deletions

View file

@ -129,9 +129,11 @@ class SSHClient(ExecContext):
def ssh_client(*args, **kwargs) -> SSHClient: # export def ssh_client(*args, **kwargs) -> SSHClient: # export
from importlib import import_module from importlib import import_module
errors: list[str] = [] errors: list[str] = []
for name in ['Paramiko', 'Exec']: for name in ['AsyncSSH', 'Paramiko', 'Exec']:
try: try:
return getattr(import_module(f'jw.pkg.lib.ec.ssh.{name}'), name)(*args, **kwargs) ret = getattr(import_module(f'jw.pkg.lib.ec.ssh.{name}'), name)(*args, **kwargs)
log(INFO, f'Using SSH-client "{name}"')
return ret
except Exception as e: except Exception as e:
msg = f'Can\'t instantiate SSH client class {name} ({str(e)})' msg = f'Can\'t instantiate SSH client class {name} ({str(e)})'
errors.append(msg) errors.append(msg)

View file

@ -0,0 +1,240 @@
# -*- coding: utf-8 -*-
import os, sys, shlex, asyncio, asyncssh
from ...log import *
from ...ExecContext import Result
from ..SSHClient import SSHClient as Base
_USE_DEFAULT_KNOWN_HOSTS = object()
class AsyncSSH(Base):
def __init__(
self,
uri: str,
*,
client_keys: list[str] | None = None,
known_hosts=_USE_DEFAULT_KNOWN_HOSTS,
term_type: str | None = None,
connect_timeout: float | None = 30.0,
**kwargs,
) -> None:
super().__init__(
uri,
caps=self.Caps.LogOutput | self.Caps.Wd,
**kwargs
)
self.client_keys = client_keys
self.known_hosts = known_hosts
self.term_type = term_type or os.environ.get("TERM", "xterm")
self.connect_timeout = connect_timeout
def _connect_kwargs(self) -> dict:
kwargs: dict = {
"host": self.hostname,
"port": self.port,
"username": self.username,
"password": self.password,
"client_keys": self.client_keys,
"connect_timeout": self.connect_timeout,
}
if self.known_hosts is not _USE_DEFAULT_KNOWN_HOSTS:
kwargs["known_hosts"] = self.known_hosts
return {k: v for k, v in kwargs.items() if v is not None}
@staticmethod
def _build_remote_command(cmd: list[str], wd: str | None) -> str:
if not cmd:
raise ValueError("cmd must not be empty")
inner = f"exec {shlex.join(cmd)}"
if wd is not None:
inner = f"cd {shlex.quote(wd)} && {inner}"
return f"/bin/sh -lc {shlex.quote(inner)}"
@staticmethod
def _merge_env_into_forwarded_args(
args: tuple,
kwargs: dict,
mod_env: dict[str, str],
) -> tuple[tuple, dict]:
args = list(args)
kwargs = dict(kwargs)
if "env" in kwargs:
base_env = kwargs["env"]
merged_env = dict(base_env or {})
merged_env.update(mod_env)
kwargs["env"] = merged_env or None
elif len(args) >= 4:
base_env = args[3]
merged_env = dict(base_env or {})
merged_env.update(mod_env)
args[3] = merged_env or None
else:
kwargs["env"] = dict(mod_env) if mod_env else None
return tuple(args), kwargs
async def _read_stream(
self,
stream,
prio,
collector: list[bytes],
*,
verbose: bool,
log_prefix: str,
log_enc: str,
) -> None:
buf = b""
while True:
chunk = await stream.read(4096)
if not chunk:
break
collector.append(chunk)
if verbose:
buf += chunk
while b"\n" in buf:
line, buf = buf.split(b"\n", 1)
log(prio, log_prefix, line.decode(log_enc, errors="replace"))
if verbose and buf:
log(prio, log_prefix, buf.decode(log_enc, errors="replace"))
async def _run_on_conn(
self,
conn: asyncssh.SSHClientConnection,
cmd: list[str],
wd: str | None,
verbose: bool,
cmd_input: str | None,
env: dict[str, str] | None,
interactive: bool,
log_prefix: str,
) -> Result:
command = self._build_remote_command(cmd, wd)
stdout_parts: list[bytes] = []
stderr_parts: list[bytes] = []
stdout_log_enc = sys.stdout.encoding or "utf-8"
stderr_log_enc = sys.stderr.encoding or "utf-8"
stdin_mode = asyncssh.PIPE if cmd_input is not None else asyncssh.DEVNULL
stderr_mode = asyncssh.STDOUT if interactive else asyncssh.PIPE
proc = await conn.create_process(
command=command,
env=env,
stdin=stdin_mode,
stdout=asyncssh.PIPE,
stderr=stderr_mode,
encoding=None,
request_pty="force" if interactive else False,
term_type=self.term_type if interactive else None,
)
tasks = [
asyncio.create_task(
self._read_stream(
proc.stdout,
NOTICE,
stdout_parts,
verbose=verbose,
log_prefix=log_prefix,
log_enc=stdout_log_enc,
)
)
]
if not interactive:
tasks.append(
asyncio.create_task(
self._read_stream(
proc.stderr,
ERR,
stderr_parts,
verbose=verbose,
log_prefix=log_prefix,
log_enc=stderr_log_enc,
)
)
)
if cmd_input is not None and proc.stdin is not None:
proc.stdin.write(cmd_input.encode(sys.stdout.encoding or "utf-8"))
await proc.stdin.drain()
proc.stdin.write_eof()
completed = await proc.wait(check=False)
await asyncio.gather(*tasks)
stdout = b"".join(stdout_parts) if stdout_parts else None
stderr = None if interactive else (b"".join(stderr_parts) if stderr_parts else None)
exit_code = completed.exit_status
if exit_code is None:
exit_code = completed.returncode if completed.returncode is not None else -1
return Result(stdout, stderr, exit_code)
async def _run_ssh(
self,
cmd: list[str],
wd: str | None,
verbose: bool,
cmd_input: str | None,
env: dict[str, str] | None,
interactive: bool,
log_prefix: str,
) -> Result:
async with asyncssh.connect(**self._connect_kwargs()) as conn:
return await self._run_on_conn(
conn,
cmd,
wd,
verbose,
cmd_input,
env,
interactive,
log_prefix,
)
async def _sudo(
self,
cmd: list[str],
mod_env: dict[str, str],
opts: list[str],
*args,
**kwargs,
) -> Result:
args, kwargs = self._merge_env_into_forwarded_args(args, kwargs, mod_env)
async with asyncssh.connect(**self._connect_kwargs()) as conn:
uid_result = await conn.run("id -u", check=False)
is_root = (
uid_result.exit_status == 0
and isinstance(uid_result.stdout, str)
and uid_result.stdout.strip() == "0"
)
cmdline: list[str] = []
if not is_root:
cmdline.append("/usr/bin/sudo")
if mod_env:
cmdline.append("--preserve-env=" + ",".join(mod_env.keys()))
cmdline.extend(opts)
cmdline.extend(cmd)
return await self._run_on_conn(conn, cmdline, *args, **kwargs)