lib.ec.SSHClient.__init__(): Add parameter caps

Add an optional caps ("capabilities") argument to the constructor of
SSHClient. It is meant to be used by derived classes in order to
declare that they don't want the base class to handle a default
behaviour for a certain capability, but that they want to implement
it themselves instead.

Also, give the _run_ssh() callbacks the necessary info as parameters,
so that the derived classes have the means to do so.

Signed-off-by: Jan Lindemann <jan@janware.com>
This commit is contained in:
Jan Lindemann 2026-03-21 03:41:10 +01:00
commit 3a84408436
3 changed files with 36 additions and 8 deletions

View file

@ -3,6 +3,7 @@
from typing import Any
import os, abc, sys
from enum import Flag, auto
from ..util import pretty_cmd
from ..log import *
@ -11,8 +12,15 @@ from urllib.parse import urlparse
class SSHClient(ExecContext):
def __init__(self, uri: str, *args, **kwargs) -> None:
class Caps(Flag):
LogOutput = auto()
Interactive = auto()
Env = auto()
Wd = auto()
def __init__(self, uri: str, caps: Caps=Caps(0), *args, **kwargs) -> None:
super().__init__(uri=uri, *args, **kwargs)
self.__caps = caps
try:
parsed = urlparse(uri)
except Exception as e:
@ -23,7 +31,15 @@ class SSHClient(ExecContext):
self.__username = parsed.username
@abc.abstractmethod
async def _run_ssh(self, cmd: list[str]) -> Result:
async def _run_ssh(
cmd: list[str],
wd: str|None,
verbose: bool,
cmd_input: str|None,
env: dict[str, str]|None,
interactive: bool,
log_prefix: str
) -> Result:
pass
async def _run(
@ -41,6 +57,8 @@ class SSHClient(ExecContext):
log(prio, log_prefix, *args)
def __log_block(prio: int, title: str, block: str):
if self.__caps & self.Caps.LogOutput:
return
encoding = sys.stdout.encoding or 'utf-8'
block = block.decode(encoding).strip()
if not block:
@ -51,21 +69,31 @@ class SSHClient(ExecContext):
__log(prio, '|', line)
__log(prio, f'`{delim}')
if wd is not None:
if wd is not None and not self.__caps & self.Caps.Wd:
cmd = ['cd', wd, '&&', *cmd]
if interactive:
if interactive and not self.__caps & self.Caps.Interactive:
raise NotImplementedError('Interactive SSH is not yet implemented')
if env is not None:
if env is not None and not self.__caps & self.Caps.Env:
raise NotImplementedError('Passing an environment to SSH commands is not yet implemented')
ret = await self._run_ssh(cmd, cmd_input=cmd_input)
ret = await self._run_ssh(
cmd=cmd,
wd=wd,
verbose=verbose,
cmd_input=cmd_input,
env=env,
interactive=interactive,
log_prefix=log_prefix
)
if verbose:
__log_block(NOTICE, 'stdout', ret.stdout)
__log_block(NOTICE, 'stderr', ret.stderr)
if ret.status != 0:
__log(WARNING, f'Exit code {ret.status}')
return ret
async def _sudo(self, cmd: list[str], mod_env: dict[str, str], opts: list[str], *args, **kwargs) -> Result:

View file

@ -39,7 +39,7 @@ class Exec(Base): # export
self.__askpass_orig[key] = os.getenv(key)
os.environ[key] = val
async def _run_ssh(self, cmd: list[str], cmd_input: str|None) -> Result:
async def _run_ssh(self, cmd: list[str], cmd_input: str|None, *args, **kwargs) -> Result:
self.__init_askpass()
return await run_cmd(['ssh', self.hostname, shlex.join(cmd)], cmd_input=cmd_input)

View file

@ -43,7 +43,7 @@ class Paramiko(Base): # export
def __scp(self):
return SCPClient(self.__ssh.get_transport())
async def _run_ssh(self, cmd: list[str], cmd_input: str|None) -> Result:
async def _run_ssh(self, cmd: list[str], cmd_input: str|None, *args, **kwargs) -> Result:
try:
stdin, stdout, stderr = self.__ssh.exec_command(shlex.join(cmd), timeout=self.__timeout)
except Exception as e: