# -*- coding: utf-8 -*- from __future__ import annotations import re, os, stat, pwd, grp, tempfile from contextlib import suppress from pathlib import Path from dataclasses import dataclass from typing import TYPE_CHECKING if TYPE_CHECKING: from typing import Iterable from ...lib.Distro import Distro from ..CmdDistro import CmdDistro from ...lib.log import * from ..Cmd import Cmd as Base class Cmd(Base): # export @dataclass class Attrs: mode: int | None = None owner: str | None = None group: str | None = None conf: str | None = None def __read_key_value_file(self, path: str) -> dict[str, str]: ret: dict[str, str] = {} with open(path, "r", encoding="utf-8") as f: for line in f: line = line.strip() if not line or line.startswith("#"): continue if "=" not in line: continue key, val = line.split("=", 1) key = key.strip() val = val.strip() if key: ret[key] = val return ret def __parse_attributes(self, content: str) -> Attrs: first_line = content.splitlines()[0] if content else "" if not re.match(r"^\s*#\s*conf\s*:", first_line): return None ret = Attrs() ret.conf = first_line m = re.match(r"^\s*#\s*conf\s*:\s*(.*?)\s*$", first_line) if not m: return ret for part in m.group(1).split(";"): part = part.strip() if not part or "=" not in part: continue key, val = part.split("=", 1) key = key.strip() val = val.strip() if key == "owner": ret.owner = val or None elif key == "group": ret.group = val or None elif key == "mode": if val: try: if re.fullmatch(r"0[0-7]+", val): ret.mode = int(val, 8) else: ret.mode = int(val, 0) except ValueError: ret.mode = None return ret def __read_attributes(self, paths: Iterable[str]) -> Attrs|None: for path in paths: try: with open(path, 'r') as f: return self.__parse_attributes(f.read(path)) except: continue def __format_metadata(self, uid: int, gid: int, mode: int) -> str: return f"{pwd.getpwuid(uid).pw_name}:{grp.getgrgid(gid).gr_name} {mode:o}" def __copy_securely( self, src: str, dst: str, default_attrs: Attrs | None, replace: dict[str, str] = [], ) -> None: owner = "root" group = "root" mode = 0o400 with open(src, "r", encoding="utf-8", newline="") as f: content = f.read() attrs = self.__parse_attributes(content) if attrs is None: attrs = default_attrs if attrs is not None: if attrs.owner is not None: owner = attrs.owner if attrs.group is not None: group = attrs.group if attrs.mode is not None: mode = attrs.mode new_uid = pwd.getpwnam(owner).pw_uid new_gid = grp.getgrnam(group).gr_gid new_meta = self.__format_metadata(new_uid, new_gid, mode) for key, val in replace.items(): content = content.replace(key, val) dst_dir = os.path.dirname(os.path.abspath(dst)) tmp_fd, tmp_path = tempfile.mkstemp( prefix=f".{os.path.basename(dst)}.", dir=dst_dir, ) try: os.fchown(tmp_fd, new_uid, new_gid) os.fchmod(tmp_fd, mode) with os.fdopen(tmp_fd, "w", encoding="utf-8", newline="") as f: tmp_fd = None f.write(content) f.flush() os.fsync(f.fileno()) content_changed = True metadata_changed = True old_meta = "" try: st = os.stat(dst) except FileNotFoundError: pass else: old_mode = stat.S_IMODE(st.st_mode) old_meta = self.__format_metadata(st.st_uid, st.st_gid, old_mode) with open(dst, "r", encoding="utf-8", newline="") as f: old_content = f.read() content_changed = old_content != content metadata_changed = ( st.st_uid != new_uid or st.st_gid != new_gid or old_mode != mode ) changes = [] if content_changed: changes.append("@content") if metadata_changed: changes.append(f"@metadata ({old_meta} -> {new_meta})") details = ", ".join(changes) if changes else "no changes" log(NOTICE, f"Applying macros in {src} to {dst}: {details}") if not changes: os.unlink(tmp_path) tmp_path = None return os.replace(tmp_path, dst) tmp_path = None dir_fd = os.open(dst_dir, os.O_DIRECTORY) try: os.fsync(dir_fd) finally: os.close(dir_fd) finally: if tmp_fd is not None: os.close(tmp_fd) if tmp_path is not None: with suppress(FileNotFoundError): os.unlink(tmp_path) async def _match_files(self, packages: Iterable[str], pattern: str) -> list[str]: ret: list[str] = [] for package_name in packages: for path in await self.distro.pkg_files(package_name): if re.match(pattern, path): ret.append(path) return ret async def _list_template_files(self, packages: Iterable[str]) -> list[str]: return await self._match_files(packages, pattern=r'.*\.jw-tmpl$') async def _list_secret_paths(self, packages: Iterable[str]) -> list[str]: return [str(Path(f).with_suffix(".jw-secret")) for f in await self._list_template_files(packages)] async def _list_compilation_targets(self, packages: Iterable[str]) -> list[str]: return [f.removesuffix('.jw-tmpl') for f in await self._list_template_files(packages)] async def _remove_compilation_targets(self, packages: Iterable[str]) -> list[str]: for path in await self._list_compilation_targets(packages): if os.path.exists(path): log(NOTICE, f'Removing {path}') os.unlink(path) async def _compile_template_files(self, packages: Iterable[str], attrs: str|Attrs=None) -> list[str]: for target in await self._list_compilation_targets(packages): default_attrs = self.__read_attributes([target + '.jw-tmpl']) if default_attrs is None: default_attrs = attrs secret = target + '.jw-secret' replace = [] if not os.path.exists(secret) else self.__read_key_value_file(secret) for ext in [ '.jw-secret-file', '.jw-tmpl' ]: src = target + ext if os.path.exists(src): self.__copy_securely(src=src, dst=target, default_attrs=default_attrs, replace=replace) break else: log(WARNING, f'No secret found for target {target}, not compiling') def __init__(self, parent: CmdDistro, name: str, help: str) -> None: super().__init__(parent, name, help) def add_arguments(self, parser: ArgumentParser) -> None: super().add_arguments(parser) parser.add_argument("packages", nargs='*', help="Package names")