lib.ec.ssh.AsyncSSH: Reuse connection

With CmdCopy as test case and ExecContext.close() in place, we can
actually implement connection reuse, so do it for AsyncSSH.

Signed-off-by: Jan Lindemann <jan@janware.com>
This commit is contained in:
Jan Lindemann 2026-04-15 21:49:40 +02:00
commit 04fef1e67a

View file

@ -11,6 +11,7 @@ from .util import join_cmd
_USE_DEFAULT_KNOWN_HOSTS = object()
class AsyncSSH(Base):
def __init__(
self,
uri: str,
@ -32,6 +33,7 @@ class AsyncSSH(Base):
self.known_hosts = known_hosts
self.term_type = term_type or os.environ.get("TERM", "xterm")
self.connect_timeout = connect_timeout
self.__conn: asyncssh.SSHClientConnection|None = None
def _connect_kwargs(self) -> dict:
kwargs: dict = {
@ -81,7 +83,6 @@ class AsyncSSH(Base):
args[3] = merged_env or None
else:
kwargs["env"] = dict(mod_env) if mod_env else None
return tuple(args), kwargs
@staticmethod
@ -95,20 +96,31 @@ class AsyncSSH(Base):
def _get_local_term_size() -> tuple[int, int, int, int]:
cols, rows = shutil.get_terminal_size(fallback=(80, 24))
xpixel = ypixel = 0
try:
import fcntl, termios, struct
packed = fcntl.ioctl(sys.stdout.fileno(), termios.TIOCGWINSZ, b"\0" * 8)
rows2, cols2, xpixel, ypixel = struct.unpack("HHHH", packed)
if cols2 > 0 and rows2 > 0:
cols, rows = cols2, rows2
except Exception:
pass
return (cols, rows, xpixel, ypixel)
@property
async def _conn(self) -> asyncssh.SSHClientConnection:
if self.__conn is None:
self.__conn = await asyncssh.connect(**self._connect_kwargs())
return self.__conn
async def _close(self) -> None:
if self.__conn is not None:
try:
self.__conn.close()
await self.__conn.wait_closed()
except:
pass
self.__conn = None
async def _read_stream(
self,
stream,
@ -125,26 +137,24 @@ class AsyncSSH(Base):
chunk = await stream.read(4096)
if not chunk:
break
collector.append(chunk)
if verbose:
buf += chunk
while b"\n" in buf:
line, buf = buf.split(b"\n", 1)
log(prio, log_prefix, line.decode(log_enc, errors="replace"))
if verbose and buf:
log(prio, log_prefix, buf.decode(log_enc, errors="replace"))
async def _run_interactive_on_conn(
self,
conn: asyncssh.SSHClientConnection,
cmd: list[str],
wd: str | None,
cmd_input: bytes | None,
env: dict[str, str] | None,
) -> Result:
conn = await self._conn
command = self._build_remote_command(cmd, wd)
stdout_parts: list[bytes] = []
@ -194,10 +204,8 @@ class AsyncSSH(Base):
if cmd_input is not None and proc.stdin is not None:
proc.stdin.write(cmd_input)
await proc.stdin.drain()
while True:
data = await stdin_queue.get()
if data is None:
if proc.stdin is not None:
try:
@ -205,10 +213,8 @@ class AsyncSSH(Base):
except (BrokenPipeError, OSError):
pass
return
if proc.stdin is None:
return
proc.stdin.write(data)
await proc.stdin.drain()
@ -217,11 +223,11 @@ class AsyncSSH(Base):
chunk = await proc.stdout.read(4096)
if not chunk:
break
stdout_parts.append(chunk)
_write_local(chunk)
def _on_winch(*_args) -> None:
try:
proc.change_terminal_size(*self._get_local_term_size())
except Exception:
@ -300,7 +306,6 @@ class AsyncSSH(Base):
async def _run_captured_pty_on_conn(
self,
conn: asyncssh.SSHClientConnection,
cmd: list[str],
wd: str | None,
verbose: bool,
@ -308,6 +313,8 @@ class AsyncSSH(Base):
env: dict[str, str] | None,
log_prefix: str,
) -> Result:
conn = await self._conn
command = self._build_remote_command(cmd, wd)
stdout_parts: list[bytes] = []
@ -352,7 +359,6 @@ class AsyncSSH(Base):
async def _run_on_conn(
self,
conn: asyncssh.SSHClientConnection,
cmd: list[str],
wd: str | None,
verbose: bool,
@ -361,24 +367,25 @@ class AsyncSSH(Base):
interactive: bool,
log_prefix: str,
) -> Result:
conn = await self._conn
if interactive:
if self._has_local_tty():
return await self._run_interactive_on_conn(
conn,
cmd,
wd,
cmd_input,
env,
cmd=cmd,
wd=wd,
cmd_input=cmd_input,
env=env,
)
return await self._run_captured_pty_on_conn(
conn,
cmd,
wd,
verbose,
cmd_input,
env,
log_prefix,
cmd=cmd,
wd=wd,
verbose=verbose,
cmd_input=cmd_input,
env=env,
log_prefix=log_prefix,
)
command = self._build_remote_command(cmd, wd)
@ -451,9 +458,7 @@ class AsyncSSH(Base):
interactive: bool,
log_prefix: str,
) -> Result:
async with asyncssh.connect(**self._connect_kwargs()) as conn:
return await self._run_on_conn(
conn,
cmd,
wd,
verbose,
@ -471,10 +476,11 @@ class AsyncSSH(Base):
*args,
**kwargs,
) -> Result:
args, kwargs = self._merge_env_into_forwarded_args(args, kwargs, mod_env)
async with asyncssh.connect(**self._connect_kwargs()) as conn:
uid_result = await conn.run("id -u", check=False)
conn = await self._conn
uid_result = conn.run("id -u", check=False)
is_root = (
uid_result.exit_status == 0
and isinstance(uid_result.stdout, str)
@ -491,4 +497,4 @@ class AsyncSSH(Base):
cmdline.extend(cmd)
return await self._run_on_conn(conn, cmdline, *args, **kwargs)
return await self._run_on_conn(cmdline, *args, **kwargs)