mirror of
ssh://git.janware.com/janware/proj/jw-pkg
synced 2026-04-24 09:13:37 +02:00
lib.ec.ssh.AsyncSSH: Add class
Add a SSHClient implementation using AsyncSSH. This is the first and currently only class derived from SSHClient which implements SSHClient.Cap.LogOutput, designed to consume and log command output as it streams in. It felt like the lower hanging fruit not to do that with Paramiko: Paramiko doesn't provide a native async API, so it would need to spawn additional worker threads. I think. Signed-off-by: Jan Lindemann <jan@janware.com>
This commit is contained in:
parent
279b7789e2
commit
737cbc3e24
2 changed files with 244 additions and 2 deletions
|
|
@ -129,9 +129,11 @@ class SSHClient(ExecContext):
|
|||
def ssh_client(*args, **kwargs) -> SSHClient: # export
|
||||
from importlib import import_module
|
||||
errors: list[str] = []
|
||||
for name in ['Paramiko', 'Exec']:
|
||||
for name in ['AsyncSSH', 'Paramiko', 'Exec']:
|
||||
try:
|
||||
return getattr(import_module(f'jw.pkg.lib.ec.ssh.{name}'), name)(*args, **kwargs)
|
||||
ret = getattr(import_module(f'jw.pkg.lib.ec.ssh.{name}'), name)(*args, **kwargs)
|
||||
log(INFO, f'Using SSH-client "{name}"')
|
||||
return ret
|
||||
except Exception as e:
|
||||
msg = f'Can\'t instantiate SSH client class {name} ({str(e)})'
|
||||
errors.append(msg)
|
||||
|
|
|
|||
240
src/python/jw/pkg/lib/ec/ssh/AsyncSSH.py
Normal file
240
src/python/jw/pkg/lib/ec/ssh/AsyncSSH.py
Normal file
|
|
@ -0,0 +1,240 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import os, sys, shlex, asyncio, asyncssh
|
||||
|
||||
from ...log import *
|
||||
from ...ExecContext import Result
|
||||
from ..SSHClient import SSHClient as Base
|
||||
|
||||
_USE_DEFAULT_KNOWN_HOSTS = object()
|
||||
|
||||
class AsyncSSH(Base):
|
||||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
*,
|
||||
client_keys: list[str] | None = None,
|
||||
known_hosts=_USE_DEFAULT_KNOWN_HOSTS,
|
||||
term_type: str | None = None,
|
||||
connect_timeout: float | None = 30.0,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
|
||||
super().__init__(
|
||||
uri,
|
||||
caps=self.Caps.LogOutput | self.Caps.Wd,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
self.client_keys = client_keys
|
||||
self.known_hosts = known_hosts
|
||||
self.term_type = term_type or os.environ.get("TERM", "xterm")
|
||||
self.connect_timeout = connect_timeout
|
||||
|
||||
def _connect_kwargs(self) -> dict:
|
||||
kwargs: dict = {
|
||||
"host": self.hostname,
|
||||
"port": self.port,
|
||||
"username": self.username,
|
||||
"password": self.password,
|
||||
"client_keys": self.client_keys,
|
||||
"connect_timeout": self.connect_timeout,
|
||||
}
|
||||
|
||||
if self.known_hosts is not _USE_DEFAULT_KNOWN_HOSTS:
|
||||
kwargs["known_hosts"] = self.known_hosts
|
||||
|
||||
return {k: v for k, v in kwargs.items() if v is not None}
|
||||
|
||||
@staticmethod
|
||||
def _build_remote_command(cmd: list[str], wd: str | None) -> str:
|
||||
|
||||
if not cmd:
|
||||
raise ValueError("cmd must not be empty")
|
||||
|
||||
inner = f"exec {shlex.join(cmd)}"
|
||||
if wd is not None:
|
||||
inner = f"cd {shlex.quote(wd)} && {inner}"
|
||||
|
||||
return f"/bin/sh -lc {shlex.quote(inner)}"
|
||||
|
||||
@staticmethod
|
||||
def _merge_env_into_forwarded_args(
|
||||
args: tuple,
|
||||
kwargs: dict,
|
||||
mod_env: dict[str, str],
|
||||
) -> tuple[tuple, dict]:
|
||||
args = list(args)
|
||||
kwargs = dict(kwargs)
|
||||
|
||||
if "env" in kwargs:
|
||||
base_env = kwargs["env"]
|
||||
merged_env = dict(base_env or {})
|
||||
merged_env.update(mod_env)
|
||||
kwargs["env"] = merged_env or None
|
||||
elif len(args) >= 4:
|
||||
base_env = args[3]
|
||||
merged_env = dict(base_env or {})
|
||||
merged_env.update(mod_env)
|
||||
args[3] = merged_env or None
|
||||
else:
|
||||
kwargs["env"] = dict(mod_env) if mod_env else None
|
||||
|
||||
return tuple(args), kwargs
|
||||
|
||||
async def _read_stream(
|
||||
self,
|
||||
stream,
|
||||
prio,
|
||||
collector: list[bytes],
|
||||
*,
|
||||
verbose: bool,
|
||||
log_prefix: str,
|
||||
log_enc: str,
|
||||
) -> None:
|
||||
buf = b""
|
||||
|
||||
while True:
|
||||
chunk = await stream.read(4096)
|
||||
if not chunk:
|
||||
break
|
||||
|
||||
collector.append(chunk)
|
||||
|
||||
if verbose:
|
||||
buf += chunk
|
||||
while b"\n" in buf:
|
||||
line, buf = buf.split(b"\n", 1)
|
||||
log(prio, log_prefix, line.decode(log_enc, errors="replace"))
|
||||
|
||||
if verbose and buf:
|
||||
log(prio, log_prefix, buf.decode(log_enc, errors="replace"))
|
||||
|
||||
async def _run_on_conn(
|
||||
self,
|
||||
conn: asyncssh.SSHClientConnection,
|
||||
cmd: list[str],
|
||||
wd: str | None,
|
||||
verbose: bool,
|
||||
cmd_input: str | None,
|
||||
env: dict[str, str] | None,
|
||||
interactive: bool,
|
||||
log_prefix: str,
|
||||
) -> Result:
|
||||
command = self._build_remote_command(cmd, wd)
|
||||
|
||||
stdout_parts: list[bytes] = []
|
||||
stderr_parts: list[bytes] = []
|
||||
|
||||
stdout_log_enc = sys.stdout.encoding or "utf-8"
|
||||
stderr_log_enc = sys.stderr.encoding or "utf-8"
|
||||
|
||||
stdin_mode = asyncssh.PIPE if cmd_input is not None else asyncssh.DEVNULL
|
||||
stderr_mode = asyncssh.STDOUT if interactive else asyncssh.PIPE
|
||||
|
||||
proc = await conn.create_process(
|
||||
command=command,
|
||||
env=env,
|
||||
stdin=stdin_mode,
|
||||
stdout=asyncssh.PIPE,
|
||||
stderr=stderr_mode,
|
||||
encoding=None,
|
||||
request_pty="force" if interactive else False,
|
||||
term_type=self.term_type if interactive else None,
|
||||
)
|
||||
|
||||
tasks = [
|
||||
asyncio.create_task(
|
||||
self._read_stream(
|
||||
proc.stdout,
|
||||
NOTICE,
|
||||
stdout_parts,
|
||||
verbose=verbose,
|
||||
log_prefix=log_prefix,
|
||||
log_enc=stdout_log_enc,
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
if not interactive:
|
||||
tasks.append(
|
||||
asyncio.create_task(
|
||||
self._read_stream(
|
||||
proc.stderr,
|
||||
ERR,
|
||||
stderr_parts,
|
||||
verbose=verbose,
|
||||
log_prefix=log_prefix,
|
||||
log_enc=stderr_log_enc,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if cmd_input is not None and proc.stdin is not None:
|
||||
proc.stdin.write(cmd_input.encode(sys.stdout.encoding or "utf-8"))
|
||||
await proc.stdin.drain()
|
||||
proc.stdin.write_eof()
|
||||
|
||||
completed = await proc.wait(check=False)
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
stdout = b"".join(stdout_parts) if stdout_parts else None
|
||||
stderr = None if interactive else (b"".join(stderr_parts) if stderr_parts else None)
|
||||
|
||||
exit_code = completed.exit_status
|
||||
if exit_code is None:
|
||||
exit_code = completed.returncode if completed.returncode is not None else -1
|
||||
|
||||
return Result(stdout, stderr, exit_code)
|
||||
|
||||
async def _run_ssh(
|
||||
self,
|
||||
cmd: list[str],
|
||||
wd: str | None,
|
||||
verbose: bool,
|
||||
cmd_input: str | None,
|
||||
env: dict[str, str] | None,
|
||||
interactive: bool,
|
||||
log_prefix: str,
|
||||
) -> Result:
|
||||
async with asyncssh.connect(**self._connect_kwargs()) as conn:
|
||||
return await self._run_on_conn(
|
||||
conn,
|
||||
cmd,
|
||||
wd,
|
||||
verbose,
|
||||
cmd_input,
|
||||
env,
|
||||
interactive,
|
||||
log_prefix,
|
||||
)
|
||||
|
||||
async def _sudo(
|
||||
self,
|
||||
cmd: list[str],
|
||||
mod_env: dict[str, str],
|
||||
opts: list[str],
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> Result:
|
||||
args, kwargs = self._merge_env_into_forwarded_args(args, kwargs, mod_env)
|
||||
|
||||
async with asyncssh.connect(**self._connect_kwargs()) as conn:
|
||||
uid_result = await conn.run("id -u", check=False)
|
||||
is_root = (
|
||||
uid_result.exit_status == 0
|
||||
and isinstance(uid_result.stdout, str)
|
||||
and uid_result.stdout.strip() == "0"
|
||||
)
|
||||
|
||||
cmdline: list[str] = []
|
||||
|
||||
if not is_root:
|
||||
cmdline.append("/usr/bin/sudo")
|
||||
if mod_env:
|
||||
cmdline.append("--preserve-env=" + ",".join(mod_env.keys()))
|
||||
cmdline.extend(opts)
|
||||
|
||||
cmdline.extend(cmd)
|
||||
|
||||
return await self._run_on_conn(conn, cmdline, *args, **kwargs)
|
||||
Loading…
Add table
Add a link
Reference in a new issue