jw-pkg/src/python/jw/pkg/lib/SSHClient.py

158 lines
5.2 KiB
Python
Raw Normal View History

# -*- coding: utf-8 -*-
from typing import Any
import os, abc, shlex, sys
from .util import run_cmd
from .log import *
from .ExecContext import Result
class SSHClient(abc.ABC):
def __init__(self, hostname: str) -> None:
self.___ssh = None
self.__hostname = hostname
self.__password: str|None = None
@property
def hostname(self):
return self.__hostname
def set_password(self, password: str) -> None:
self.__password = password
@property
def password(self) -> str:
return self.__password
def set_username(self, username: str) -> None:
self.__username = username
@property
def username(self) -> str:
return self.__username
@abc.abstractmethod
async def _run_cmd(self, cmd: list[str]) -> Result:
pass
async def run_cmd(
self,
args: list[str],
wd: str|None = None,
throw: bool = True,
verbose: bool = False,
cmd_input: str|None = None,
env: dict[str, str]|None = None,
title: str=None,
output_encoding: str|None = None, # None => unchanged; "bytes" => return raw bytes
) -> Result:
if wd is not None:
args = ['cd', wd, '&&', *args]
if verbose:
log(WARNING, f'Verbose SSH commands are not yet implemented')
interactive = (
cmd_input == "mode:interactive"
or (cmd_input == "mode:auto" and sys.stdin.isatty())
)
if interactive:
raise NotImplementedError('Interactive SSH is not yet implemented')
if env is not None:
raise NotImplementedError('Passing an environment to SSH commands is not yet implemented')
stdout_b, stderr_b, status = await self._run_cmd(args, cmd_input=cmd_input)
if throw and status:
raise Exception(f'SSH command returned error {status}')
if output_encoding == 'bytes':
return stdout_b, stderr_b, status
if output_encoding is None:
output_encoding = sys.stdout.encoding or "utf-8"
stdout_s = stdout_b.decode(output_encoding, errors="replace") if stdout_b is not None else None
stderr_s = stderr_b.decode(output_encoding, errors="replace") if stderr_b is not None else None
return stdout_s, stderr_s, status
class SSHClientInternal(SSHClient): # export
def __init__(self, hostname: str) -> None:
super().__init__(hostname=hostname)
self.__timeout: float|None = None # Untested
self.___ssh: Any|None = None
def __ssh_connect(self):
import paramiko # type: ignore # error: Library stubs not installed for "paramiko"
ret = paramiko.SSHClient()
ret.set_missing_host_key_policy(paramiko.AutoAddPolicy())
path_to_key=os.path.join(os.environ['HOME'], '.ssh', 'id_rsa')
ret.connect(self.hostname, key_filename=path_to_key, allow_agent=True)
s = ret.get_transport().open_session()
# set up the agent request handler to handle agent requests from the server
paramiko.agent.AgentRequestHandler(s)
return ret
@property
def __ssh(self):
if self.___ssh is None:
self.___ssh = self.__ssh_connect()
return self.___ssh
@property
def __scp(self):
return SCPClient(self.__ssh.get_transport())
async def _run_cmd(self, cmd: list[str], cmd_input: str|None) -> Result:
stdin, stdout, stderr = self.__ssh.exec_command(shlex.join(cmd), timeout=self.__timeout)
if cmd_input is not None:
stdin.write(cmd_input)
exit_status = stdout.channel.recv_exit_status()
return stdout.read(), stderr.read(), exit_status
class SSHClientCmd(SSHClient): # export
def __init__(self, hostname: str) -> None:
self.__askpass: str|None = None
self.__askpass_orig: dict[str, str|None] = dict()
super().__init__(hostname=hostname)
def __del__(self):
for key, val in self.__askpass_orig.items():
if val is None:
del os.environ[key]
else:
os.environ[key] = val
if self.__askpass is not None:
os.remove(self.__askpass)
def __init_askpass(self):
if self.__askpass is None and self.password is not None:
import sys, tempfile
prefix = os.path.basename(sys.argv[0]) + '-'
f = tempfile.NamedTemporaryFile(mode='w+t', prefix=prefix, delete=False)
os.chmod(f.name, 0o0700)
self.__askpass = f.name
f.write(f'#!/bin/bash\n\necho -n "{self.password}\n"')
f.close()
for key, val in {'SSH_ASKPASS': self.__askpass, 'SSH_ASKPASS_REQUIRE': 'force'}.items():
self.__askpass_orig[key] = os.getenv(key)
os.environ[key] = val
async def _run_cmd(self, cmd: list[str], cmd_input: str|None) -> Result:
self.__init_askpass()
return await run_cmd(['ssh', self.hostname, shlex.join(cmd)],
output_encoding='bytes', cmd_input=cmd_input)
def ssh_client(*args, **kwargs) -> SSHClient: # export
try:
return SSHClientInternal(*args, **kwargs)
except:
pass
return SSHClientCmd(*args, **kwargs)