# -*- coding: utf-8 -*- from typing import Any import os, abc, shlex, sys from .util import run_cmd from .ExecContext import Result class SSHClient(abc.ABC): def __init__(self, hostname: str) -> None: self.___ssh = None self.__hostname = hostname self.__password: str|None = None @property def hostname(self): return self.__hostname def set_password(self, password: str) -> None: self.__password = password @property def password(self) -> str: return self.__password def set_username(self, username: str) -> None: self.__username = username @property def username(self) -> str: return self.__username @abc.abstractmethod async def _run_cmd(self, cmd: list[str]) -> Result: pass async def run_cmd( self, cmd: list[str], output_encoding: str|None = None, ) -> Result: stdout_b, stderr_b, status = await self._run_cmd(cmd) if output_encoding == 'bytes': return stdout_b, stderr_b, status if output_encoding is None: output_encoding = sys.stdout.encoding or "utf-8" stdout_s = stdout_b.decode(output_encoding, errors="replace") if stdout_b is not None else None stderr_s = stderr_b.decode(output_encoding, errors="replace") if stderr_b is not None else None return stdout_s, stderr_s, status class SSHClientInternal(SSHClient): # export def __init__(self, hostname: str) -> None: super().__init__(hostname=hostname) self.__timeout: float|None = None # Untested self.___ssh: Any|None = None def __ssh_connect(self): import paramiko # type: ignore # error: Library stubs not installed for "paramiko" ret = paramiko.SSHClient() ret.set_missing_host_key_policy(paramiko.AutoAddPolicy()) path_to_key=os.path.join(os.environ['HOME'], '.ssh', 'id_rsa') ret.connect(self.hostname, key_filename=path_to_key, allow_agent=True) s = ret.get_transport().open_session() # set up the agent request handler to handle agent requests from the server paramiko.agent.AgentRequestHandler(s) return ret @property def __ssh(self): if self.___ssh is None: self.___ssh = self.__ssh_connect() return self.___ssh @property def __scp(self): return SCPClient(self.__ssh.get_transport()) async def _run_cmd(self, cmd: list[str]) -> Result: stdin, stdout, stderr = self.__ssh.exec_command(shlex.join(cmd), timeout=self.__timeout) exit_status = stdout.channel.recv_exit_status() return stdout.read(), stderr.read(), exit_status class SSHClientCmd(SSHClient): # export def __init__(self, hostname: str) -> None: self.__askpass: str|None = None self.__askpass_orig: dict[str, str|None] = dict() super().__init__(hostname=hostname) def __del__(self): for key, val in self.__askpass_orig.items(): if val is None: del os.environ[key] else: os.environ[key] = val if self.__askpass is not None: os.remove(self.__askpass) def __init_askpass(self): if self.__askpass is None and self.password is not None: import sys, tempfile prefix = os.path.basename(sys.argv[0]) + '-' f = tempfile.NamedTemporaryFile(mode='w+t', prefix=prefix, delete=False) os.chmod(f.name, 0o0700) self.__askpass = f.name f.write(f'#!/bin/bash\n\necho -n "{self.password}\n"') f.close() for key, val in {'SSH_ASKPASS': self.__askpass, 'SSH_ASKPASS_REQUIRE': 'force'}.items(): self.__askpass_orig[key] = os.getenv(key) os.environ[key] = val async def _run_cmd(self, cmd: list[str]) -> Result: self.__init_askpass() return await run_cmd(['ssh', self.hostname, shlex.join(cmd)], output_encoding='bytes') def ssh_client(*args, **kwargs) -> SSHClient: # export try: return SSHClientInternal(*args, **kwargs) except: pass return SSHClientCmd(*args, **kwargs)