diff --git a/src/python/jw/pkg/lib/ec/SSHClient.py b/src/python/jw/pkg/lib/ec/SSHClient.py index 3505bd83..6df4a4d7 100644 --- a/src/python/jw/pkg/lib/ec/SSHClient.py +++ b/src/python/jw/pkg/lib/ec/SSHClient.py @@ -3,6 +3,7 @@ from typing import Any import os, abc, sys +from enum import Flag, auto from ..util import pretty_cmd from ..log import * @@ -11,8 +12,15 @@ from urllib.parse import urlparse class SSHClient(ExecContext): - def __init__(self, uri: str, *args, **kwargs) -> None: + class Caps(Flag): + LogOutput = auto() + Interactive = auto() + Env = auto() + Wd = auto() + + def __init__(self, uri: str, caps: Caps=Caps(0), *args, **kwargs) -> None: super().__init__(uri=uri, *args, **kwargs) + self.__caps = caps try: parsed = urlparse(uri) except Exception as e: @@ -23,7 +31,15 @@ class SSHClient(ExecContext): self.__username = parsed.username @abc.abstractmethod - async def _run_ssh(self, cmd: list[str]) -> Result: + async def _run_ssh( + cmd: list[str], + wd: str|None, + verbose: bool, + cmd_input: str|None, + env: dict[str, str]|None, + interactive: bool, + log_prefix: str + ) -> Result: pass async def _run( @@ -41,6 +57,8 @@ class SSHClient(ExecContext): log(prio, log_prefix, *args) def __log_block(prio: int, title: str, block: str): + if self.__caps & self.Caps.LogOutput: + return encoding = sys.stdout.encoding or 'utf-8' block = block.decode(encoding).strip() if not block: @@ -51,21 +69,31 @@ class SSHClient(ExecContext): __log(prio, '|', line) __log(prio, f'`{delim}') - if wd is not None: + if wd is not None and not self.__caps & self.Caps.Wd: cmd = ['cd', wd, '&&', *cmd] - if interactive: + if interactive and not self.__caps & self.Caps.Interactive: raise NotImplementedError('Interactive SSH is not yet implemented') - if env is not None: + if env is not None and not self.__caps & self.Caps.Env: raise NotImplementedError('Passing an environment to SSH commands is not yet implemented') - ret = await self._run_ssh(cmd, cmd_input=cmd_input) + ret = await self._run_ssh( + cmd=cmd, + wd=wd, + verbose=verbose, + cmd_input=cmd_input, + env=env, + interactive=interactive, + log_prefix=log_prefix + ) + if verbose: __log_block(NOTICE, 'stdout', ret.stdout) __log_block(NOTICE, 'stderr', ret.stderr) if ret.status != 0: __log(WARNING, f'Exit code {ret.status}') + return ret async def _sudo(self, cmd: list[str], mod_env: dict[str, str], opts: list[str], *args, **kwargs) -> Result: diff --git a/src/python/jw/pkg/lib/ec/ssh/Exec.py b/src/python/jw/pkg/lib/ec/ssh/Exec.py index aaca27f7..4fbbcd21 100644 --- a/src/python/jw/pkg/lib/ec/ssh/Exec.py +++ b/src/python/jw/pkg/lib/ec/ssh/Exec.py @@ -39,7 +39,7 @@ class Exec(Base): # export self.__askpass_orig[key] = os.getenv(key) os.environ[key] = val - async def _run_ssh(self, cmd: list[str], cmd_input: str|None) -> Result: + async def _run_ssh(self, cmd: list[str], cmd_input: str|None, *args, **kwargs) -> Result: self.__init_askpass() return await run_cmd(['ssh', self.hostname, shlex.join(cmd)], cmd_input=cmd_input) diff --git a/src/python/jw/pkg/lib/ec/ssh/Paramiko.py b/src/python/jw/pkg/lib/ec/ssh/Paramiko.py index 8c26d09e..8d67204f 100644 --- a/src/python/jw/pkg/lib/ec/ssh/Paramiko.py +++ b/src/python/jw/pkg/lib/ec/ssh/Paramiko.py @@ -43,7 +43,7 @@ class Paramiko(Base): # export def __scp(self): return SCPClient(self.__ssh.get_transport()) - async def _run_ssh(self, cmd: list[str], cmd_input: str|None) -> Result: + async def _run_ssh(self, cmd: list[str], cmd_input: str|None, *args, **kwargs) -> Result: try: stdin, stdout, stderr = self.__ssh.exec_command(shlex.join(cmd), timeout=self.__timeout) except Exception as e: