diff --git a/src/python/jw/pkg/lib/ec/ssh/AsyncSSH.py b/src/python/jw/pkg/lib/ec/ssh/AsyncSSH.py index 6aae15dd..5faab1f2 100644 --- a/src/python/jw/pkg/lib/ec/ssh/AsyncSSH.py +++ b/src/python/jw/pkg/lib/ec/ssh/AsyncSSH.py @@ -29,68 +29,37 @@ class AsyncSSH(Base): **kwargs ) - self.client_keys = client_keys - self.known_hosts = known_hosts - self.term_type = term_type or os.environ.get('TERM', 'xterm') - self.connect_timeout = connect_timeout + self.__client_keys = client_keys + self.__known_hosts = known_hosts + self.__term_type = term_type or os.environ.get('TERM', 'xterm') + self.__connect_timeout = connect_timeout self.__conn: asyncssh.SSHClientConnection|None = None - def _connect_kwargs(self, hide_secrets: bool=False) -> dict: + async def _close(self) -> None: + if self.__conn is not None: + try: + self.__conn.close() + await self.__conn.wait_closed() + except Exception as e: + log(DEBUG, f'Failed to close connection ({str(e)}, ignored)') + self.__conn = None + def _connect_kwargs(self, hide_secrets: bool=False) -> dict: kwargs: dict = { 'host': self.hostname, 'port': self.port, 'username': self.username, 'password': self.password, - 'client_keys': self.client_keys, - 'connect_timeout': self.connect_timeout, + 'client_keys': self.__client_keys, + 'connect_timeout': self.__connect_timeout, } - - if self.known_hosts is not _USE_DEFAULT_KNOWN_HOSTS: - kwargs['known_hosts'] = self.known_hosts - + if self.__known_hosts is not _USE_DEFAULT_KNOWN_HOSTS: + kwargs['known_hosts'] = self.__known_hosts ret = {k: v for k, v in kwargs.items() if v is not None} if hide_secrets and 'password' in kwargs: kwargs['password'] = '' return ret - @staticmethod - def _build_remote_command(cmd: list[str], wd: str | None) -> str: - - if not cmd: - raise ValueError('cmd must not be empty') - - inner = f'exec {join_cmd(cmd)}' - if wd is not None: - inner = f'cd {shlex.quote(wd)} && {inner}' - - return f'/bin/sh -lc {shlex.quote(inner)}' - - @staticmethod - def _has_local_tty() -> bool: - try: - return sys.stdin.isatty() and sys.stdout.isatty() - except Exception: - return False - - @staticmethod - def _get_local_term_size() -> tuple[int, int, int, int]: - cols, rows = shutil.get_terminal_size(fallback=(80, 24)) - xpixel = ypixel = 0 - - try: - import fcntl, termios, struct - - packed = fcntl.ioctl(sys.stdout.fileno(), termios.TIOCGWINSZ, b'\0' * 8) - rows2, cols2, xpixel, ypixel = struct.unpack('HHHH', packed) - - if cols2 > 0 and rows2 > 0: - cols, rows = cols2, rows2 - except Exception: - pass - - return (cols, rows, xpixel, ypixel) - @property async def _conn(self) -> asyncssh.SSHClientConnection: if self.__conn is None: @@ -105,14 +74,33 @@ class AsyncSSH(Base): raise return self.__conn - async def _close(self) -> None: - if self.__conn is not None: - try: - self.__conn.close() - await self.__conn.wait_closed() - except Exception as e: - log(DEBUG, f'Failed to close connection ({str(e)}, ignored)') - self.__conn = None + @staticmethod + def _build_remote_command(cmd: list[str], wd: str | None) -> str: + inner = f'exec {join_cmd(cmd)}' + if wd is not None: + inner = f'cd {shlex.quote(wd)} && {inner}' + return f'/bin/sh -lc {shlex.quote(inner)}' + + @staticmethod + def _has_local_tty() -> bool: + try: + return sys.stdin.isatty() and sys.stdout.isatty() + except Exception: + return False + + @staticmethod + def _get_local_term_size() -> tuple[int, int, int, int]: + cols, rows = shutil.get_terminal_size(fallback=(80, 24)) + xpixel = ypixel = 0 + try: + import fcntl, termios, struct + packed = fcntl.ioctl(sys.stdout.fileno(), termios.TIOCGWINSZ, b'\0' * 8) + rows2, cols2, xpixel, ypixel = struct.unpack('HHHH', packed) + if cols2 > 0 and rows2 > 0: + cols, rows = cols2, rows2 + except Exception: + pass + return (cols, rows, xpixel, ypixel) async def _read_stream( self, @@ -125,20 +113,16 @@ class AsyncSSH(Base): log_enc: str, ) -> None: buf = b'' - while True: chunk = await stream.read(4096) if not chunk: break - collector.append(chunk) - if verbose: buf += chunk while b'\n' in buf: line, buf = buf.split(b'\n', 1) log(prio, log_prefix, line.decode(log_enc, errors='replace')) - if verbose and buf: log(prio, log_prefix, buf.decode(log_enc, errors='replace')) @@ -162,7 +146,7 @@ class AsyncSSH(Base): stderr = asyncssh.STDOUT, encoding = None, request_pty = 'force', - term_type = self.term_type, + term_type = self.__term_type, term_size = self._get_local_term_size(), ) @@ -324,7 +308,7 @@ class AsyncSSH(Base): stderr = asyncssh.STDOUT, encoding = None, request_pty = 'force', - term_type = self.term_type, + term_type = self.__term_type, ) task = asyncio.create_task( @@ -353,96 +337,6 @@ class AsyncSSH(Base): stdout = b''.join(stdout_parts) if stdout_parts else None return Result(stdout, None, exit_code) - async def _run_on_conn( - self, - cmd: list[str], - wd: str | None, - verbose: bool, - cmd_input: bytes | None, - mod_env: dict[str, str] | None, - interactive: bool, - log_prefix: str, - ) -> Result: - if interactive: - if self._has_local_tty(): - return await self._run_interactive_on_conn( - cmd = cmd, - wd = wd, - cmd_input = cmd_input, - mod_env = mod_env, - ) - - return await self._run_captured_pty_on_conn( - cmd = cmd, - wd = wd, - verbose = verbose, - cmd_input = cmd_input, - mod_env = mod_env, - log_prefix = log_prefix, - ) - - command = self._build_remote_command(cmd, wd) - - stdout_parts: list[bytes] = [] - stderr_parts: list[bytes] = [] - - stdout_log_enc = sys.stdout.encoding or 'utf-8' - stderr_log_enc = sys.stderr.encoding or 'utf-8' - - stdin_mode = asyncssh.PIPE if cmd_input is not None else asyncssh.DEVNULL - - conn = await self._conn - - proc = await conn.create_process( - command = command, - env = mod_env, - stdin = stdin_mode, - stdout = asyncssh.PIPE, - stderr = asyncssh.PIPE, - encoding = None, - request_pty = False, - ) - - tasks = [ - asyncio.create_task( - self._read_stream( - proc.stdout, - NOTICE, - stdout_parts, - verbose = verbose, - log_prefix = log_prefix, - log_enc = stdout_log_enc, - ) - ), - asyncio.create_task( - self._read_stream( - proc.stderr, - ERR, - stderr_parts, - verbose = verbose, - log_prefix = log_prefix, - log_enc = stderr_log_enc, - ) - ), - ] - - if cmd_input is not None and proc.stdin is not None: - proc.stdin.write(cmd_input) - await proc.stdin.drain() - proc.stdin.write_eof() - - completed = await proc.wait(check=False) - await asyncio.gather(*tasks) - - stdout = b''.join(stdout_parts) if stdout_parts else None - stderr = b''.join(stderr_parts) if stderr_parts else None - - exit_code = completed.exit_status - if exit_code is None: - exit_code = completed.returncode if completed.returncode is not None else -1 - - return Result(stdout, stderr, exit_code) - async def _run_ssh( self, cmd: list[str], @@ -453,16 +347,88 @@ class AsyncSSH(Base): interactive: bool, log_prefix: str, ) -> Result: + try: - return await self._run_on_conn( - cmd = cmd, - wd = wd, - verbose = verbose, - cmd_input = cmd_input, - mod_env = mod_env, - interactive = interactive, - log_prefix = log_prefix, + + if interactive: + if self._has_local_tty(): + return await self._run_interactive_on_conn( + cmd = cmd, + wd = wd, + cmd_input = cmd_input, + mod_env = mod_env, + ) + return await self._run_captured_pty_on_conn( + cmd = cmd, + wd = wd, + verbose = verbose, + cmd_input = cmd_input, + mod_env = mod_env, + log_prefix = log_prefix, + ) + + command = self._build_remote_command(cmd, wd) + + stdout_parts: list[bytes] = [] + stderr_parts: list[bytes] = [] + + stdout_log_enc = sys.stdout.encoding or 'utf-8' + stderr_log_enc = sys.stderr.encoding or 'utf-8' + + stdin_mode = asyncssh.PIPE if cmd_input is not None else asyncssh.DEVNULL + + conn = await self._conn + + proc = await conn.create_process( + command = command, + env = mod_env, + stdin = stdin_mode, + stdout = asyncssh.PIPE, + stderr = asyncssh.PIPE, + encoding = None, + request_pty = False, ) + + tasks = [ + asyncio.create_task( + self._read_stream( + proc.stdout, + NOTICE, + stdout_parts, + verbose = verbose, + log_prefix = log_prefix, + log_enc = stdout_log_enc, + ) + ), + asyncio.create_task( + self._read_stream( + proc.stderr, + ERR, + stderr_parts, + verbose = verbose, + log_prefix = log_prefix, + log_enc = stderr_log_enc, + ) + ), + ] + + if cmd_input is not None and proc.stdin is not None: + proc.stdin.write(cmd_input) + await proc.stdin.drain() + proc.stdin.write_eof() + + completed = await proc.wait(check=False) + await asyncio.gather(*tasks) + + stdout = b''.join(stdout_parts) if stdout_parts else None + stderr = b''.join(stderr_parts) if stderr_parts else None + + exit_code = completed.exit_status + if exit_code is None: + exit_code = completed.returncode if completed.returncode is not None else -1 + + return Result(stdout, stderr, exit_code) + except Exception as e: - log(ERR, f'Failed to run command {" ".join(cmd)} ({e})' + log(ERR, f'Failed to run command {" ".join(cmd)} ({e})') raise