from __future__ import annotations import os from typing import TYPE_CHECKING from ...base import InputMode from ...util import run_cmd from ..SSHClient import SSHClient as Base from .util import join_cmd if TYPE_CHECKING: from ...base import Input, Result class Exec(Base): def __init__(self, uri, *args, **kwargs) -> None: self.__askpass: str | None = None self.__askpass_orig: dict[str, str | None] = dict() super().__init__(uri = uri, caps = self.Caps.ModEnv, **kwargs) def __del__(self): for key, val in self.__askpass_orig.items(): if val is None: del os.environ[key] else: os.environ[key] = val if self.__askpass is not None: os.remove(self.__askpass) def __init_askpass(self): if self.__askpass is None and self.password is not None: import sys import tempfile prefix = os.path.basename(sys.argv[0]) + '-' f = tempfile.NamedTemporaryFile( mode = 'w+t', prefix = prefix, delete = False ) os.chmod(f.name, 0o0700) self.__askpass = f.name f.write(f'#!/bin/bash\n\necho -n "{self.password}\n"') f.close() for key, val in { 'SSH_ASKPASS': self.__askpass, 'SSH_ASKPASS_REQUIRE': 'force', }.items(): self.__askpass_orig[key] = os.getenv(key) os.environ[key] = val async def _run_ssh( self, cmd: list[str], wd: str | None, verbose: bool, cmd_input: bytes | None, mod_env: dict[str, str] | None, interactive: bool, log_prefix: str, ) -> Result: def __pub_cmd_input(cmd_input: bytes | None) -> Input: if cmd_input is None: if interactive: return InputMode.Interactive return InputMode.NonInteractive return cmd_input self.__init_askpass() opts: list[str] = [] if mod_env: for key, val in mod_env.items(): opts.extend(['-o', f'SetEnv {key}="{val}"']) if self.username: opts.extend(['-l', self.username]) if self.port is not None: opts.extend(['-p', str(self.port)]) return await run_cmd( ['ssh', *opts, self.hostname, join_cmd(cmd)], cmd_input = __pub_cmd_input(cmd_input), throw = False, )