diff --git a/src/python/jw/pkg/lib/ec/ssh/AsyncSSH.py b/src/python/jw/pkg/lib/ec/ssh/AsyncSSH.py index 0b827c1b..52985c0a 100644 --- a/src/python/jw/pkg/lib/ec/ssh/AsyncSSH.py +++ b/src/python/jw/pkg/lib/ec/ssh/AsyncSSH.py @@ -11,6 +11,7 @@ from .util import join_cmd _USE_DEFAULT_KNOWN_HOSTS = object() class AsyncSSH(Base): + def __init__( self, uri: str, @@ -32,6 +33,7 @@ class AsyncSSH(Base): self.known_hosts = known_hosts self.term_type = term_type or os.environ.get("TERM", "xterm") self.connect_timeout = connect_timeout + self.__conn: asyncssh.SSHClientConnection|None = None def _connect_kwargs(self) -> dict: kwargs: dict = { @@ -81,7 +83,6 @@ class AsyncSSH(Base): args[3] = merged_env or None else: kwargs["env"] = dict(mod_env) if mod_env else None - return tuple(args), kwargs @staticmethod @@ -95,20 +96,31 @@ class AsyncSSH(Base): def _get_local_term_size() -> tuple[int, int, int, int]: cols, rows = shutil.get_terminal_size(fallback=(80, 24)) xpixel = ypixel = 0 - try: import fcntl, termios, struct - packed = fcntl.ioctl(sys.stdout.fileno(), termios.TIOCGWINSZ, b"\0" * 8) rows2, cols2, xpixel, ypixel = struct.unpack("HHHH", packed) - if cols2 > 0 and rows2 > 0: cols, rows = cols2, rows2 except Exception: pass - return (cols, rows, xpixel, ypixel) + @property + async def _conn(self) -> asyncssh.SSHClientConnection: + if self.__conn is None: + self.__conn = await asyncssh.connect(**self._connect_kwargs()) + return self.__conn + + async def _close(self) -> None: + if self.__conn is not None: + try: + self.__conn.close() + await self.__conn.wait_closed() + except: + pass + self.__conn = None + async def _read_stream( self, stream, @@ -125,26 +137,24 @@ class AsyncSSH(Base): chunk = await stream.read(4096) if not chunk: break - collector.append(chunk) - if verbose: buf += chunk while b"\n" in buf: line, buf = buf.split(b"\n", 1) log(prio, log_prefix, line.decode(log_enc, errors="replace")) - if verbose and buf: log(prio, log_prefix, buf.decode(log_enc, errors="replace")) async def _run_interactive_on_conn( self, - conn: asyncssh.SSHClientConnection, cmd: list[str], wd: str | None, cmd_input: bytes | None, env: dict[str, str] | None, ) -> Result: + + conn = await self._conn command = self._build_remote_command(cmd, wd) stdout_parts: list[bytes] = [] @@ -194,10 +204,8 @@ class AsyncSSH(Base): if cmd_input is not None and proc.stdin is not None: proc.stdin.write(cmd_input) await proc.stdin.drain() - while True: data = await stdin_queue.get() - if data is None: if proc.stdin is not None: try: @@ -205,10 +213,8 @@ class AsyncSSH(Base): except (BrokenPipeError, OSError): pass return - if proc.stdin is None: return - proc.stdin.write(data) await proc.stdin.drain() @@ -217,11 +223,11 @@ class AsyncSSH(Base): chunk = await proc.stdout.read(4096) if not chunk: break - stdout_parts.append(chunk) _write_local(chunk) def _on_winch(*_args) -> None: + try: proc.change_terminal_size(*self._get_local_term_size()) except Exception: @@ -300,7 +306,6 @@ class AsyncSSH(Base): async def _run_captured_pty_on_conn( self, - conn: asyncssh.SSHClientConnection, cmd: list[str], wd: str | None, verbose: bool, @@ -308,6 +313,8 @@ class AsyncSSH(Base): env: dict[str, str] | None, log_prefix: str, ) -> Result: + + conn = await self._conn command = self._build_remote_command(cmd, wd) stdout_parts: list[bytes] = [] @@ -352,7 +359,6 @@ class AsyncSSH(Base): async def _run_on_conn( self, - conn: asyncssh.SSHClientConnection, cmd: list[str], wd: str | None, verbose: bool, @@ -361,24 +367,25 @@ class AsyncSSH(Base): interactive: bool, log_prefix: str, ) -> Result: + + conn = await self._conn + if interactive: if self._has_local_tty(): return await self._run_interactive_on_conn( - conn, - cmd, - wd, - cmd_input, - env, + cmd=cmd, + wd=wd, + cmd_input=cmd_input, + env=env, ) return await self._run_captured_pty_on_conn( - conn, - cmd, - wd, - verbose, - cmd_input, - env, - log_prefix, + cmd=cmd, + wd=wd, + verbose=verbose, + cmd_input=cmd_input, + env=env, + log_prefix=log_prefix, ) command = self._build_remote_command(cmd, wd) @@ -451,17 +458,15 @@ class AsyncSSH(Base): interactive: bool, log_prefix: str, ) -> Result: - async with asyncssh.connect(**self._connect_kwargs()) as conn: - return await self._run_on_conn( - conn, - cmd, - wd, - verbose, - cmd_input, - env, - interactive, - log_prefix, - ) + return await self._run_on_conn( + cmd, + wd, + verbose, + cmd_input, + env, + interactive, + log_prefix, + ) async def _sudo( self, @@ -471,24 +476,25 @@ class AsyncSSH(Base): *args, **kwargs, ) -> Result: + args, kwargs = self._merge_env_into_forwarded_args(args, kwargs, mod_env) - async with asyncssh.connect(**self._connect_kwargs()) as conn: - uid_result = await conn.run("id -u", check=False) - is_root = ( - uid_result.exit_status == 0 - and isinstance(uid_result.stdout, str) - and uid_result.stdout.strip() == "0" - ) + conn = await self._conn + uid_result = conn.run("id -u", check=False) + is_root = ( + uid_result.exit_status == 0 + and isinstance(uid_result.stdout, str) + and uid_result.stdout.strip() == "0" + ) - cmdline: list[str] = [] + cmdline: list[str] = [] - if not is_root: - cmdline.append("/usr/bin/sudo") - if mod_env: - cmdline.append("--preserve-env=" + ",".join(mod_env.keys())) - cmdline.extend(opts) + if not is_root: + cmdline.append("/usr/bin/sudo") + if mod_env: + cmdline.append("--preserve-env=" + ",".join(mod_env.keys())) + cmdline.extend(opts) - cmdline.extend(cmd) + cmdline.extend(cmd) - return await self._run_on_conn(conn, cmdline, *args, **kwargs) + return await self._run_on_conn(cmdline, *args, **kwargs)