lib.FileContext.open(): Add method

Add an async open() method which should allow to do what __init__()
couldn't, because it's not async, and to match the already existing
.close(). It's called by __aenter__() __aexit__() if the FileContext
is instantiated as context manager, or at will when the user finds it
a good idea.

Signed-off-by: Jan Lindemann <jan@janware.com>
This commit is contained in:
Jan Lindemann 2026-04-22 07:55:44 +02:00
commit 58f7997bc6
Signed by: jan
GPG key ID: 3750640C9E25DD61
3 changed files with 51 additions and 33 deletions

View file

@ -24,6 +24,7 @@ class FileContext(abc.ABC):
assert verbose_default is not None assert verbose_default is not None
async def __aenter__(self): async def __aenter__(self):
await self.open()
return self return self
async def __aexit__(self, exc_type, exc, tb): async def __aexit__(self, exc_type, exc, tb):
@ -38,6 +39,18 @@ class FileContext(abc.ABC):
return self.__root + path return self.__root + path
return self.__root + '/' + path return self.__root + '/' + path
async def _open(self) -> None:
pass
async def open(self) -> None:
await self._open()
async def _close(self) -> None:
pass
async def close(self) -> None:
await self._close()
@classmethod @classmethod
def schema_from_uri(cls, uri: str) -> str: def schema_from_uri(cls, uri: str) -> str:
tokens = re.split(r'://', uri) tokens = re.split(r'://', uri)
@ -225,12 +238,6 @@ class FileContext(abc.ABC):
async def is_dir(self, path: str) -> bool: async def is_dir(self, path: str) -> bool:
return self._is_dir(self._chroot(path)) return self._is_dir(self._chroot(path))
async def _close(self) -> None:
pass
async def close(self) -> None:
await self._close()
@classmethod @classmethod
def create(cls, uri: str, *args, **kwargs) -> Self: def create(cls, uri: str, *args, **kwargs) -> Self:
match cls.schema_from_uri(uri): match cls.schema_from_uri(uri):

View file

@ -35,6 +35,10 @@ class AsyncSSH(Base):
self.__connect_timeout = connect_timeout self.__connect_timeout = connect_timeout
self.__conn: asyncssh.SSHClientConnection|None = None self.__conn: asyncssh.SSHClientConnection|None = None
async def _open(self) -> None:
await super()._open()
await self._conn
async def _close(self) -> None: async def _close(self) -> None:
if self.__conn is not None: if self.__conn is not None:
try: try:

View file

@ -10,6 +10,9 @@ from ..SSHClient import SSHClient as Base
from .util import join_cmd from .util import join_cmd
if TYPE_CHECKING:
from typing import Any
class Paramiko(Base): class Paramiko(Base):
def __init__(self, uri, *args, **kwargs) -> None: def __init__(self, uri, *args, **kwargs) -> None:
@ -20,34 +23,40 @@ class Paramiko(Base):
**kwargs **kwargs
) )
self.__timeout: float|None = None # Untested self.__timeout: float|None = None # Untested
self.___ssh: Any|None = None self.___client: Any|None = None
def __ssh_connect(self):
ret = paramiko.SSHClient()
ret.set_missing_host_key_policy(paramiko.AutoAddPolicy())
try:
ret.connect(
hostname=self.hostname,
username=self.username,
allow_agent=True
)
except Exception as e:
log(ERR, f'Failed to connect to {self.hostname} ({str(e)})')
raise
s = ret.get_transport().open_session()
# set up the agent request handler to handle agent requests from the server
paramiko.agent.AgentRequestHandler(s)
return ret
@property @property
def __ssh(self): def __client(self) -> Any:
if self.___ssh is None: if self.___client is None:
self.___ssh = self.__ssh_connect() ret = paramiko.SSHClient()
return self.___ssh ret.set_missing_host_key_policy(paramiko.AutoAddPolicy())
try:
ret.connect(
hostname = self.hostname,
username = self.username,
allow_agent = True
)
except Exception as e:
log(ERR, f'Failed to connect to {self.hostname} ({str(e)})')
raise
s = ret.get_transport().open_session()
# set up the agent request handler to handle agent requests from the server
paramiko.agent.AgentRequestHandler(s)
self.___client = ret
return self.___client
@property @property
def __scp(self): def __scp(self) -> Any:
return SCPClient(self.__ssh.get_transport()) return SCPClient(self.__client.get_transport())
async def _open(self) -> None:
await super()._open()
self.__client
async def _close(self) -> None:
if self.___client is not None:
self.___client.close()
self.___client = None
async def _run_ssh( async def _run_ssh(
self, self,
@ -63,7 +72,7 @@ class Paramiko(Base):
kwargs: [str, Any] = {} kwargs: [str, Any] = {}
if mod_env is not None: if mod_env is not None:
kwargs['environment'] = mod_env kwargs['environment'] = mod_env
stdin, stdout, stderr = self.__ssh.exec_command( stdin, stdout, stderr = self.__client.exec_command(
join_cmd(cmd), join_cmd(cmd),
timeout=self.__timeout, timeout=self.__timeout,
**kwargs, **kwargs,
@ -75,5 +84,3 @@ class Paramiko(Base):
stdin.write(cmd_input) stdin.write(cmd_input)
exit_status = stdout.channel.recv_exit_status() exit_status = stdout.channel.recv_exit_status()
return Result(stdout.read(), stderr.read(), exit_status) return Result(stdout.read(), stderr.read(), exit_status)