lib.ec.ssh.AsyncSSH: Code beautification

- Remove _run_on_conn() because it doesn't add any value

  - Add verbose try-except block around connect()

  - Add try-except block around failing close

  - Prefix private member variables with "__"

Signed-off-by: Jan Lindemann <jan@janware.com>
This commit is contained in:
Jan Lindemann 2026-04-19 19:52:18 +02:00
commit 8210baa683

View file

@ -29,68 +29,37 @@ class AsyncSSH(Base):
**kwargs **kwargs
) )
self.client_keys = client_keys self.__client_keys = client_keys
self.known_hosts = known_hosts self.__known_hosts = known_hosts
self.term_type = term_type or os.environ.get('TERM', 'xterm') self.__term_type = term_type or os.environ.get('TERM', 'xterm')
self.connect_timeout = connect_timeout self.__connect_timeout = connect_timeout
self.__conn: asyncssh.SSHClientConnection|None = None self.__conn: asyncssh.SSHClientConnection|None = None
def _connect_kwargs(self, hide_secrets: bool=False) -> dict: async def _close(self) -> None:
if self.__conn is not None:
try:
self.__conn.close()
await self.__conn.wait_closed()
except Exception as e:
log(DEBUG, f'Failed to close connection ({str(e)}, ignored)')
self.__conn = None
def _connect_kwargs(self, hide_secrets: bool=False) -> dict:
kwargs: dict = { kwargs: dict = {
'host': self.hostname, 'host': self.hostname,
'port': self.port, 'port': self.port,
'username': self.username, 'username': self.username,
'password': self.password, 'password': self.password,
'client_keys': self.client_keys, 'client_keys': self.__client_keys,
'connect_timeout': self.connect_timeout, 'connect_timeout': self.__connect_timeout,
} }
if self.__known_hosts is not _USE_DEFAULT_KNOWN_HOSTS:
if self.known_hosts is not _USE_DEFAULT_KNOWN_HOSTS: kwargs['known_hosts'] = self.__known_hosts
kwargs['known_hosts'] = self.known_hosts
ret = {k: v for k, v in kwargs.items() if v is not None} ret = {k: v for k, v in kwargs.items() if v is not None}
if hide_secrets and 'password' in kwargs: if hide_secrets and 'password' in kwargs:
kwargs['password'] = '<hidden>' kwargs['password'] = '<hidden>'
return ret return ret
@staticmethod
def _build_remote_command(cmd: list[str], wd: str | None) -> str:
if not cmd:
raise ValueError('cmd must not be empty')
inner = f'exec {join_cmd(cmd)}'
if wd is not None:
inner = f'cd {shlex.quote(wd)} && {inner}'
return f'/bin/sh -lc {shlex.quote(inner)}'
@staticmethod
def _has_local_tty() -> bool:
try:
return sys.stdin.isatty() and sys.stdout.isatty()
except Exception:
return False
@staticmethod
def _get_local_term_size() -> tuple[int, int, int, int]:
cols, rows = shutil.get_terminal_size(fallback=(80, 24))
xpixel = ypixel = 0
try:
import fcntl, termios, struct
packed = fcntl.ioctl(sys.stdout.fileno(), termios.TIOCGWINSZ, b'\0' * 8)
rows2, cols2, xpixel, ypixel = struct.unpack('HHHH', packed)
if cols2 > 0 and rows2 > 0:
cols, rows = cols2, rows2
except Exception:
pass
return (cols, rows, xpixel, ypixel)
@property @property
async def _conn(self) -> asyncssh.SSHClientConnection: async def _conn(self) -> asyncssh.SSHClientConnection:
if self.__conn is None: if self.__conn is None:
@ -105,14 +74,33 @@ class AsyncSSH(Base):
raise raise
return self.__conn return self.__conn
async def _close(self) -> None: @staticmethod
if self.__conn is not None: def _build_remote_command(cmd: list[str], wd: str | None) -> str:
try: inner = f'exec {join_cmd(cmd)}'
self.__conn.close() if wd is not None:
await self.__conn.wait_closed() inner = f'cd {shlex.quote(wd)} && {inner}'
except Exception as e: return f'/bin/sh -lc {shlex.quote(inner)}'
log(DEBUG, f'Failed to close connection ({str(e)}, ignored)')
self.__conn = None @staticmethod
def _has_local_tty() -> bool:
try:
return sys.stdin.isatty() and sys.stdout.isatty()
except Exception:
return False
@staticmethod
def _get_local_term_size() -> tuple[int, int, int, int]:
cols, rows = shutil.get_terminal_size(fallback=(80, 24))
xpixel = ypixel = 0
try:
import fcntl, termios, struct
packed = fcntl.ioctl(sys.stdout.fileno(), termios.TIOCGWINSZ, b'\0' * 8)
rows2, cols2, xpixel, ypixel = struct.unpack('HHHH', packed)
if cols2 > 0 and rows2 > 0:
cols, rows = cols2, rows2
except Exception:
pass
return (cols, rows, xpixel, ypixel)
async def _read_stream( async def _read_stream(
self, self,
@ -125,20 +113,16 @@ class AsyncSSH(Base):
log_enc: str, log_enc: str,
) -> None: ) -> None:
buf = b'' buf = b''
while True: while True:
chunk = await stream.read(4096) chunk = await stream.read(4096)
if not chunk: if not chunk:
break break
collector.append(chunk) collector.append(chunk)
if verbose: if verbose:
buf += chunk buf += chunk
while b'\n' in buf: while b'\n' in buf:
line, buf = buf.split(b'\n', 1) line, buf = buf.split(b'\n', 1)
log(prio, log_prefix, line.decode(log_enc, errors='replace')) log(prio, log_prefix, line.decode(log_enc, errors='replace'))
if verbose and buf: if verbose and buf:
log(prio, log_prefix, buf.decode(log_enc, errors='replace')) log(prio, log_prefix, buf.decode(log_enc, errors='replace'))
@ -162,7 +146,7 @@ class AsyncSSH(Base):
stderr = asyncssh.STDOUT, stderr = asyncssh.STDOUT,
encoding = None, encoding = None,
request_pty = 'force', request_pty = 'force',
term_type = self.term_type, term_type = self.__term_type,
term_size = self._get_local_term_size(), term_size = self._get_local_term_size(),
) )
@ -324,7 +308,7 @@ class AsyncSSH(Base):
stderr = asyncssh.STDOUT, stderr = asyncssh.STDOUT,
encoding = None, encoding = None,
request_pty = 'force', request_pty = 'force',
term_type = self.term_type, term_type = self.__term_type,
) )
task = asyncio.create_task( task = asyncio.create_task(
@ -353,96 +337,6 @@ class AsyncSSH(Base):
stdout = b''.join(stdout_parts) if stdout_parts else None stdout = b''.join(stdout_parts) if stdout_parts else None
return Result(stdout, None, exit_code) return Result(stdout, None, exit_code)
async def _run_on_conn(
self,
cmd: list[str],
wd: str | None,
verbose: bool,
cmd_input: bytes | None,
mod_env: dict[str, str] | None,
interactive: bool,
log_prefix: str,
) -> Result:
if interactive:
if self._has_local_tty():
return await self._run_interactive_on_conn(
cmd = cmd,
wd = wd,
cmd_input = cmd_input,
mod_env = mod_env,
)
return await self._run_captured_pty_on_conn(
cmd = cmd,
wd = wd,
verbose = verbose,
cmd_input = cmd_input,
mod_env = mod_env,
log_prefix = log_prefix,
)
command = self._build_remote_command(cmd, wd)
stdout_parts: list[bytes] = []
stderr_parts: list[bytes] = []
stdout_log_enc = sys.stdout.encoding or 'utf-8'
stderr_log_enc = sys.stderr.encoding or 'utf-8'
stdin_mode = asyncssh.PIPE if cmd_input is not None else asyncssh.DEVNULL
conn = await self._conn
proc = await conn.create_process(
command = command,
env = mod_env,
stdin = stdin_mode,
stdout = asyncssh.PIPE,
stderr = asyncssh.PIPE,
encoding = None,
request_pty = False,
)
tasks = [
asyncio.create_task(
self._read_stream(
proc.stdout,
NOTICE,
stdout_parts,
verbose = verbose,
log_prefix = log_prefix,
log_enc = stdout_log_enc,
)
),
asyncio.create_task(
self._read_stream(
proc.stderr,
ERR,
stderr_parts,
verbose = verbose,
log_prefix = log_prefix,
log_enc = stderr_log_enc,
)
),
]
if cmd_input is not None and proc.stdin is not None:
proc.stdin.write(cmd_input)
await proc.stdin.drain()
proc.stdin.write_eof()
completed = await proc.wait(check=False)
await asyncio.gather(*tasks)
stdout = b''.join(stdout_parts) if stdout_parts else None
stderr = b''.join(stderr_parts) if stderr_parts else None
exit_code = completed.exit_status
if exit_code is None:
exit_code = completed.returncode if completed.returncode is not None else -1
return Result(stdout, stderr, exit_code)
async def _run_ssh( async def _run_ssh(
self, self,
cmd: list[str], cmd: list[str],
@ -453,16 +347,88 @@ class AsyncSSH(Base):
interactive: bool, interactive: bool,
log_prefix: str, log_prefix: str,
) -> Result: ) -> Result:
try: try:
return await self._run_on_conn(
cmd = cmd, if interactive:
wd = wd, if self._has_local_tty():
verbose = verbose, return await self._run_interactive_on_conn(
cmd_input = cmd_input, cmd = cmd,
mod_env = mod_env, wd = wd,
interactive = interactive, cmd_input = cmd_input,
log_prefix = log_prefix, mod_env = mod_env,
)
return await self._run_captured_pty_on_conn(
cmd = cmd,
wd = wd,
verbose = verbose,
cmd_input = cmd_input,
mod_env = mod_env,
log_prefix = log_prefix,
)
command = self._build_remote_command(cmd, wd)
stdout_parts: list[bytes] = []
stderr_parts: list[bytes] = []
stdout_log_enc = sys.stdout.encoding or 'utf-8'
stderr_log_enc = sys.stderr.encoding or 'utf-8'
stdin_mode = asyncssh.PIPE if cmd_input is not None else asyncssh.DEVNULL
conn = await self._conn
proc = await conn.create_process(
command = command,
env = mod_env,
stdin = stdin_mode,
stdout = asyncssh.PIPE,
stderr = asyncssh.PIPE,
encoding = None,
request_pty = False,
) )
tasks = [
asyncio.create_task(
self._read_stream(
proc.stdout,
NOTICE,
stdout_parts,
verbose = verbose,
log_prefix = log_prefix,
log_enc = stdout_log_enc,
)
),
asyncio.create_task(
self._read_stream(
proc.stderr,
ERR,
stderr_parts,
verbose = verbose,
log_prefix = log_prefix,
log_enc = stderr_log_enc,
)
),
]
if cmd_input is not None and proc.stdin is not None:
proc.stdin.write(cmd_input)
await proc.stdin.drain()
proc.stdin.write_eof()
completed = await proc.wait(check=False)
await asyncio.gather(*tasks)
stdout = b''.join(stdout_parts) if stdout_parts else None
stderr = b''.join(stderr_parts) if stderr_parts else None
exit_code = completed.exit_status
if exit_code is None:
exit_code = completed.returncode if completed.returncode is not None else -1
return Result(stdout, stderr, exit_code)
except Exception as e: except Exception as e:
log(ERR, f'Failed to run command {" ".join(cmd)} ({e})' log(ERR, f'Failed to run command {" ".join(cmd)} ({e})')
raise raise