diff --git a/src/python/jw/pkg/lib/ec/ssh/AsyncSSH.py b/src/python/jw/pkg/lib/ec/ssh/AsyncSSH.py index 01a75c1a..6aae15dd 100644 --- a/src/python/jw/pkg/lib/ec/ssh/AsyncSSH.py +++ b/src/python/jw/pkg/lib/ec/ssh/AsyncSSH.py @@ -33,6 +33,7 @@ class AsyncSSH(Base): self.known_hosts = known_hosts self.term_type = term_type or os.environ.get('TERM', 'xterm') self.connect_timeout = connect_timeout + self.__conn: asyncssh.SSHClientConnection|None = None def _connect_kwargs(self, hide_secrets: bool=False) -> dict: @@ -90,6 +91,29 @@ class AsyncSSH(Base): return (cols, rows, xpixel, ypixel) + @property + async def _conn(self) -> asyncssh.SSHClientConnection: + if self.__conn is None: + try: + self.__conn = await asyncssh.connect(**self._connect_kwargs()) + except Exception as e: + msg = f'-------------------- Failed to connect ({str(e)})' + log(ERR, ',', msg) + for key, val in self._connect_kwargs(hide_secrets=True).items(): + log(ERR, f'| {key:<20} = {val}') + log(ERR, '`', msg) + raise + return self.__conn + + async def _close(self) -> None: + if self.__conn is not None: + try: + self.__conn.close() + await self.__conn.wait_closed() + except Exception as e: + log(DEBUG, f'Failed to close connection ({str(e)}, ignored)') + self.__conn = None + async def _read_stream( self, stream, @@ -120,13 +144,13 @@ class AsyncSSH(Base): async def _run_interactive_on_conn( self, - conn: asyncssh.SSHClientConnection, cmd: list[str], wd: str | None, cmd_input: bytes | None, mod_env: dict[str, str] | None, ) -> Result: + conn = await self._conn command = self._build_remote_command(cmd, wd) stdout_parts: list[bytes] = [] @@ -278,7 +302,6 @@ class AsyncSSH(Base): async def _run_captured_pty_on_conn( self, - conn: asyncssh.SSHClientConnection, cmd: list[str], wd: str | None, verbose: bool, @@ -286,6 +309,8 @@ class AsyncSSH(Base): mod_env: dict[str, str] | None, log_prefix: str, ) -> Result: + + conn = await self._conn command = self._build_remote_command(cmd, wd) stdout_parts: list[bytes] = [] @@ -330,7 +355,6 @@ class AsyncSSH(Base): async def _run_on_conn( self, - conn: asyncssh.SSHClientConnection, cmd: list[str], wd: str | None, verbose: bool, @@ -342,7 +366,6 @@ class AsyncSSH(Base): if interactive: if self._has_local_tty(): return await self._run_interactive_on_conn( - conn = conn, cmd = cmd, wd = wd, cmd_input = cmd_input, @@ -350,7 +373,6 @@ class AsyncSSH(Base): ) return await self._run_captured_pty_on_conn( - conn = conn, cmd = cmd, wd = wd, verbose = verbose, @@ -369,6 +391,8 @@ class AsyncSSH(Base): stdin_mode = asyncssh.PIPE if cmd_input is not None else asyncssh.DEVNULL + conn = await self._conn + proc = await conn.create_process( command = command, env = mod_env, @@ -430,21 +454,15 @@ class AsyncSSH(Base): log_prefix: str, ) -> Result: try: - async with asyncssh.connect(**self._connect_kwargs()) as conn: - return await self._run_on_conn( - conn = conn, - cmd = cmd, - wd = wd, - verbose = verbose, - cmd_input = cmd_input, - mod_env = mod_env, - interactive = interactive, - log_prefix = log_prefix, - ) + return await self._run_on_conn( + cmd = cmd, + wd = wd, + verbose = verbose, + cmd_input = cmd_input, + mod_env = mod_env, + interactive = interactive, + log_prefix = log_prefix, + ) except Exception as e: - msg = f'-------------------- Failed to run command {" ".join(cmd)} ({e})' - log(ERR, ',', msg) - for key, val in self._connect_kwargs(hide_secrets=True).items(): - log(ERR, f'| {key:<20} = {val}') - log(ERR, '`', msg) + log(ERR, f'Failed to run command {" ".join(cmd)} ({e})' raise