diff --git a/src/python/jw/pkg/lib/ec/SSHClient.py b/src/python/jw/pkg/lib/ec/SSHClient.py index d944d0be..d6aaffd3 100644 --- a/src/python/jw/pkg/lib/ec/SSHClient.py +++ b/src/python/jw/pkg/lib/ec/SSHClient.py @@ -129,9 +129,11 @@ class SSHClient(ExecContext): def ssh_client(*args, **kwargs) -> SSHClient: # export from importlib import import_module errors: list[str] = [] - for name in ['Paramiko', 'Exec']: + for name in ['AsyncSSH', 'Paramiko', 'Exec']: try: - return getattr(import_module(f'jw.pkg.lib.ec.ssh.{name}'), name)(*args, **kwargs) + ret = getattr(import_module(f'jw.pkg.lib.ec.ssh.{name}'), name)(*args, **kwargs) + log(INFO, f'Using SSH-client "{name}"') + return ret except Exception as e: msg = f'Can\'t instantiate SSH client class {name} ({str(e)})' errors.append(msg) diff --git a/src/python/jw/pkg/lib/ec/ssh/AsyncSSH.py b/src/python/jw/pkg/lib/ec/ssh/AsyncSSH.py new file mode 100644 index 00000000..43fc0444 --- /dev/null +++ b/src/python/jw/pkg/lib/ec/ssh/AsyncSSH.py @@ -0,0 +1,240 @@ +# -*- coding: utf-8 -*- + +import os, sys, shlex, asyncio, asyncssh + +from ...log import * +from ...ExecContext import Result +from ..SSHClient import SSHClient as Base + +_USE_DEFAULT_KNOWN_HOSTS = object() + +class AsyncSSH(Base): + def __init__( + self, + uri: str, + *, + client_keys: list[str] | None = None, + known_hosts=_USE_DEFAULT_KNOWN_HOSTS, + term_type: str | None = None, + connect_timeout: float | None = 30.0, + **kwargs, + ) -> None: + + super().__init__( + uri, + caps=self.Caps.LogOutput | self.Caps.Wd, + **kwargs + ) + + self.client_keys = client_keys + self.known_hosts = known_hosts + self.term_type = term_type or os.environ.get("TERM", "xterm") + self.connect_timeout = connect_timeout + + def _connect_kwargs(self) -> dict: + kwargs: dict = { + "host": self.hostname, + "port": self.port, + "username": self.username, + "password": self.password, + "client_keys": self.client_keys, + "connect_timeout": self.connect_timeout, + } + + if self.known_hosts is not _USE_DEFAULT_KNOWN_HOSTS: + kwargs["known_hosts"] = self.known_hosts + + return {k: v for k, v in kwargs.items() if v is not None} + + @staticmethod + def _build_remote_command(cmd: list[str], wd: str | None) -> str: + + if not cmd: + raise ValueError("cmd must not be empty") + + inner = f"exec {shlex.join(cmd)}" + if wd is not None: + inner = f"cd {shlex.quote(wd)} && {inner}" + + return f"/bin/sh -lc {shlex.quote(inner)}" + + @staticmethod + def _merge_env_into_forwarded_args( + args: tuple, + kwargs: dict, + mod_env: dict[str, str], + ) -> tuple[tuple, dict]: + args = list(args) + kwargs = dict(kwargs) + + if "env" in kwargs: + base_env = kwargs["env"] + merged_env = dict(base_env or {}) + merged_env.update(mod_env) + kwargs["env"] = merged_env or None + elif len(args) >= 4: + base_env = args[3] + merged_env = dict(base_env or {}) + merged_env.update(mod_env) + args[3] = merged_env or None + else: + kwargs["env"] = dict(mod_env) if mod_env else None + + return tuple(args), kwargs + + async def _read_stream( + self, + stream, + prio, + collector: list[bytes], + *, + verbose: bool, + log_prefix: str, + log_enc: str, + ) -> None: + buf = b"" + + while True: + chunk = await stream.read(4096) + if not chunk: + break + + collector.append(chunk) + + if verbose: + buf += chunk + while b"\n" in buf: + line, buf = buf.split(b"\n", 1) + log(prio, log_prefix, line.decode(log_enc, errors="replace")) + + if verbose and buf: + log(prio, log_prefix, buf.decode(log_enc, errors="replace")) + + async def _run_on_conn( + self, + conn: asyncssh.SSHClientConnection, + cmd: list[str], + wd: str | None, + verbose: bool, + cmd_input: str | None, + env: dict[str, str] | None, + interactive: bool, + log_prefix: str, + ) -> Result: + command = self._build_remote_command(cmd, wd) + + stdout_parts: list[bytes] = [] + stderr_parts: list[bytes] = [] + + stdout_log_enc = sys.stdout.encoding or "utf-8" + stderr_log_enc = sys.stderr.encoding or "utf-8" + + stdin_mode = asyncssh.PIPE if cmd_input is not None else asyncssh.DEVNULL + stderr_mode = asyncssh.STDOUT if interactive else asyncssh.PIPE + + proc = await conn.create_process( + command=command, + env=env, + stdin=stdin_mode, + stdout=asyncssh.PIPE, + stderr=stderr_mode, + encoding=None, + request_pty="force" if interactive else False, + term_type=self.term_type if interactive else None, + ) + + tasks = [ + asyncio.create_task( + self._read_stream( + proc.stdout, + NOTICE, + stdout_parts, + verbose=verbose, + log_prefix=log_prefix, + log_enc=stdout_log_enc, + ) + ) + ] + + if not interactive: + tasks.append( + asyncio.create_task( + self._read_stream( + proc.stderr, + ERR, + stderr_parts, + verbose=verbose, + log_prefix=log_prefix, + log_enc=stderr_log_enc, + ) + ) + ) + + if cmd_input is not None and proc.stdin is not None: + proc.stdin.write(cmd_input.encode(sys.stdout.encoding or "utf-8")) + await proc.stdin.drain() + proc.stdin.write_eof() + + completed = await proc.wait(check=False) + await asyncio.gather(*tasks) + + stdout = b"".join(stdout_parts) if stdout_parts else None + stderr = None if interactive else (b"".join(stderr_parts) if stderr_parts else None) + + exit_code = completed.exit_status + if exit_code is None: + exit_code = completed.returncode if completed.returncode is not None else -1 + + return Result(stdout, stderr, exit_code) + + async def _run_ssh( + self, + cmd: list[str], + wd: str | None, + verbose: bool, + cmd_input: str | None, + env: dict[str, str] | None, + interactive: bool, + log_prefix: str, + ) -> Result: + async with asyncssh.connect(**self._connect_kwargs()) as conn: + return await self._run_on_conn( + conn, + cmd, + wd, + verbose, + cmd_input, + env, + interactive, + log_prefix, + ) + + async def _sudo( + self, + cmd: list[str], + mod_env: dict[str, str], + opts: list[str], + *args, + **kwargs, + ) -> Result: + args, kwargs = self._merge_env_into_forwarded_args(args, kwargs, mod_env) + + async with asyncssh.connect(**self._connect_kwargs()) as conn: + uid_result = await conn.run("id -u", check=False) + is_root = ( + uid_result.exit_status == 0 + and isinstance(uid_result.stdout, str) + and uid_result.stdout.strip() == "0" + ) + + cmdline: list[str] = [] + + if not is_root: + cmdline.append("/usr/bin/sudo") + if mod_env: + cmdline.append("--preserve-env=" + ",".join(mod_env.keys())) + cmdline.extend(opts) + + cmdline.extend(cmd) + + return await self._run_on_conn(conn, cmdline, *args, **kwargs)