diff --git a/src/python/jw/pkg/lib/ec/SSHClient.py b/src/python/jw/pkg/lib/ec/SSHClient.py index 72e02de5..3505bd83 100644 --- a/src/python/jw/pkg/lib/ec/SSHClient.py +++ b/src/python/jw/pkg/lib/ec/SSHClient.py @@ -2,9 +2,9 @@ from typing import Any -import os, abc, shlex, sys +import os, abc, sys -from ..util import run_cmd, pretty_cmd +from ..util import pretty_cmd from ..log import * from ..ExecContext import ExecContext, Result from urllib.parse import urlparse @@ -19,7 +19,7 @@ class SSHClient(ExecContext): log(ERR, f'Failed to parse SSH URI "{uri}"') raise self.__hostname = parsed.hostname - self.__password: parsed.password + self.__password = parsed.password self.__username = parsed.username @abc.abstractmethod @@ -93,89 +93,18 @@ class SSHClient(ExecContext): def username(self) -> str: return self.__username -class SSHClientInternal(SSHClient): # export - - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - self.__timeout: float|None = None # Untested - self.___ssh: Any|None = None - - def __ssh_connect(self): - import paramiko # type: ignore # error: Library stubs not installed for "paramiko" - ret = paramiko.SSHClient() - ret.set_missing_host_key_policy(paramiko.AutoAddPolicy()) - try: - ret.connect( - hostname=self.hostname, - username=self.username, - allow_agent=True - ) - except Exception as e: - log(ERR, f'Failed to connect to {self.hostname} with key file {path_to_key} ({str(e)})') - raise - s = ret.get_transport().open_session() - # set up the agent request handler to handle agent requests from the server - paramiko.agent.AgentRequestHandler(s) - return ret - - @property - def __ssh(self): - if self.___ssh is None: - self.___ssh = self.__ssh_connect() - return self.___ssh - - @property - def __scp(self): - return SCPClient(self.__ssh.get_transport()) - - async def _run_ssh(self, cmd: list[str], cmd_input: str|None) -> Result: - try: - stdin, stdout, stderr = self.__ssh.exec_command(shlex.join(cmd), timeout=self.__timeout) - except Exception as e: - log(ERR, f'Command failed for {self.uri}: "{shlex.join(cmd)}"') - raise - if cmd_input is not None: - stdin.write(cmd_input) - exit_status = stdout.channel.recv_exit_status() - return Result(stdout.read(), stderr.read(), exit_status) - -class SSHClientCmd(SSHClient): # export - - def __init__(self, *args, **kwargs) -> None: - self.__askpass: str|None = None - self.__askpass_orig: dict[str, str|None] = dict() - super().__init__(*args, **kwargs) - - def __del__(self): - for key, val in self.__askpass_orig.items(): - if val is None: - del os.environ[key] - else: - os.environ[key] = val - if self.__askpass is not None: - os.remove(self.__askpass) - - def __init_askpass(self): - if self.__askpass is None and self.password is not None: - import sys, tempfile - prefix = os.path.basename(sys.argv[0]) + '-' - f = tempfile.NamedTemporaryFile(mode='w+t', prefix=prefix, delete=False) - os.chmod(f.name, 0o0700) - self.__askpass = f.name - f.write(f'#!/bin/bash\n\necho -n "{self.password}\n"') - f.close() - for key, val in {'SSH_ASKPASS': self.__askpass, 'SSH_ASKPASS_REQUIRE': 'force'}.items(): - self.__askpass_orig[key] = os.getenv(key) - os.environ[key] = val - - async def _run_ssh(self, cmd: list[str], cmd_input: str|None) -> Result: - self.__init_askpass() - return await run_cmd(['ssh', self.hostname, shlex.join(cmd)], - output_encoding='bytes', cmd_input=cmd_input) - def ssh_client(*args, **kwargs) -> SSHClient: # export - try: - return SSHClientInternal(*args, **kwargs) - except: - pass - return SSHClientCmd(*args, **kwargs) + from importlib import import_module + errors: list[str] = [] + for name in ['Paramiko', 'Exec']: + try: + return getattr(import_module(f'jw.pkg.lib.ec.ssh.{name}'), name)(*args, **kwargs) + except Exception as e: + msg = f'Can\'t instantiate SSH client class {name} ({str(e)})' + errors.append(msg) + log(DEBUG, f'{msg}, trying next') + msg = f'No working SSH clients for {" ".join(args)}' + log(ERR, f'----- {msg}') + for error in errors: + log(ERR, error) + raise Exception(msg) diff --git a/src/python/jw/pkg/lib/ec/ssh/Exec.py b/src/python/jw/pkg/lib/ec/ssh/Exec.py new file mode 100644 index 00000000..aaca27f7 --- /dev/null +++ b/src/python/jw/pkg/lib/ec/ssh/Exec.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import shlex + +from ...util import run_cmd +from ..SSHClient import SSHClient as Base + +if TYPE_CHECKING: + from ...ExecContext import Result + +class Exec(Base): # export + + def __init__(self, *args, **kwargs) -> None: + self.__askpass: str|None = None + self.__askpass_orig: dict[str, str|None] = dict() + super().__init__(*args, **kwargs) + + def __del__(self): + for key, val in self.__askpass_orig.items(): + if val is None: + del os.environ[key] + else: + os.environ[key] = val + if self.__askpass is not None: + os.remove(self.__askpass) + + def __init_askpass(self): + if self.__askpass is None and self.password is not None: + import sys, tempfile + prefix = os.path.basename(sys.argv[0]) + '-' + f = tempfile.NamedTemporaryFile(mode='w+t', prefix=prefix, delete=False) + os.chmod(f.name, 0o0700) + self.__askpass = f.name + f.write(f'#!/bin/bash\n\necho -n "{self.password}\n"') + f.close() + for key, val in {'SSH_ASKPASS': self.__askpass, 'SSH_ASKPASS_REQUIRE': 'force'}.items(): + self.__askpass_orig[key] = os.getenv(key) + os.environ[key] = val + + async def _run_ssh(self, cmd: list[str], cmd_input: str|None) -> Result: + self.__init_askpass() + return await run_cmd(['ssh', self.hostname, shlex.join(cmd)], cmd_input=cmd_input) + diff --git a/src/python/jw/pkg/lib/ec/ssh/Makefile b/src/python/jw/pkg/lib/ec/ssh/Makefile new file mode 100644 index 00000000..19fedac9 --- /dev/null +++ b/src/python/jw/pkg/lib/ec/ssh/Makefile @@ -0,0 +1,4 @@ +TOPDIR = ../../../../../../.. + +include $(TOPDIR)/make/proj.mk +include $(JWBDIR)/make/py-mod.mk diff --git a/src/python/jw/pkg/lib/ec/ssh/Paramiko.py b/src/python/jw/pkg/lib/ec/ssh/Paramiko.py new file mode 100644 index 00000000..8c26d09e --- /dev/null +++ b/src/python/jw/pkg/lib/ec/ssh/Paramiko.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import paramiko # type: ignore # error: Library stubs not installed for "paramiko" +import shlex + +from ...log import * +from ...ExecContext import Result +from ..SSHClient import SSHClient as Base + +class Paramiko(Base): # export + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.__timeout: float|None = None # Untested + self.___ssh: Any|None = None + + def __ssh_connect(self): + ret = paramiko.SSHClient() + ret.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + try: + ret.connect( + hostname=self.hostname, + username=self.username, + allow_agent=True + ) + except Exception as e: + log(ERR, f'Failed to connect to {self.hostname} with key file {path_to_key} ({str(e)})') + raise + s = ret.get_transport().open_session() + # set up the agent request handler to handle agent requests from the server + paramiko.agent.AgentRequestHandler(s) + return ret + + @property + def __ssh(self): + if self.___ssh is None: + self.___ssh = self.__ssh_connect() + return self.___ssh + + @property + def __scp(self): + return SCPClient(self.__ssh.get_transport()) + + async def _run_ssh(self, cmd: list[str], cmd_input: str|None) -> Result: + try: + stdin, stdout, stderr = self.__ssh.exec_command(shlex.join(cmd), timeout=self.__timeout) + except Exception as e: + log(ERR, f'Command failed for {self.uri}: "{shlex.join(cmd)}"') + raise + if cmd_input is not None: + stdin.write(cmd_input) + exit_status = stdout.channel.recv_exit_status() + return Result(stdout.read(), stderr.read(), exit_status) + +