lib.ec.ssh.AsyncSSH: Reuse connection

With CmdCopy as test case and ExecContext.close() in place, we can
actually implement connection reuse, so do it for AsyncSSH.

Signed-off-by: Jan Lindemann <jan@janware.com>
This commit is contained in:
Jan Lindemann 2026-04-15 21:49:40 +02:00
commit 04fef1e67a

View file

@ -11,6 +11,7 @@ from .util import join_cmd
_USE_DEFAULT_KNOWN_HOSTS = object() _USE_DEFAULT_KNOWN_HOSTS = object()
class AsyncSSH(Base): class AsyncSSH(Base):
def __init__( def __init__(
self, self,
uri: str, uri: str,
@ -32,6 +33,7 @@ class AsyncSSH(Base):
self.known_hosts = known_hosts self.known_hosts = known_hosts
self.term_type = term_type or os.environ.get("TERM", "xterm") self.term_type = term_type or os.environ.get("TERM", "xterm")
self.connect_timeout = connect_timeout self.connect_timeout = connect_timeout
self.__conn: asyncssh.SSHClientConnection|None = None
def _connect_kwargs(self) -> dict: def _connect_kwargs(self) -> dict:
kwargs: dict = { kwargs: dict = {
@ -81,7 +83,6 @@ class AsyncSSH(Base):
args[3] = merged_env or None args[3] = merged_env or None
else: else:
kwargs["env"] = dict(mod_env) if mod_env else None kwargs["env"] = dict(mod_env) if mod_env else None
return tuple(args), kwargs return tuple(args), kwargs
@staticmethod @staticmethod
@ -95,20 +96,31 @@ class AsyncSSH(Base):
def _get_local_term_size() -> tuple[int, int, int, int]: def _get_local_term_size() -> tuple[int, int, int, int]:
cols, rows = shutil.get_terminal_size(fallback=(80, 24)) cols, rows = shutil.get_terminal_size(fallback=(80, 24))
xpixel = ypixel = 0 xpixel = ypixel = 0
try: try:
import fcntl, termios, struct import fcntl, termios, struct
packed = fcntl.ioctl(sys.stdout.fileno(), termios.TIOCGWINSZ, b"\0" * 8) packed = fcntl.ioctl(sys.stdout.fileno(), termios.TIOCGWINSZ, b"\0" * 8)
rows2, cols2, xpixel, ypixel = struct.unpack("HHHH", packed) rows2, cols2, xpixel, ypixel = struct.unpack("HHHH", packed)
if cols2 > 0 and rows2 > 0: if cols2 > 0 and rows2 > 0:
cols, rows = cols2, rows2 cols, rows = cols2, rows2
except Exception: except Exception:
pass pass
return (cols, rows, xpixel, ypixel) return (cols, rows, xpixel, ypixel)
@property
async def _conn(self) -> asyncssh.SSHClientConnection:
if self.__conn is None:
self.__conn = await asyncssh.connect(**self._connect_kwargs())
return self.__conn
async def _close(self) -> None:
if self.__conn is not None:
try:
self.__conn.close()
await self.__conn.wait_closed()
except:
pass
self.__conn = None
async def _read_stream( async def _read_stream(
self, self,
stream, stream,
@ -125,26 +137,24 @@ class AsyncSSH(Base):
chunk = await stream.read(4096) chunk = await stream.read(4096)
if not chunk: if not chunk:
break break
collector.append(chunk) collector.append(chunk)
if verbose: if verbose:
buf += chunk buf += chunk
while b"\n" in buf: while b"\n" in buf:
line, buf = buf.split(b"\n", 1) line, buf = buf.split(b"\n", 1)
log(prio, log_prefix, line.decode(log_enc, errors="replace")) log(prio, log_prefix, line.decode(log_enc, errors="replace"))
if verbose and buf: if verbose and buf:
log(prio, log_prefix, buf.decode(log_enc, errors="replace")) log(prio, log_prefix, buf.decode(log_enc, errors="replace"))
async def _run_interactive_on_conn( async def _run_interactive_on_conn(
self, self,
conn: asyncssh.SSHClientConnection,
cmd: list[str], cmd: list[str],
wd: str | None, wd: str | None,
cmd_input: bytes | None, cmd_input: bytes | None,
env: dict[str, str] | None, env: dict[str, str] | None,
) -> Result: ) -> Result:
conn = await self._conn
command = self._build_remote_command(cmd, wd) command = self._build_remote_command(cmd, wd)
stdout_parts: list[bytes] = [] stdout_parts: list[bytes] = []
@ -194,10 +204,8 @@ class AsyncSSH(Base):
if cmd_input is not None and proc.stdin is not None: if cmd_input is not None and proc.stdin is not None:
proc.stdin.write(cmd_input) proc.stdin.write(cmd_input)
await proc.stdin.drain() await proc.stdin.drain()
while True: while True:
data = await stdin_queue.get() data = await stdin_queue.get()
if data is None: if data is None:
if proc.stdin is not None: if proc.stdin is not None:
try: try:
@ -205,10 +213,8 @@ class AsyncSSH(Base):
except (BrokenPipeError, OSError): except (BrokenPipeError, OSError):
pass pass
return return
if proc.stdin is None: if proc.stdin is None:
return return
proc.stdin.write(data) proc.stdin.write(data)
await proc.stdin.drain() await proc.stdin.drain()
@ -217,11 +223,11 @@ class AsyncSSH(Base):
chunk = await proc.stdout.read(4096) chunk = await proc.stdout.read(4096)
if not chunk: if not chunk:
break break
stdout_parts.append(chunk) stdout_parts.append(chunk)
_write_local(chunk) _write_local(chunk)
def _on_winch(*_args) -> None: def _on_winch(*_args) -> None:
try: try:
proc.change_terminal_size(*self._get_local_term_size()) proc.change_terminal_size(*self._get_local_term_size())
except Exception: except Exception:
@ -300,7 +306,6 @@ class AsyncSSH(Base):
async def _run_captured_pty_on_conn( async def _run_captured_pty_on_conn(
self, self,
conn: asyncssh.SSHClientConnection,
cmd: list[str], cmd: list[str],
wd: str | None, wd: str | None,
verbose: bool, verbose: bool,
@ -308,6 +313,8 @@ class AsyncSSH(Base):
env: dict[str, str] | None, env: dict[str, str] | None,
log_prefix: str, log_prefix: str,
) -> Result: ) -> Result:
conn = await self._conn
command = self._build_remote_command(cmd, wd) command = self._build_remote_command(cmd, wd)
stdout_parts: list[bytes] = [] stdout_parts: list[bytes] = []
@ -352,7 +359,6 @@ class AsyncSSH(Base):
async def _run_on_conn( async def _run_on_conn(
self, self,
conn: asyncssh.SSHClientConnection,
cmd: list[str], cmd: list[str],
wd: str | None, wd: str | None,
verbose: bool, verbose: bool,
@ -361,24 +367,25 @@ class AsyncSSH(Base):
interactive: bool, interactive: bool,
log_prefix: str, log_prefix: str,
) -> Result: ) -> Result:
conn = await self._conn
if interactive: if interactive:
if self._has_local_tty(): if self._has_local_tty():
return await self._run_interactive_on_conn( return await self._run_interactive_on_conn(
conn, cmd=cmd,
cmd, wd=wd,
wd, cmd_input=cmd_input,
cmd_input, env=env,
env,
) )
return await self._run_captured_pty_on_conn( return await self._run_captured_pty_on_conn(
conn, cmd=cmd,
cmd, wd=wd,
wd, verbose=verbose,
verbose, cmd_input=cmd_input,
cmd_input, env=env,
env, log_prefix=log_prefix,
log_prefix,
) )
command = self._build_remote_command(cmd, wd) command = self._build_remote_command(cmd, wd)
@ -451,9 +458,7 @@ class AsyncSSH(Base):
interactive: bool, interactive: bool,
log_prefix: str, log_prefix: str,
) -> Result: ) -> Result:
async with asyncssh.connect(**self._connect_kwargs()) as conn:
return await self._run_on_conn( return await self._run_on_conn(
conn,
cmd, cmd,
wd, wd,
verbose, verbose,
@ -471,10 +476,11 @@ class AsyncSSH(Base):
*args, *args,
**kwargs, **kwargs,
) -> Result: ) -> Result:
args, kwargs = self._merge_env_into_forwarded_args(args, kwargs, mod_env) args, kwargs = self._merge_env_into_forwarded_args(args, kwargs, mod_env)
async with asyncssh.connect(**self._connect_kwargs()) as conn: conn = await self._conn
uid_result = await conn.run("id -u", check=False) uid_result = conn.run("id -u", check=False)
is_root = ( is_root = (
uid_result.exit_status == 0 uid_result.exit_status == 0
and isinstance(uid_result.stdout, str) and isinstance(uid_result.stdout, str)
@ -491,4 +497,4 @@ class AsyncSSH(Base):
cmdline.extend(cmd) cmdline.extend(cmd)
return await self._run_on_conn(conn, cmdline, *args, **kwargs) return await self._run_on_conn(cmdline, *args, **kwargs)