lib.ExecContext: Support bytes-typed cmd_input

The Input instance passed as cmd_input to ExecContext.run() and
.sudo() currently may be of type str. Allow to pass bytes, too.

At the same time, disallow None to be passed as cmd_input. Force the
caller to be more explicit how it wants input to be handled, notably
with respect to interactivity.

Along the way fix a bug: Content in cmd_input should result in
CallContext.interactive == False but doesn't. Fix that.

Signed-off-by: Jan Lindemann <jan@janware.com>
This commit is contained in:
Jan Lindemann 2026-04-15 14:02:44 +02:00
commit 04b294917f
9 changed files with 46 additions and 29 deletions

View file

@ -120,7 +120,6 @@ class CmdBuild(Cmd): # export
wd=wd,
throw=True,
verbose=True,
cmd_input=None,
env=env,
title=title
)

View file

@ -43,7 +43,8 @@ class CmdListRepos(Cmd): # export
print('\n'.join(result.stdout.decode().splitlines()))
return
case 'https':
cmd_input = None
from jw.pkg.lib.ExecContext import InputMode
cmd_input = InputMode.NonInteractive
if re.match(r'https://github.com', args.base_url):
curl_args = [
'-f',

View file

@ -18,7 +18,7 @@ class InputMode(Enum):
OptInteractive = auto()
Auto = auto()
Input: TypeAlias = InputMode | None | str
Input: TypeAlias = InputMode | bytes | str
class Result(NamedTuple):
@ -61,6 +61,13 @@ class ExecContext(abc.ABC):
# -- At the end of this dance, interactive needs to be either True
# or False
interactive: bool|None = None
if not isinstance(cmd_input, InputMode):
interactive = False
self.__cmd_input = (
cmd_input if isinstance(cmd_input, bytes) else
cmd_input.encode(sys.stdout.encoding or "utf-8")
)
else:
match cmd_input:
case InputMode.Interactive:
interactive = True
@ -74,6 +81,7 @@ class ExecContext(abc.ABC):
interactive = parent.interactive
if interactive is None:
interactive = sys.stdin.isatty()
self.__cmd_input = None
assert interactive in [ True, False ]
self.__interactive = interactive
@ -106,7 +114,7 @@ class ExecContext(abc.ABC):
return self.__verbose
@property
def cmd_input(self) -> bool:
def cmd_input(self) -> bytes|None:
return self.__cmd_input
@property
@ -211,6 +219,10 @@ class ExecContext(abc.ABC):
In PTY mode stderr is always None because PTY merges stdout/stderr.
"""
# Note that in the calls to the wrapped method, cmd_input == None can
# be returned by CallContext and is very much allowed
assert cmd_input is not None
ret = Result(None, None, 1)
with self.CallContext(self, title=title, cmd=cmd, cmd_input=cmd_input, wd=wd,
log_prefix='|', throw=throw, verbose=verbose) as cc:
@ -246,6 +258,10 @@ class ExecContext(abc.ABC):
title: str=None,
) -> Result:
# Note that in the calls to the wrapped method, cmd_input == None can
# be returned by CallContext and is very much allowed
assert cmd_input is not None
ret = Result(None, None, 1)
if opts is None:
opts = {}

View file

@ -18,7 +18,7 @@ class Local(Base):
cmd: list[str],
wd: str|None,
verbose: bool,
cmd_input: str|None,
cmd_input: bytes|None,
env: dict[str, str]|None,
interactive: bool,
log_prefix: str
@ -116,7 +116,7 @@ class Local(Base):
]
if stdin is asyncio.subprocess.PIPE:
proc.stdin.write(cmd_input.encode(sys.stdout.encoding or "utf-8"))
proc.stdin.write(cmd_input)
await proc.stdin.drain()
proc.stdin.close()
@ -134,7 +134,7 @@ class Local(Base):
async def _sudo(self, cmd: list[str], mod_env: dict[str, str], opts: list[str], *args, **kwargs) -> Result:
env: dict[str, str]|None = None
cmd_input: str|None = None
cmd_input: bytes|None = None
if mod_env:
env = os.environ.copy()
env.update(mod_env)

View file

@ -39,7 +39,7 @@ class SSHClient(ExecContext):
cmd: list[str],
wd: str|None,
verbose: bool,
cmd_input: str|None,
cmd_input: bytes|None,
env: dict[str, str]|None,
interactive: bool,
log_prefix: str
@ -51,7 +51,7 @@ class SSHClient(ExecContext):
cmd: list[str],
wd: str|None,
verbose: bool,
cmd_input: str|None,
cmd_input: bytes|None,
env: dict[str, str]|None,
interactive: bool,
log_prefix: str

View file

@ -142,7 +142,7 @@ class AsyncSSH(Base):
conn: asyncssh.SSHClientConnection,
cmd: list[str],
wd: str | None,
cmd_input: str | None,
cmd_input: bytes | None,
env: dict[str, str] | None,
) -> Result:
command = self._build_remote_command(cmd, wd)
@ -192,7 +192,7 @@ class AsyncSSH(Base):
async def _pump_stdin() -> None:
if cmd_input is not None and proc.stdin is not None:
proc.stdin.write(cmd_input.encode(sys.stdout.encoding or "utf-8"))
proc.stdin.write(cmd_input)
await proc.stdin.drain()
while True:
@ -304,7 +304,7 @@ class AsyncSSH(Base):
cmd: list[str],
wd: str | None,
verbose: bool,
cmd_input: str | None,
cmd_input: bytes | None,
env: dict[str, str] | None,
log_prefix: str,
) -> Result:
@ -336,7 +336,7 @@ class AsyncSSH(Base):
)
if cmd_input is not None and proc.stdin is not None:
proc.stdin.write(cmd_input.encode(sys.stdout.encoding or "utf-8"))
proc.stdin.write(cmd_input)
await proc.stdin.drain()
proc.stdin.write_eof()
@ -356,7 +356,7 @@ class AsyncSSH(Base):
cmd: list[str],
wd: str | None,
verbose: bool,
cmd_input: str | None,
cmd_input: bytes | None,
env: dict[str, str] | None,
interactive: bool,
log_prefix: str,
@ -425,7 +425,7 @@ class AsyncSSH(Base):
]
if cmd_input is not None and proc.stdin is not None:
proc.stdin.write(cmd_input.encode(sys.stdout.encoding or "utf-8"))
proc.stdin.write(cmd_input)
await proc.stdin.drain()
proc.stdin.write_eof()

View file

@ -38,7 +38,7 @@ class Exec(Base):
self.__askpass_orig[key] = os.getenv(key)
os.environ[key] = val
async def _run_ssh(self, cmd: list[str], cmd_input: str|None, *args, **kwargs) -> Result:
async def _run_ssh(self, cmd: list[str], cmd_input: bytes|None, *args, **kwargs) -> Result:
self.__init_askpass()
return await run_cmd(['ssh', self.hostname, join_cmd(cmd)], cmd_input=cmd_input, interactive=self.interactive, throw=False)

View file

@ -44,7 +44,7 @@ class Paramiko(Base):
def __scp(self):
return SCPClient(self.__ssh.get_transport())
async def _run_ssh(self, cmd: list[str], cmd_input: str|None, *args, **kwargs) -> Result:
async def _run_ssh(self, cmd: list[str], cmd_input: bytes|None, *args, **kwargs) -> Result:
try:
stdin, stdout, stderr = self.__ssh.exec_command(join_cmd(cmd), timeout=self.__timeout)
except Exception as e:

View file

@ -15,6 +15,7 @@ from urllib.parse import urlparse
from enum import Enum, auto
from .log import *
from .ExecContext import InputMode
class AskpassKey(Enum):
Username = auto()
@ -40,7 +41,7 @@ async def run_cmd(*args, ec: ExecContext|None=None, verbose: bool|None=None, int
ec = Local(verbose_default=verbose, interactive=interactive)
return await ec.run(verbose=verbose, *args, **kwargs)
async def run_curl(args: list[str], parse_json: bool=False, wd=None, throw=None, verbose=None, cmd_input=None, ec: ExecContext|None=None, decode=False) -> dict|str: # export
async def run_curl(args: list[str], parse_json: bool=False, wd=None, throw=None, verbose=None, cmd_input=InputMode.NonInteractive, ec: ExecContext|None=None, decode=False) -> dict|str: # export
if verbose is None:
verbose = False if ec is None else ec.verbose_default
cmd = ['curl']