diff --git a/src/python/jw/pkg/lib/FileContext.py b/src/python/jw/pkg/lib/FileContext.py index 76abe761..de6fbd5d 100644 --- a/src/python/jw/pkg/lib/FileContext.py +++ b/src/python/jw/pkg/lib/FileContext.py @@ -24,6 +24,7 @@ class FileContext(abc.ABC): assert verbose_default is not None async def __aenter__(self): + await self.open() return self async def __aexit__(self, exc_type, exc, tb): @@ -38,6 +39,18 @@ class FileContext(abc.ABC): return self.__root + path return self.__root + '/' + path + async def _open(self) -> None: + pass + + async def open(self) -> None: + await self._open() + + async def _close(self) -> None: + pass + + async def close(self) -> None: + await self._close() + @classmethod def schema_from_uri(cls, uri: str) -> str: tokens = re.split(r'://', uri) @@ -225,12 +238,6 @@ class FileContext(abc.ABC): async def is_dir(self, path: str) -> bool: return self._is_dir(self._chroot(path)) - async def _close(self) -> None: - pass - - async def close(self) -> None: - await self._close() - @classmethod def create(cls, uri: str, *args, **kwargs) -> Self: match cls.schema_from_uri(uri): diff --git a/src/python/jw/pkg/lib/ec/ssh/AsyncSSH.py b/src/python/jw/pkg/lib/ec/ssh/AsyncSSH.py index 5faab1f2..b0159e3c 100644 --- a/src/python/jw/pkg/lib/ec/ssh/AsyncSSH.py +++ b/src/python/jw/pkg/lib/ec/ssh/AsyncSSH.py @@ -35,6 +35,10 @@ class AsyncSSH(Base): self.__connect_timeout = connect_timeout self.__conn: asyncssh.SSHClientConnection|None = None + async def _open(self) -> None: + await super()._open() + await self._conn + async def _close(self) -> None: if self.__conn is not None: try: diff --git a/src/python/jw/pkg/lib/ec/ssh/Paramiko.py b/src/python/jw/pkg/lib/ec/ssh/Paramiko.py index 176ffb08..b60be30c 100644 --- a/src/python/jw/pkg/lib/ec/ssh/Paramiko.py +++ b/src/python/jw/pkg/lib/ec/ssh/Paramiko.py @@ -10,6 +10,9 @@ from ..SSHClient import SSHClient as Base from .util import join_cmd +if TYPE_CHECKING: + from typing import Any + class Paramiko(Base): def __init__(self, uri, *args, **kwargs) -> None: @@ -20,34 +23,40 @@ class Paramiko(Base): **kwargs ) self.__timeout: float|None = None # Untested - self.___ssh: Any|None = None - - def __ssh_connect(self): - ret = paramiko.SSHClient() - ret.set_missing_host_key_policy(paramiko.AutoAddPolicy()) - try: - ret.connect( - hostname=self.hostname, - username=self.username, - allow_agent=True - ) - except Exception as e: - log(ERR, f'Failed to connect to {self.hostname} ({str(e)})') - raise - s = ret.get_transport().open_session() - # set up the agent request handler to handle agent requests from the server - paramiko.agent.AgentRequestHandler(s) - return ret + self.___client: Any|None = None @property - def __ssh(self): - if self.___ssh is None: - self.___ssh = self.__ssh_connect() - return self.___ssh + def __client(self) -> Any: + if self.___client is None: + ret = paramiko.SSHClient() + ret.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + try: + ret.connect( + hostname = self.hostname, + username = self.username, + allow_agent = True + ) + except Exception as e: + log(ERR, f'Failed to connect to {self.hostname} ({str(e)})') + raise + s = ret.get_transport().open_session() + # set up the agent request handler to handle agent requests from the server + paramiko.agent.AgentRequestHandler(s) + self.___client = ret + return self.___client @property - def __scp(self): - return SCPClient(self.__ssh.get_transport()) + def __scp(self) -> Any: + return SCPClient(self.__client.get_transport()) + + async def _open(self) -> None: + await super()._open() + self.__client + + async def _close(self) -> None: + if self.___client is not None: + self.___client.close() + self.___client = None async def _run_ssh( self, @@ -63,7 +72,7 @@ class Paramiko(Base): kwargs: [str, Any] = {} if mod_env is not None: kwargs['environment'] = mod_env - stdin, stdout, stderr = self.__ssh.exec_command( + stdin, stdout, stderr = self.__client.exec_command( join_cmd(cmd), timeout=self.__timeout, **kwargs, @@ -75,5 +84,3 @@ class Paramiko(Base): stdin.write(cmd_input) exit_status = stdout.channel.recv_exit_status() return Result(stdout.read(), stderr.read(), exit_status) - -