# -*- coding: utf-8 -*- from typing import Any import os, abc, shlex, sys from .util import run_cmd from .log import * from .ExecContext import ExecContext, Result class SSHClient(ExecContext): def __init__(self, hostname: str) -> None: super().__init__(interactive=False, verbose_default=False) self.___ssh = None self.__hostname = hostname self.__password: str|None = None @property def hostname(self): return self.__hostname def set_password(self, password: str) -> None: self.__password = password @property def password(self) -> str: return self.__password def set_username(self, username: str) -> None: self.__username = username @property def username(self) -> str: return self.__username @abc.abstractmethod async def _run_cmd(self, cmd: list[str]) -> Result: pass async def run_cmd(self, *args, **kwargs) -> Result: kwargs.setdefault('wd', None) kwargs.setdefault('throw', True) kwargs.setdefault('verbose', False) kwargs.setdefault('cmd_input', None) kwargs.setdefault('env', None) kwargs.setdefault('title', None) kwargs.setdefault('output_encoding', None) return await self._run(*args, **kwargs) async def _run( self, args: list[str], wd: str|None, throw: bool, verbose: bool, cmd_input: str|None, env: dict[str, str]|None, title: str, output_encoding: str|None, # None => unchanged; "bytes" => return raw bytes ) -> Result: if wd is not None: args = ['cd', wd, '&&', *args] if verbose: log(WARNING, f'Verbose SSH commands are not yet implemented') interactive = ( cmd_input == "mode:interactive" or (cmd_input == "mode:auto" and sys.stdin.isatty()) ) if interactive: raise NotImplementedError('Interactive SSH is not yet implemented') if env is not None: raise NotImplementedError('Passing an environment to SSH commands is not yet implemented') stdout_b, stderr_b, status = await self._run_cmd(args, cmd_input=cmd_input) if throw and status: raise Exception(f'SSH command returned error {status}') if output_encoding == 'bytes': return stdout_b, stderr_b, status if output_encoding is None: output_encoding = sys.stdout.encoding or "utf-8" stdout_s = stdout_b.decode(output_encoding, errors="replace") if stdout_b is not None else None stderr_s = stderr_b.decode(output_encoding, errors="replace") if stderr_b is not None else None return stdout_s, stderr_s, status async def _sudo(self, cmd: list[str], mod_env: dict[str, str], opts: list[str], *args, **kwargs) -> Result: if self.username != 'root': cmd = ['sudo', *opts, *cmd] if mod_env: log(WARNING, f'Modifying environment over SSH is not implemented, ignored') return await self._run(cmd, *args, **kwargs) class SSHClientInternal(SSHClient): # export def __init__(self, hostname: str) -> None: super().__init__(hostname=hostname) self.__timeout: float|None = None # Untested self.___ssh: Any|None = None def __ssh_connect(self): import paramiko # type: ignore # error: Library stubs not installed for "paramiko" ret = paramiko.SSHClient() ret.set_missing_host_key_policy(paramiko.AutoAddPolicy()) path_to_key=os.path.join(os.environ['HOME'], '.ssh', 'id_rsa') ret.connect(self.hostname, key_filename=path_to_key, allow_agent=True) s = ret.get_transport().open_session() # set up the agent request handler to handle agent requests from the server paramiko.agent.AgentRequestHandler(s) return ret @property def __ssh(self): if self.___ssh is None: self.___ssh = self.__ssh_connect() return self.___ssh @property def __scp(self): return SCPClient(self.__ssh.get_transport()) async def _run_cmd(self, cmd: list[str], cmd_input: str|None) -> Result: stdin, stdout, stderr = self.__ssh.exec_command(shlex.join(cmd), timeout=self.__timeout) if cmd_input is not None: stdin.write(cmd_input) exit_status = stdout.channel.recv_exit_status() return Result(stdout.read(), stderr.read(), exit_status) class SSHClientCmd(SSHClient): # export def __init__(self, hostname: str) -> None: self.__askpass: str|None = None self.__askpass_orig: dict[str, str|None] = dict() super().__init__(hostname=hostname) def __del__(self): for key, val in self.__askpass_orig.items(): if val is None: del os.environ[key] else: os.environ[key] = val if self.__askpass is not None: os.remove(self.__askpass) def __init_askpass(self): if self.__askpass is None and self.password is not None: import sys, tempfile prefix = os.path.basename(sys.argv[0]) + '-' f = tempfile.NamedTemporaryFile(mode='w+t', prefix=prefix, delete=False) os.chmod(f.name, 0o0700) self.__askpass = f.name f.write(f'#!/bin/bash\n\necho -n "{self.password}\n"') f.close() for key, val in {'SSH_ASKPASS': self.__askpass, 'SSH_ASKPASS_REQUIRE': 'force'}.items(): self.__askpass_orig[key] = os.getenv(key) os.environ[key] = val async def _run_cmd(self, cmd: list[str], cmd_input: str|None) -> Result: self.__init_askpass() return await run_cmd(['ssh', self.hostname, shlex.join(cmd)], output_encoding='bytes', cmd_input=cmd_input) def ssh_client(*args, **kwargs) -> SSHClient: # export try: return SSHClientInternal(*args, **kwargs) except: pass return SSHClientCmd(*args, **kwargs)