Makefile: Include py-topdir.mk

Include py-topdir.mk, which entails loads of fallout from make check. Fix it.

Signed-off-by: Jan Lindemann <jan@janware.com>
This commit is contained in:
Jan Lindemann 2026-06-10 13:28:35 +02:00
commit e9845b5a1f
Signed by: Jan Lindemann
GPG key ID: 3750640C9E25DD61
45 changed files with 1796 additions and 1191 deletions

View file

@ -2,3 +2,4 @@ TOPDIR = .
include $(TOPDIR)/make/proj.mk
include $(JWBDIR)/make/topdir.mk
include $(JWBDIR)/make/py-topdir.mk

View file

@ -1,20 +1,17 @@
# -*- coding: utf-8 -*-
import argparse
from collections import OrderedDict
from .log import slog
from .log import get_caller_pos, slog
class ArgsContainer: # export
__args: OrderedDict[str, str] = OrderedDict()
__kwargs: OrderedDict[str, str] = OrderedDict()
__kwargs: OrderedDict[str, dict[str, str]] = OrderedDict()
__values: dict[str, str] = {}
__specified_args: list[str] = list()
def __getattr__(self, name):
values = self.__values
def __getattr__(self, name: str) -> str:
if name in self.__values:
return self.__values[name]
if name in self.__kwargs.keys():
@ -25,7 +22,7 @@ class ArgsContainer: # export
raise Exception(f'No argument "{name}" defined')
def __setattr__(self, name, value):
if not name in self.__kwargs.keys():
if name not in self.__kwargs.keys():
raise Exception(f'No argument "{name}" defined')
self.__values[name] = value
self.__specified_args.append(name)
@ -41,7 +38,7 @@ class ArgsContainer: # export
else:
raise Exception('Missing argument name')
name = name.replace('-', '_')
self.__args[name] = args
self.__args[name] = arg
self.__kwargs[name] = kwargs
def keys(self):
@ -50,7 +47,7 @@ class ArgsContainer: # export
def args(self, name) -> str:
return self.__args[name]
def kwargs(self, name) -> str:
def kwargs(self, name) -> dict[str, str]:
return self.__kwargs[name]
def dump(self, prio, *args, **kwargs):
@ -59,7 +56,7 @@ class ArgsContainer: # export
val = None
try:
val = self.__getattr__(name)
except:
except Exception:
pass
slog(prio, f'{name}: {val}', caller = caller)
@ -67,7 +64,9 @@ class ArgsContainer: # export
def specified_args(self):
return self.__specified_args
def add_argument(p: argparse.ArgumentParser|ArgsContainer, name: str, *args, **kwargs): # export
def add_argument( # export
p: argparse.ArgumentParser | ArgsContainer, name: str, *args, **kwargs
):
key = name.strip('--').replace('-', '_')
if isinstance(p, ArgsContainer):

View file

@ -1,5 +1,3 @@
# -*- coding: utf-8 -*-
from typing import Any
class Bunch: # export

View file

@ -1,11 +1,19 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
import inspect, sys, re, abc, argparse
from argparse import ArgumentParser, _SubParsersAction
import abc
import argparse
import inspect
import re
import sys
from argparse import ArgumentParser
from typing import TYPE_CHECKING
from . import log
if TYPE_CHECKING:
from .Cmds import Cmds
# full blown example of one level of nested subcommands
# git -C project remote -v show -n myremote
@ -16,7 +24,6 @@ class Cmd(abc.ABC): # export
pass
def __init__(self, name: str, help: str) -> None:
from . import Cmds
self.name = name
self.help = help
self.parent = None
@ -28,8 +35,11 @@ class Cmd(abc.ABC): # export
pass
def add_parser(self, parsers) -> ArgumentParser:
r = parsers.add_parser(self.name, help=self.help,
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
r = parsers.add_parser(
self.name,
help = self.help,
formatter_class = argparse.ArgumentDefaultsHelpFormatter
)
r.set_defaults(func = self.run)
return r

View file

@ -1,10 +1,18 @@
# -*- coding: utf-8 -*-
import argparse
import asyncio
import cProfile
import importlib
import inspect
import os
import re
import sys
import os, sys, argcomplete, argparse, importlib, inspect, re, pickle, asyncio, cProfile
from argparse import ArgumentParser
from pathlib import Path, PurePath
from .log import *
import argcomplete
from .log import DEBUG, ERR, NOTICE, add_log_file, set_flags, set_level, slog
from .stree import serdes
class Cmds: # export
@ -15,7 +23,7 @@ class Cmds: # export
except Exception as e:
slog(ERR, f'Failed to instantiate command of type {cls}: {e}')
raise
r.cmds = self # TODO: Rename Cmds class to App, "Cmds" isn't very self-explanatory
r.cmds = self # TODO: Rename Cmds class to App, "Cmds" isn't self-explanatory
r.app = self
return r
@ -26,7 +34,9 @@ class Cmds: # export
for c in cmd.child_classes:
cmd.children.append(self.__instantiate(c))
if len(cmd.children) > 0:
subparsers = parser.add_subparsers(title='Available subcommands of ' + cmd.name, metavar='')
subparsers = parser.add_subparsers(
title = 'Available subcommands of ' + cmd.name, metavar = ''
)
for sub_cmd in cmd.children:
self.__add_cmd_to_parser(sub_cmd, subparsers)
@ -38,7 +48,13 @@ class Cmds: # export
slog(DEBUG, 'Reading configuration "{}"'.format(path))
return serdes.read(path, ''), [path]
def __init__(self, description: str = '', filter: str = '^Cmd.*', modules: None=None, eloop: None=None) -> None:
def __init__(
self,
description: str = '',
filter: str = '^Cmd.*',
modules: None = None,
eloop: None = None
) -> None:
self.__description = description
self.__filter = filter
self.__modules = modules
@ -68,14 +84,30 @@ class Cmds: # export
set_flags(log_flags)
set_level(log_level)
slog(DEBUG, "set log level to {}".format(log_level))
self.__parser = argparse.ArgumentParser(usage=os.path.basename(sys.argv[0]) + ' [options]',
formatter_class=argparse.ArgumentDefaultsHelpFormatter, description=self.__description)
self.__parser.add_argument('--log-flags', help='Log flags', default=log_flags)
self.__parser.add_argument('--log-level', help='Log level', default=log_level)
self.__parser.add_argument('--backtrace', help='Show exception backtraces', action='store_true', default=False)
self.__parser.add_argument('--write-profile', help='Profile code and store output to file', default=None)
self.__parser = argparse.ArgumentParser(
usage = os.path.basename(sys.argv[0]) + ' [options]',
formatter_class = argparse.ArgumentDefaultsHelpFormatter,
description = self.__description
)
self.__parser.add_argument(
'--log-flags', help = 'Log flags', default = log_flags
)
self.__parser.add_argument(
'--log-level', help = 'Log level', default = log_level
)
self.__parser.add_argument(
'--backtrace',
help = 'Show exception backtraces',
action = 'store_true',
default = False
)
self.__parser.add_argument(
'--write-profile',
help = 'Profile code and store output to file',
default = None
)
self.__parser.add_argument('--log-file', help = 'Log file', default = log_file)
if self.__modules == None:
if self.__modules is None:
self.__modules = ['__main__']
subcmds = set()
slog(DEBUG, '-- searching for commands')
@ -96,7 +128,9 @@ class Cmds: # export
subcmds.update(cmd.child_classes)
cmds = [cmd for cmd in self.__cmds if type(cmd) not in subcmds]
subparsers = self.__parser.add_subparsers(title='Available commands', metavar='')
subparsers = self.__parser.add_subparsers(
title = 'Available commands', metavar = ''
)
for cmd in cmds:
slog(DEBUG, f'Adding top-level command {cmd} to parser')
self.__add_cmd_to_parser(cmd, subparsers)
@ -113,7 +147,8 @@ class Cmds: # export
self.__back_trace = self.args.backtrace
exit_status = 0
# This is the toplevel parser, i.e. no func member has been added to the args via
# This is the toplevel parser, i.e. no func member has been added to the args
# yet, via
#
# Cmds.__init__()
# Cmds.__add_cmd_to_parser(cmd, subparsers)
@ -135,17 +170,15 @@ class Cmds: # export
add_log_file(self.args.log_file)
try:
ret = await self._run(self.args)
await self._run(self.args)
except Exception as e:
if hasattr(e, 'message'):
slog(ERR, e.message)
else:
slog(ERR, f'Exception: {type(e)}: {e}')
slog(ERR, f'Exception: {type(e)}: {str(e)}')
exit_status = 1
if self.__back_trace:
raise
finally:
if pr is not None:
assert self.args.write_profile is not None, 'args.write_profile'
pr.disable()
slog(NOTICE, f'Writing profile statistics to {self.args.write_profile}')
pr.dump_stats(self.args.write_profile)
@ -173,6 +206,8 @@ class Cmds: # export
#return self.__run()
return self.eloop.run_until_complete(self.__run(argv)) # type: ignore
def run_sub_commands(description = '', filter = '^Cmd.*', modules=None, argv=None): # export
def run_sub_commands( # export
description = '', filter = '^Cmd.*', modules = None, argv = None
):
cmds = Cmds(description, filter, modules)
return cmds.run(argv = argv)

View file

@ -1,12 +1,14 @@
# -*- coding: utf-8 -*-
import glob
import os
import re
import sys
from typing import Optional, Dict, cast
import os, re, glob, sys
from pathlib import Path, PosixPath
from pathlib import Path
from typing import Dict, Optional, cast
from . import stree
from .stree import serdes
from .log import DEBUG, ERR, slog, get_caller_pos
from .stree.StringTree import StringTree
from .log import *
class Config(): # export
@ -40,7 +42,7 @@ class Config(): # export
for f in glob.glob(g):
slog(DEBUG, 'Reading config "{}"'.format(f))
paths_buf = []
tree = stree.read(f, paths_buf=paths_buf)
tree = serdes.read(f, paths_buf = paths_buf)
assert (len(paths_buf))
if refuse_mode_mask is not None:
for p in paths_buf:
@ -49,18 +51,25 @@ class Config(): # export
for item in tree.child_list():
if item.content is None:
continue
if not re.search('password|secret', cast(str, item.content), flags=re.IGNORECASE):
if not re.search('password|secret',
cast('str', item.content),
flags = re.IGNORECASE):
continue
msg = "Config files define secret, but at least one has file permissions open for world"
msg = (
'Config files define secret, but at least one '
'has file permissions open for world'
)
slog(ERR, f'{msg}:')
for pp in paths_buf:
slog(ERR, f' {((os.stat(pp).st_mode) & 0o7777):o} {pp}')
mode = (os.stat(pp).st_mode) & 0o7777
slog(ERR, f' {mode:o} {pp}')
raise Exception(msg)
tree.dump(DEBUG, f)
ret.add("", tree)
return ret
def __init__(self,
def __init__(
self,
search_dirs: Optional[list[str]] = None,
glob_paths: Optional[list[str]] = None,
glob_paths_env_key: Optional[str] = None,
@ -87,8 +96,11 @@ class Config(): # export
glob_paths = []
glob_paths.extend(glob_paths_env.split(':'))
self.__conf = self.__load(search_dirs=search_dirs, glob_paths=glob_paths,
refuse_mode_mask=refuse_mode_mask)
self.__conf = self.__load(
search_dirs = search_dirs,
glob_paths = glob_paths,
refuse_mode_mask = refuse_mode_mask
)
if root_section is not None:
tmp = self.__conf.get(root_section)
@ -141,7 +153,11 @@ class Config(): # export
def value(self, key: str, default = None) -> Optional[str]:
return self.get(key, default)
def branch(self, path: str, throw: bool=True): # type: ignore # Optional[Config]: FIXME: Don't know how to get hold of this type here
def branch(
self,
path: str,
throw: bool = True
): # type: ignore # Optional[Config]: FIXME: Don't know how to get hold of this type here
if self.__conf:
tree = self.__conf.get(path)
if tree is None:
@ -162,7 +178,12 @@ class Config(): # export
def name(self):
return self.__conf.content
def find(self, key: str|None, val: str|None, match:StringTree.Match=StringTree.Match.Equal) -> list[str]:
def find(
self,
key: str | None,
val: str | None,
match: StringTree.Match = StringTree.Match.Equal
) -> list[str]:
return self.__conf.find(key, val, match = match)
#def __getattr__(self, name: str):

View file

@ -1,5 +1,3 @@
# -*- coding: utf-8 -*-
class CppState: # export
def __init__(self):
@ -63,7 +61,9 @@ class CppState: # export
self.in_c_comment = True
self.things.append(self.__pair_c_comment)
elif tok == '*/':
raise Exception("Unmatched closing C-style comment mark", tok, "in line", line)
raise Exception(
"Unmatched closing C-style comment mark", tok, "in line", line
)
else:
if self.in_cpp_comment:
if tok == '\n':
@ -101,4 +101,3 @@ class CppState: # export
def is_optional(self):
return self.in_list() or self.in_option()

View file

@ -1,6 +1,4 @@
TOPDIR = ../../../..
PY_UPDATE_INIT_PY ?= false
include $(TOPDIR)/make/proj.mk
include $(JWBDIR)/make/py-mod.mk

View file

@ -1,5 +1,3 @@
# -*- coding: utf-8 -*-
from __future__ import print_function
from . import log
@ -7,10 +5,10 @@ from . import log
class Object(object): # export
def __init__(self):
self.log_level = log.level
self.log_level = log.log_level()
def log(self, prio, *args):
if self.log_level == log.level:
if self.log_level == log.log_level():
log.slog(prio, args)
return
if prio <= self.log_level:

View file

@ -1,10 +1,12 @@
import re
import json
from collections import OrderedDict
from .log import *
import re
import shlex
import traceback
from collections import OrderedDict
from .log import ERR, get_caller_pos, slog, slog_m
class Options: # export
class OrderedData:
@ -31,7 +33,7 @@ class Options: # export
spec = '{' + spec + '}'
try:
return json.loads(spec, object_pairs_hook = cls)
except:
except Exception:
pass
return None
@ -42,7 +44,7 @@ class Options: # export
r = cls()
try:
opt_strs = shlex.split(opts_str)
except Exception as e:
except Exception:
slog_m(ERR, traceback.format_exc())
slog(ERR, 'Failed to split options string >{}<'.format(opts_str))
raise
@ -52,7 +54,7 @@ class Options: # export
lhs = sides[0].strip()
if not len(lhs):
continue
if self.__allowed_keys and not lhs in self.__allowed_keys:
if self.__allowed_keys and lhs not in self.__allowed_keys:
raise Exception('Field "{}" not supported'.format(lhs))
rhs = ' '.join(sides[1:]).strip() if len(sides) > 1 else self.__true_val
if cls == OrderedDict:
@ -82,7 +84,7 @@ class Options: # export
self.__str = self.__str__()
def __getitem__(self, key):
if not key in self.__dict.keys():
if key not in self.__dict.keys():
return None
return self.__dict[key]
@ -99,24 +101,27 @@ class Options: # export
return len(self.__data.pairs)
def __contains__(self, keys):
if not type(keys) in [list, set]:
if type(keys) not in [list, set]:
return keys in self.__dict.keys()
for key in keys:
if not key in self.__dict.keys():
if key not in self.__dict.keys():
return False
return True
def __iter__(self):
return iter(self.__list)
def __next__(self):
return next(self.__list)
#def __next__(self):
# return next(self.__list)
def __init__(self, spec=None, delimiter=',', allowed_keys=None, true_val=True):
def __init__(
self, spec = None, delimiter = ',', allowed_keys = None, true_val = True
):
self.__true_val = true_val
self.__allowed_keys = None
self.__delimiter = delimiter
self.__data = self.OrderedData() if spec is None else self.__parse(spec, self.OrderedData)
self.__data = self.OrderedData(
) if spec is None else self.__parse(spec, self.OrderedData)
self.__dict = {}
#self.__dict = OrderedDict() if spec is None else self.__parse(spec,OrderedDict)
self.__list = []
@ -138,20 +143,26 @@ class Options: # export
def get(self, key, default = None, by_index = False):
if by_index:
if type(key) != int:
raise KeyError('Tried to get value from options string with ' +
'index {} of type "{}": {}'.format(key, type(key), str(self)))
if isinstance(key, int):
raise KeyError(
'Tried to get value from options string with ' +
'index {} of type "{}": {}'.format(key, type(key), str(self))
)
if key >= len(self.__data.pairs):
if default is not None:
return default
raise KeyError('Tried to get value from options string with ' +
'index {} of {}: {}'.format(key, len(self.__data.pairs), str(self)))
raise KeyError(
'Tried to get value from options string with ' +
'index {} of {}: {}'.format(key, len(self.__data.pairs), str(self))
)
return self.__list[key]
if key in self.__dict.keys():
return self.__dict[key]
if default is not None:
return default
raise KeyError('Key "{}" is not present in options string: {}'.format(key, str(self)))
raise KeyError(
'Key "{}" is not present in options string: {}'.format(key, str(self))
)
def update(self, rhs):
if hasattr(rhs, 'items'):
@ -159,9 +170,13 @@ class Options: # export
self.__dict[key] = val
return
if isinstance(rhs, str):
self.update(self.__parse(rhs))
self.update(self.__parse(rhs, self.OrderedData))
return
raise Exception('Tried to update options with object of incompatible type {}'.format(type(rhs)))
raise Exception(
'Tried to update options with object of incompatible type {}'.format(
type(rhs)
)
)
def append_to(self, obj):
for opt in self.__list:

View file

@ -1,9 +1,15 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
import signal
from abc import ABC, abstractmethod
from enum import Enum, Flag, auto
from typing import List
from typing import TYPE_CHECKING
from .log import ERR, slog
if TYPE_CHECKING:
from .Signals import Signals
def _sigchld_handler(signum, process):
if not signum == signal.SIGCHLD:
@ -12,7 +18,7 @@ def _sigchld_handler(signum, process):
class Process(ABC): # export
__processes: List[Process] = []
__processes: set[Process] = set()
class State(Enum):
Running = auto()
@ -23,19 +29,18 @@ class Process(ABC): # export
FailOnExitWithoutShutdown = auto()
def __init__(self):
self.__state = Running
self.__flags = Flags.FailOnExitWithoutShutdown
self.__state = Process.State.Running
self.__flags = self.Flags.FailOnExitWithoutShutdown
if len(self.__processes) == 0:
self._signals().add_handler(signals.SIGCHLD, _sigchld_handler)
self.signals().add_handler(signal.SIGCHLD, _sigchld_handler)
self.__processes.add(self)
@classmethod
def propagate_signal(cls, signum):
for p in cls.__processes:
p.__signal(signum)
cls.signals().propagate(signum)
def signal(self, signum):
if signum == signals.SIGCHLD:
if signum == signal.SIGCHLD:
self.exited()
@abstractmethod
@ -44,7 +49,7 @@ class Process(ABC): # export
@classmethod
@abstractmethod
def signals(cls):
def signals(cls) -> Signals:
pass
# to be reimplemented
@ -56,17 +61,15 @@ class Process(ABC): # export
return str(self._pid())
def request_shutdown(self):
if not self.__state == Shutdown:
self.__state = Shutdown
if not self.__state == Process.State.Shutdown:
self.__state = Process.State.Shutdown
self._request_shutdown()
def exited(self):
if self.__state == Process.State.Running:
slog(ERR, 'process "{}" exited unexpectedly'.format(process.name()))
if __flags & Process.Flags.FailOnExitWithoutShutdown:
slog(ERR, 'process exited unexpectedly')
if self.__flags & Process.Flags.FailOnExitWithoutShutdown:
slog(ERR, 'exiting')
exit(1)
self.__state = Process.State.Done
self.__processes.erase(self)
if len(self.__processes) == 0:
self._signals().remove_handler(signals.SIGCHLD) # FIXME: broken logic
self.__processes.remove(self)

View file

@ -1,8 +1,8 @@
# -*- coding: utf-8 -*-
from __future__ import print_function
import os, io, sys, traceback
import os
import io
import sys
from fcntl import fcntl, F_GETFL, F_SETFL
class RedirectStdIO: # export

View file

@ -1,19 +1,10 @@
# -*- coding: utf-8 -*-
from typing import Dict, Callable
from abc import ABC, abstractmethod
_handled_signals: Dict[int, Callable] = {}
def _signal_handler(signal, frame):
if not signal in _handled_signals.keys():
return
for h in _handled_signals[signal]:
h.func(signal, *h.args)
from abc import abstractmethod
from typing import Dict
class Signals:
class Handler:
def __init__(self, func, args):
self.func = func
self.args = args
@ -23,15 +14,27 @@ class Signals:
@classmethod
@abstractmethod
def _add_handler(self, signal, handler):
def _add_handler(cls, signal, handler):
raise Exception("_add_handler() is not reimplemented")
@classmethod
def add_handler(cls, signals, handler, *args):
for signal in signals:
h = Signals.Handler(handler, args)
if not signal in _handled_signals.keys():
if signal not in _handled_signals.keys():
_handled_signals[signal] = [h]
cls._add_signal_handler(signal, _signal_handler)
cls._add_handler(signal, _signal_handler)
else:
_handled_signals[signal].add(h)
_handled_signals[signal].append(h)
@classmethod
def propagate(cls, signal):
_signal_handler(signal, None)
_handled_signals: Dict[int, list[Signals.Handler]] = {}
def _signal_handler(signal, frame):
if signal not in _handled_signals.keys():
return
for h in _handled_signals[signal]:
h.func(signal, *h.args)

View file

@ -1,8 +1,6 @@
# -*- coding: utf-8 -*-
from datetime import datetime
from .log import *
from .log import get_caller_pos, slog
class StopWatch: # export
@ -21,5 +19,9 @@ class StopWatch: # export
else:
msg = '------------------ '
caller = kwargs['caller'] if 'caller' in kwargs.keys() else get_caller_pos(1)
slog(prio, '{} {} {}'.format(self.name, str(now - self.__last), msg), caller=caller)
slog(
prio,
'{} {} {}'.format(self.name, str(now - self.__last), msg),
caller = caller
)
self.__last = now

View file

@ -1,3 +0,0 @@
from pkgutil import extend_path
__path__ = extend_path(__path__, __name__)

View file

@ -1,9 +1,8 @@
# -*- coding: utf-8 -*-
import re
import re, shlex
from collections import namedtuple
from ..log import *
from ..log import DEBUG, get_caller_pos, prio_gets_logged, slog
L, R = 'Left Right'.split()
ARG, KEYW, QUOTED, LPAREN, RPAREN = 'arg kw quoted ( )'.split()
@ -83,16 +82,20 @@ class ShuntingYard(object): # export
regex = regex[1:]
scanner = re.Scanner([
scanner = re.Scanner( # pyright: ignore[reportAttributeAccessIssue]
[
(regex, lambda scanner, token: (KEYW, token)),
(r"\"[^\"]*\"|'[^']*'", lambda scanner, token: (QUOTED, token[1:-1])),
(r"[^\s()]+", lambda scanner, token: (ARG, token)),
(r"\s+", None), # None == skip token.
])
]
)
tokens, remainder = scanner.scan(spec)
if len(remainder) > 0:
raise Exception("Failed to tokenize " + spec + ", remaining bit is ", remainder)
raise Exception(
"Failed to tokenize " + spec + ", remaining bit is ", remainder
)
#self.debug(tokens)
return tokens
@ -112,7 +115,7 @@ class ShuntingYard(object): # export
tokenized = self.tokenize(infix)
self.debug("tokenized = ", tokenized)
outq, stack = [], []
table = ['TOKEN,ACTION,RPN OUTPUT,OP STACK,NOTES'.split(',')]
table = ['TOKEN', 'ACTION', 'RPN OUTPUT', ('OP STACK', ), 'NOTES']
for toktype, token in tokenized:
self.debug("Checking token", token)
note = action = ''
@ -127,7 +130,9 @@ class ShuntingYard(object): # export
note = 'Pop ops from stack to output'
while stack:
t2, op2 = stack[-1]
if (op1.assoc == L and op1.prec <= op2.prec) or (op1.assoc == R and op1.prec < op2.prec):
if (op1.assoc == L
and op1.prec <= op2.prec) or (op1.assoc == R
and op1.prec < op2.prec):
if t1 != RPAREN:
if t2 != LPAREN:
stack.pop()
@ -143,7 +148,9 @@ class ShuntingYard(object): # export
else:
stack.pop()
action = '(Pop & discard "(")'
table.append( (v, action, outq, (s[0] for s in stack), note) )
table.append(
(v, action, outq, (s[0] for s in stack), note)
)
break
table.append((v, action, (outq), (s[0] for s in stack), note))
v = note = ''
@ -169,11 +176,23 @@ class ShuntingYard(object): # export
v = note = ''
if self.do_debug:
maxcolwidths = [len(max(x, key = len)) for x in [zip(*table)]]
caller = get_caller_pos()
get_caller_pos()
row = table[0]
slog(DEBUG, ' '.join('{cell:^{width}}'.format(width=width, cell=cell) for (width, cell) in zip(maxcolwidths, row)))
slog(
DEBUG,
' '.join(
'{cell:^{width}}'.format(width = width, cell = cell)
for (width, cell) in zip(maxcolwidths, row)
)
)
for row in table[1:]:
slog(DEBUG, ' '.join('{cell:<{width}}'.format(width=width, cell=cell) for (width, cell) in zip(maxcolwidths, row)))
slog(
DEBUG,
' '.join(
'{cell:<{width}}'.format(width = width, cell = cell)
for (width, cell) in zip(maxcolwidths, row)
)
)
return table[-1][2]
def infix_to_postfix_orig(self, infix):
@ -185,7 +204,7 @@ class ShuntingYard(object): # export
for tokinfo in tokens:
self.debug(tokinfo)
toktype, token = tokinfo[0], tokinfo[1]
_toktype, token = tokinfo[0], tokinfo[1]
self.debug("Checking token ", token)
@ -204,7 +223,8 @@ class ShuntingYard(object): # export
topToken = s.pop()
continue
while (not s.isEmpty()) and (self.__ops[s.peek()].prec >= self.__ops[token].prec):
while (not s.isEmpty()) and (self.__ops[s.peek()].prec
>= self.__ops[token].prec):
#self.debug(token)
r.append(s.pop())
#self.debug(r)
@ -240,7 +260,9 @@ class ShuntingYard(object): # export
args.append(vals.pop())
#self.debug("running %s(%s)" % (token, ', '.join(reversed(args))))
val = op.func(*reversed(args))
self.debug("%s(%s) = %s" % (token, ', '.join(map(str, reversed(args))), val))
self.debug(
"%s(%s) = %s" % (token, ', '.join(map(str, reversed(args))), val)
)
vals.push(val)
return vals.pop()
@ -266,16 +288,16 @@ if __name__ == '__main__':
# return string.split()
def f_mult(self, a, b):
return str(atof(a) * atof(b));
return str(atof(a) * atof(b))
def f_div(self, a, b):
return str(atof(a) / atof(b));
return str(atof(a) / atof(b))
def f_add(self, a, b):
return str(atof(a) + atof(b));
return str(atof(a) + atof(b))
def f_sub(self, a, b):
return str(atof(a) - atof(b));
return str(atof(a) - atof(b))
def __init__(self):
Op = Operator

View file

@ -1,5 +1,3 @@
# -*- coding: utf-8 -*-
from abc import abstractmethod
from ..Process import Process as ProcessBase
@ -10,7 +8,7 @@ class Process(ProcessBase): # export
__signals = Signals()
def __init__(self, aio):
super().__init()
super().__init__()
self.aio = aio
@classmethod

View file

@ -1,5 +1,7 @@
import asyncio
from ..log import *
import re
from ..log import DEBUG, ERR, INFO, WARNING, slog
# FIXME: Derive this from Process, or merge the classes entirely
@ -56,12 +58,19 @@ class ShellCmd: # export
return r[1:]
try:
slog(INFO, "Running shell command [{}]: {}".format(self.__name, format_cmdline(self.__cmdline)))
slog(
INFO,
"Running shell command [{}]: {}".format(
self.__name, format_cmdline(self.__cmdline)
)
)
self.__transport, self.__protocol = await self.__eloop.subprocess_exec(
lambda: self.SubprocessProtocol(self, self.__name),
*self.__cmdline,
)
self.__proc = self.__transport.get_extra_info('subprocess') # Popen instance
self.__proc = self.__transport.get_extra_info(
'subprocess'
) # Popen instance
except:
slog(ERR, "Failed to run process [{}]".format(self.__name))
raise
@ -69,6 +78,7 @@ class ShellCmd: # export
def __reap(self):
if self.__rc is None and self.__transport:
self.__transport = None
if self.__proc is not None:
self.__rc = self.__proc.wait()
# to be called from SubprocessProtocol / SIGCHLD handler
@ -78,13 +88,24 @@ class ShellCmd: # export
async def __cleanup(self):
pid = self.__reap()
sd_fine = self.__shutdown in [ self.ShutdownState.Unnecessary, self.ShutdownState.Completed ]
sd_fine = self.__shutdown in [
self.ShutdownState.Unnecessary, self.ShutdownState.Completed
]
if self.__rc == 0 and sd_fine:
slog(INFO, "The shell command [{}], pid {}, has exited cleanly".format(self.__name, self.__proc.pid))
assert self.__proc is not None
slog(
INFO,
"The shell command [{}], pid {}, has exited cleanly".format(
self.__name, self.__proc.pid
)
)
self.monitor = self.console = self.__protocol = self.__task = None
return 0
slog(ERR, "The process ([{}], pid {}) has exited {}with status code {}, aborting".format(
self.__name, pid, "" if sd_fine else "prematurely ", self.__rc))
slog(
ERR,
"The process ([{}], pid {}) has exited {}with status code {}, aborting".
format(self.__name, pid, "" if sd_fine else "prematurely ", self.__rc)
)
exit(1)
async def init(self):
@ -100,9 +121,9 @@ class ShellCmd: # export
if __name__ == '__main__':
from .. import log
log.set_level('info')
async def run():
sp = ShellCmd(['echo', 'hello world!'])
await sp.run()
asyncio.run(run())

View file

@ -1,5 +1,3 @@
# -*- coding: utf-8 -*-
import asyncio
from ..Signals import Signals as SignalsBase

View file

@ -1,15 +1,16 @@
# -*- coding: utf-8 -*-
from typing import Optional, Union, Self
from __future__ import annotations
import abc
from enum import Flag, Enum, auto
from enum import Enum, Flag, auto
from typing import TYPE_CHECKING, Optional, Self, Union
from ..log import *
from ..Config import Config
from ..log import ERR
from ..misc import load_object
if TYPE_CHECKING:
from ..Config import Config
class Access(Enum): # export
Read = auto()
Modify = auto()
@ -77,7 +78,7 @@ class Auth(abc.ABC): # export
if tp == '':
val = conf.get('type')
if val is None:
msg = f'No type specified in auth configuration'
msg = 'No type specified in auth configuration'
conf.dump(ERR, msg)
raise Exception(msg)
tp = val
@ -92,10 +93,17 @@ class Auth(abc.ABC): # export
return self.__conf
@abc.abstractmethod
def _access(self, what: str, access_type: Optional[Access], who: User|Group|None) -> bool:
def _access(
self, what: str, access_type: Optional[Access], who: User | Group | None
) -> bool:
raise NotImplementedError
def access(self, what: str, access_type: Optional[Access]=None, who: Optional[Union[User|Group]]=None) -> bool:
def access(
self,
what: str,
access_type: Optional[Access] = None,
who: Optional[Union[User, Group]] = None
) -> bool:
return self._access(what, access_type, who)
@abc.abstractmethod

View file

@ -1,14 +1,16 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
from typing import Optional, Union
from typing import TYPE_CHECKING, Optional
from ...log import *
from ... import Config
from .. import Access
from .. import Auth as AuthBase
from .. import Group as GroupBase
from .. import User as UserBase
from .. import ProjectFlags
from ...log import WARNING, slog
from ..Auth import Access
from ..Auth import Auth as AuthBase
from ..Auth import Group as GroupBase
from ..Auth import ProjectFlags
from ..Auth import User as UserBase
if TYPE_CHECKING:
from ...Config import Config
class Group(GroupBase): # export
@ -62,12 +64,18 @@ class Auth(AuthBase): # export
ret: dict[str, UserBase] = {}
for name in self.conf.entries('user'):
conf = self.conf.branch('user.' + name)
assert conf is not None, 'Config is None'
ret[name] = User(self, name, conf)
self.___users = ret
return self.___users
def _access(self, what: str, access_type: Optional[Access], who: User|GroupBase|None) -> bool: # type: ignore
slog(WARNING, f'Returning False for {access_type} access to resource {what} by {who}')
def _access(
self, what: str, access_type: Access | None, who: UserBase | GroupBase | None
) -> bool: # type: ignore
slog(
WARNING,
f'Returning False for {access_type} access to resource {what} by {who}'
)
return False
def _user(self, name) -> UserBase:

View file

@ -1,17 +1,19 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
from typing import Optional, Union
from typing import TYPE_CHECKING, Optional
import ldap
import ldap # type: ignore[import-untyped]
from ...log import *
from ...ldap import bind
from ...log import DEBUG, ERR, WARNING, slog
from ..Auth import Access
from ..Auth import Auth as AuthBase
from ..Auth import Group as GroupBase
from ..Auth import ProjectFlags
from ..Auth import User as UserBase
if TYPE_CHECKING:
from ...Config import Config
from .. import Access
from .. import Auth as AuthBase
from .. import Group as GroupBase
from .. import User as UserBase
from .. import ProjectFlags
class Group(GroupBase): # export
@ -24,13 +26,7 @@ class Group(GroupBase): # export
class User(UserBase):
def __init__(
self,
auth: AuthBase,
name: str,
cn: str,
email: str
):
def __init__(self, auth: AuthBase, name: str, cn: str, email: str):
self.__auth = auth
self.__name = name
@ -72,18 +68,16 @@ class Auth(AuthBase): # export
ret_by_email: dict[str, User] = {}
for res in self.__conn.find(
self.__user_base_dn,
ldap.SCOPE_SUBTREE,
ldap.SCOPE_SUBTREE, # pyright: ignore[reportAttributeAccessIssue]
"objectClass=inetOrgPerson",
('uid', 'cn', 'uidNumber', 'mail', 'maildrop')
):
('uid', 'cn', 'uidNumber', 'mail', 'maildrop')):
try:
display_name = None
if 'displayName' in res[1]:
cn = res[1]['displayName'][0].decode('utf-8')
else:
cn = res[1]['cn'][0].decode('utf-8')
uid = res[1]['uid'][0].decode('utf-8')
uidNumber = res[1]['uidNumber'][0].decode('utf-8')
res[1]['uidNumber'][0].decode('utf-8')
emails = []
#for attr in ['mail', 'maildrop']:
for attr in ['mail']:
@ -113,8 +107,16 @@ class Auth(AuthBase): # export
self.__users
return self.___user_by_email # type: ignore # We are sure that ___user_by_email is not None at this point
def _access(self, what: str, access_type: Optional[Access], who: User|GroupBase|None) -> bool: # type: ignore
slog(WARNING, f'Returning False for {access_type} access to resource {what} by {who}')
def _access(
self,
what: str,
access_type: Optional[Access],
who: UserBase | GroupBase | None
) -> bool: # type: ignore
slog(
WARNING,
f'Returning False for {access_type} access to resource {what} by {who}'
)
return False
def _user(self, name) -> UserBase:
@ -136,5 +138,9 @@ class Auth(AuthBase): # export
def _projects(self, name, flags: ProjectFlags) -> list[str]:
if flags & ProjectFlags.Contributing:
# TODO: Ask LDAP
slog(WARNING, f'Querying LDAP for projects a user contributes to is not implemented, ignoring')
slog(
WARNING,
'Querying LDAP for projects a user contributes to is not '
'implemented, ignoring'
)
return []

View file

@ -1,110 +1,130 @@
# -*- coding: utf-8 -*-
import os
import pytimeparse, os
from datetime import datetime, timedelta
from collections import OrderedDict
from datetime import timedelta
from .log import *
import pytimeparse # type: ignore[import-untyped]
from .log import DEBUG, WARNING, slog
_int_chars = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
def _strip(s_, throw=True, log_level=ERR):
def _strip(s_) -> str:
s = s_.strip()
if len(s) != 0:
return s
msg = f'Tried to strip empty string "{s_}" to int'
if throw:
raise Exception(msg)
slog(log_level, msg)
return None
raise Exception(f'Tried to strip empty string "{s_}"')
def cast_str_to_timedelta(s_: str, throw=True, log_level=DEBUG): # export
s = _strip(s_, throw=throw, log_level=log_level)
try:
return (True, timedelta(seconds=pytimeparse.parse(s_)))
except Exception as e:
msg = f'Could not convert string "{s_}" to time ({e})'
if throw:
raise Exception(msg)
slog(log_level, msg)
return (False, None)
def cast_str_to_timedelta(s_: str): # export
s = _strip(s_)
seconds = pytimeparse.parse(s)
if seconds is None:
raise Exception(f'Failed to convert {s} to timedelta')
return timedelta(seconds = seconds)
def cast_str_to_int(s_: str, throw=True, log_level=DEBUG): # export
s = _strip(s_, throw=throw, log_level=log_level)
def cast_str_to_int(s_: str): # export
s = _strip(s_)
if s[0] == '-':
s = s[1:]
for c in s:
if not c in _int_chars:
break
else:
return (True, int(s_))
msg = f'Could not convert string "{s_}" to int'
if throw:
raise Exception(msg)
slog(log_level, msg)
return (False, None)
if c not in _int_chars:
raise Exception(f'Could not convert string "{s}" to int')
return int(s)
def cast_str_to_bool(s_: str, throw=True, log_level=DEBUG): # export
s = _strip(s_, throw=throw, log_level=log_level).lower()
def cast_str_to_bool(s_: str): # export
s = _strip(s_).lower()
if s in ['true', 'yes', '1']:
return (True, True)
return True
if s in ['false', 'no', '0']:
return (True, False)
msg = f'Could not convert string "{s_}" to bool'
if throw:
raise Exception(msg)
slog(log_level, msg)
return (False, None)
return False
raise Exception(f'Could not convert string "{s_}" to bool')
_str_cast_functions = OrderedDict({
bool: cast_str_to_bool,
int: cast_str_to_int,
timedelta: cast_str_to_timedelta
})
_str_cast_functions = OrderedDict(
{
bool: cast_str_to_bool, int: cast_str_to_int, timedelta: cast_str_to_timedelta
}
)
def guess_type(s: str, default = None, log_level = DEBUG, throw = False): # export
if s is None:
raise Exception('None string passed to guess_type()')
for tp, func in _str_cast_functions.items():
try:
success, value = func(s, log_level=OFF, throw=False)
if success:
return tp
except:
func(s)
except Exception:
continue
return tp
msg = f'Failed to guess type of string "{s}"'
if throw:
raise Exception(msg)
slog(log_level, msg)
return default
def from_str(s: str, target_type=None, default_type=None, throw=True, log_level=WARNING, caller=None): # export
if target_type is None:
target_type = guess_type(s, default_type)
def from_str( # export
s: str,
target_type = None,
default_type = None,
throw = True,
log_level = WARNING,
caller = None
):
if target_type is None:
for tp, func in _str_cast_functions.items():
try:
return func(s)
except Exception:
continue
msg = f'Could not deduce type to cast to from string "{s}"'
if throw:
raise Exception(msg)
slog(log_level, msg)
return None
result = _str_cast_functions[target_type](s, throw=throw, log_level=log_level)
if result[0]:
return result[1]
msg = f'Failed to cast string "{s}" to type {target_type}'
try:
return _str_cast_functions[target_type](s)
except Exception as e:
msg = f'Failed to cast string "{s}" to type {target_type} ({str(e)})'
if throw:
raise Exception(msg)
slog(log_level, msg)
return None
def from_env(key: str, default=None, target_type=None, default_type=None, throw=True, log_level=WARNING, caller=None): # export
def from_env( # export
key: str,
default = None,
target_type = None,
default_type = None,
throw = True,
log_level = WARNING,
caller = None
):
val = os.getenv(key)
if val is None:
return default
if target_type is None and default is not None:
target_type = type(default)
return from_str(val, target_type=target_type, default_type=default_type, throw=throw, log_level=log_level, caller=caller)
return from_str(
val,
target_type = target_type,
default_type = default_type,
throw = throw,
log_level = log_level,
caller = caller
)
# deprecated name
def cast_str(s: str, target_type=None, default_type=None, throw=True, log_level=WARNING, caller=None):
return from_str(s, target_type=target_type, default_type=None, throw=True, log_level=WARNING, caller=None)
def cast_str(
s: str,
target_type = None,
default_type = None,
throw = True,
log_level = WARNING,
caller = None
):
return from_str(
s,
target_type = target_type,
default_type = None,
throw = True,
log_level = WARNING,
caller = None
)

View file

@ -1,15 +1,16 @@
# -*- coding: utf-8 -*-
from typing import Any
from __future__ import annotations
import abc
from contextlib import contextmanager
from contextlib import contextmanager
from typing import TYPE_CHECKING
from ..log import NOTICE
if TYPE_CHECKING:
from ..Config import Config
from .schema.Schema import Schema
from ..Cmds import Cmds
from .Session import Session
from ..log import *
class DataBase(abc.ABC):
@ -39,4 +40,5 @@ class DataBase(abc.ABC):
try:
yield ret
finally:
if ret is not None:
self._delete_session(ret)

View file

@ -1,5 +1,3 @@
# -*- coding: utf-8 -*-
import abc
class Session(abc.ABC): # export

View file

@ -1,15 +1,14 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
from typing import Any, List, Union, Optional, Dict
from abc import ABC, abstractmethod
import re, csv, json
from typing import TYPE_CHECKING, Any, Dict, Union
from ..log import *
from ..cast import cast_str
from ..log import ERR, INFO, OFF, slog, slog_m
from .rows import rows_check_not_null, rows_dump, rows_duplicates
if TYPE_CHECKING:
from .schema.Schema import Schema
from .rows import *
TType = Union[Any, Dict[str, Any]]
class TableIoHandler(ABC): # export
@ -22,7 +21,8 @@ class TableIoHandler(ABC): # export
def _table_meta(self):
if self.__table_meta is None:
self.__table_meta = self.__schema.table_by_model_name(
self.__class__.__name__, throw=True)
self.__class__.__name__, throw = True
)
return self.__table_meta
@property
@ -35,7 +35,7 @@ class TableIoHandler(ABC): # export
def _check_non_nullable(self, rows):
buf = []
non_nullable = self.__table_meta.not_null_insertible_columns
non_nullable = self._table_meta.not_null_insertible_columns
try:
rows_check_not_null(rows, non_nullable, buf = buf)
except:
@ -48,11 +48,16 @@ class TableIoHandler(ABC): # export
if not buf:
continue
slog_m(ERR, f'\n{d} Null values in {cn} / {tn}: "{key}"\n')
use_cols=self.log_columns
use_cols = list(self.log_columns)
if key not in use_cols:
use_cols.append(key)
rows_dump(buf, use_cols = use_cols, log_prio = ERR)
rows_dump(buf, use_cols=use_cols, out_path=f'/tmp/missing_{key}_in_{tn}.html', heading=f'Missing "{key}" in table {tn}')
rows_dump(
buf,
use_cols = use_cols,
out_path = f'/tmp/missing_{key}_in_{tn}.html',
heading = f'Missing "{key}" in table {tn}'
)
raise
@property
@ -67,7 +72,9 @@ class TableIoHandler(ABC): # export
def _store(self, uri: str, data: TType):
pass
def load(self, uri: str, reference, check_duplicates=False, write_csv=None) -> TType:
def load(
self, uri: str, reference, check_duplicates = False, write_csv = None
) -> TType:
slog(INFO, f'Reading table "{self._table_name}" from "{uri}"')
ret = self._load(uri, reference)
if check_duplicates:

View file

@ -1,15 +1,15 @@
# -*- coding: utf-8 -*-
from typing import Any
from __future__ import annotations
from typing import Any, TYPE_CHECKING
import abc
from ...log import *
from ...log import slog, slog_m, ERR, INFO
from ...misc import load_classes
from ...Cmds import Cmds
from ..DataBase import DataBase
from ..schema.Schema import Schema
from .Query import Query as QueryBase
if TYPE_CHECKING:
from ..schema.Schema import Schema
from .QueryResult import QueryResult
class Queries(abc.ABC): # export

View file

@ -1,16 +1,13 @@
# -*- coding: utf-8 -*-
from typing import Any
from __future__ import annotations
import abc
from ...log import *
from ...misc import load_classes
from ...Cmds import Cmds
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from ..DataBase import DataBase
from ..Session import Session
from .QueryResult import QueryResult
#from .Queries import Queries
class Query(abc.ABC): # export

View file

@ -1,12 +1,11 @@
# -*- coding: utf-8 -*-
from typing import Any, Union
from __future__ import annotations
import abc
from enum import Enum, auto
from ...log import *
from ...Cmds import Cmds
from enum import Enum, auto
from typing import TYPE_CHECKING, Any, Union
if TYPE_CHECKING:
from ..DataBase import DataBase
from ..Session import Session
@ -58,5 +57,5 @@ class QueryResult(abc.ABC): # export
# -- pure virtuals
@abc.abstractmethod
def _cast(self, res_type: ResType, **kwargs) -> Union[Any|list[Any]]:
def _cast(self, res_type: ResType, **kwargs) -> Union[Any, list[Any]]:
pass

View file

@ -1,12 +1,16 @@
# -*- coding: utf-8 -*-
import csv
import io
import json
import os
import re
import textwrap
import io, os, re, textwrap, json, csv
from tabulate import tabulate # type: ignore
from tabulate import TableFormat, tabulate # type: ignore
from ..log import *
from ..log import (ERR, INFO, WARNING, get_caller_pos, prio_gets_logged, slog, slog_m)
def rows_pretty(rows): # export
if type(rows) == dict:
if isinstance(rows, dict):
rows = [rows]
out = []
for row in rows:
@ -14,6 +18,7 @@ def rows_pretty(rows): # export
return '\n'.join(out)
def rows_duplicates(rows, log_prio = INFO, caller = None): # export
def __equal(r1, r2):
for col in set(r1.keys()) | set(r2.keys()):
if col in r1:
@ -25,6 +30,7 @@ def rows_duplicates(rows, log_prio=INFO, caller=None): # export
if r1[col] != r2[col]:
return False
return True
ret = []
last = len(rows) - 1
i = last
@ -37,12 +43,15 @@ def rows_duplicates(rows, log_prio=INFO, caller=None): # export
last -= 1
return ret
def rows_remove(rows, callback=None, candidates=None, log_prio=INFO, caller=None): # export
def rows_remove( # export
rows, callback = None, candidates = None, log_prio = INFO, caller = None
):
def __is_remove_candidate(row):
assert candidates is not None, 'Candidates is None'
for remove_row in candidates:
for col, val in row.items():
if not col in remove_row.keys():
if col not in remove_row.keys():
break
if val != remove_row[col]:
break
@ -72,7 +81,7 @@ def rows_select(rows, rules): # export
ret = []
for row in rows:
for rule in rules:
if type(rule) == tuple():
if isinstance(rule, tuple):
search_rule = rule[0]
else:
search_rule = rule
@ -93,14 +102,25 @@ def rows_rewrite_regex(rows, rules): # export
break
else:
for exec_col_name, exec_val in rule[1].items():
slog(INFO, f'Rewriting {row} {row.get(exec_col_name)} -> {exec_val}')
slog(
INFO,
f'Rewriting {row} {row.get(exec_col_name)} -> {exec_val}'
)
row[exec_col_name] = exec_val
except Exception as e:
slog(ERR, f'Failed to run rule {rule} against {row} ({e})')
raise
def rows_check_not_null(rows, keys, log_prio=WARNING, buf=None, stat_key=None, throw=True, caller=None): # export
if type(keys) == str:
def rows_check_not_null( # export
rows,
keys,
log_prio = WARNING,
buf = None,
stat_key = None,
throw = True,
caller = None
):
if isinstance(keys, str):
keys = [keys]
if caller is None:
caller = get_caller_pos()
@ -117,7 +137,7 @@ def rows_check_not_null(rows, keys, log_prio=WARNING, buf=None, stat_key=None, t
buf.append(row)
if stat_key is not None:
stat_val = row[stat_key]
if not stat_val in stats.keys():
if stat_val not in stats.keys():
stats[stat_val] = 0
stats[stat_val] += 1
count += 1
@ -129,10 +149,23 @@ def rows_check_not_null(rows, keys, log_prio=WARNING, buf=None, stat_key=None, t
i += 1
slog(ERR, f'{i:>3}. {k:<23}: {v}', caller = caller)
if throw:
raise Exception(f'Found {count} rows violating null-constraint for keys {keys}')
raise Exception(
f'Found {count} rows violating null-constraint for keys {keys}'
)
return buf
def rows_dumps(rows, log_prio=INFO, caller=None, use_cols=None, skip_cols=None, table_name=None, out_path='log', heading=None, lead=None, tablefmt=None): # export
def rows_dumps( # export
rows,
log_prio = INFO,
caller = None,
use_cols = None,
skip_cols = None,
table_name = None,
out_path = 'log',
heading = None,
lead = None,
tablefmt = None
):
headers = 'keys'
dump_rows = rows
@ -152,20 +185,21 @@ def rows_dumps(rows, log_prio=INFO, caller=None, use_cols=None, skip_cols=None,
new_row[col] = val
new_dump_rows.append(new_row)
dump_rows = new_dump_rows
out = header = footer = ""
header = footer = ""
match tablefmt:
case 'html':
if heading is not None:
heading = f'<h1>{heading}</h1>\n'
if type(lead) == str:
if isinstance(lead, str):
lead = f'<div class="lead">\n {lead}\n</div>\n'
elif type(lead) == list:
l = '<ul>\n'
elif isinstance(lead, list):
lst = '<ul>\n'
for li in lead:
l += f'<li>{li}</li>\n'
l += '</ul>\n'
lead = l
header=textwrap.dedent('''\
lst += f'<li>{li}</li>\n'
lst += '</ul>\n'
lead = lst
header = textwrap.dedent(
'''\
<html>
<head>
<script src="https://ajax.googleapis.com/ajax/libs/jquery/2.1.1/jquery.min.js"></script>
@ -185,30 +219,47 @@ def rows_dumps(rows, log_prio=INFO, caller=None, use_cols=None, skip_cols=None,
</style>
</head>
<body>
''')
footer = textwrap.dedent('''
'''
)
footer = textwrap.dedent(
'''
</body>
</html>
''')
'''
)
case _:
if type(heading) == str:
if isinstance(heading, str):
heading = '\n' + heading
if type(lead) == str:
if isinstance(lead, str):
pass
elif type(lead) == list:
l =''
elif isinstance(lead, list):
lst = ''
for li in lead:
l += f' - {li}\n'
lead = '\n\n' + l + '\n'
lst += f' - {li}\n'
lead = '\n\n' + lst + '\n'
if heading is None:
heading = ''
if lead is None:
lead = ''
return header + heading + lead + tabulate(dump_rows, headers=headers, tablefmt=tablefmt) + footer
assert isinstance(tablefmt, str) or isinstance(tablefmt, TableFormat), 'tablefmt'
return header + heading + lead + tabulate(
dump_rows, headers = headers, tablefmt = tablefmt
) + footer
def rows_dump(rows, log_prio=INFO, caller=None, use_cols=None, skip_cols=None, table_name=None, out_path='log', heading=None, lead=None, tablefmt=None): # export
def rows_dump( # export
rows,
log_prio = INFO,
caller = None,
use_cols = None,
skip_cols = None,
table_name = None,
out_path = 'log',
heading = None,
lead = None,
tablefmt = None
):
if not prio_gets_logged(log_prio):
return
@ -218,7 +269,17 @@ def rows_dump(rows, log_prio=INFO, caller=None, use_cols=None, skip_cols=None, t
if tablefmt is None and out_path:
tablefmt = os.path.splitext(out_path)[1][1:]
out = rows_dumps(rows, log_prio=log_prio, caller=caller, use_cols=use_cols, skip_cols=skip_cols, table_name=table_name, heading=heading, lead=lead, tablefmt=tablefmt)
out = rows_dumps(
rows,
log_prio = log_prio,
caller = caller,
use_cols = use_cols,
skip_cols = skip_cols,
table_name = table_name,
heading = heading,
lead = lead,
tablefmt = tablefmt
)
match out_path:
case 'log':
@ -228,8 +289,15 @@ def rows_dump(rows, log_prio=INFO, caller=None, use_cols=None, skip_cols=None, t
fp.write(out)
def rows_to_csv(rows, use_tmpfile = False): # export
def __write(rows, out):
writer = csv.DictWriter(out, fieldnames=field_names, delimiter=';', quotechar='"', quoting=csv.QUOTE_NONNUMERIC)
writer = csv.DictWriter(
out,
fieldnames = field_names,
delimiter = ';',
quotechar = '"',
quoting = csv.QUOTE_NONNUMERIC
)
writer.writeheader()
for row in rows:
writer.writerow(row)

View file

@ -1,17 +1,20 @@
# -*- coding: utf-8 -*-
from typing import Optional, Any
from __future__ import annotations
import abc
from typing import TYPE_CHECKING, Any, Optional
from ...log import ERR, throw
if TYPE_CHECKING:
from .DataType import DataType
from ...log import *
from .Table import Table
class Column(abc.ABC): # export
def __init__(self, table, name, data_type: DataType):
def __init__(self, table: Table, name: str, data_type: DataType) -> None:
self.__name: str = name
self.__table: Any = table
self.__table: Table = table
self.__is_nullable: Optional[bool] = None
self.__is_null_insertible: Optional[bool] = None
self.__is_primary_key: Optional[bool] = None
@ -46,11 +49,11 @@ class Column(abc.ABC): # export
return self.__name
@property
def data_type(self):
def data_type(self) -> DataType:
return self.__data_type
@property
def table(self) -> str:
def table(self) -> Table:
return self.__table
@property
@ -60,7 +63,7 @@ class Column(abc.ABC): # export
return self.__is_nullable
@property
def is_null_insertible(self):
def is_null_insertible(self) -> bool:
if self.__is_null_insertible is None:
ret = False
if self.is_nullable:
@ -81,7 +84,9 @@ class Column(abc.ABC): # export
@property
def is_auto_increment(self) -> bool:
if self.__is_auto_increment is None:
self.__is_auto_increment = self.__name in self.__table.auto_increment_columns
self.__is_auto_increment = (
self.__name in self.__table.auto_increment_columns
)
return self.__is_auto_increment
@property

View file

@ -1,10 +1,14 @@
# -*- coding: utf-8 -*-
from typing import Optional, Iterable, Any
from typing import Optional, Any
class ColumnSet: # export
def __init__(self, *args: list[Any], columns: list[Any]=[], table: Optional[Any]=None, names: Optional[list[str]]=None):
def __init__(
self,
*args: list[Any],
columns: list[Any] = [],
table: Optional[Any] = None,
names: Optional[list[str]] = None
):
self.__columns: list[Any] = [*args]
self.__columns.extend(columns)
self.__table = table

View file

@ -1,15 +1,18 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
from typing import Optional, Any
from typing import TYPE_CHECKING, Any, Optional
from ...log import *
from .ColumnSet import ColumnSet
from ...log import WARNING, slog
from .SingleForeignKey import SingleForeignKey
if TYPE_CHECKING:
from .ColumnSet import ColumnSet
class CompositeForeignKey: # export
def __init__(self, child_col_set: ColumnSet, parent_col_set: ColumnSet): # TODO: Implement alternative ways to construct
def __init__(
self, child_col_set: ColumnSet, parent_col_set: ColumnSet
): # TODO: Implement alternative ways to construct
def __table(s):
ret = None
@ -46,7 +49,12 @@ class CompositeForeignKey: # export
def __repr__(self):
ret = self.__table_rel_str()
ret += ': ' + ', '.join([self.__cols_rel_str(rel.child_column, rel.parent_column) for rel in self.column_relations])
ret += ': ' + ', '.join(
[
self.__cols_rel_str(rel.child_column, rel.parent_column)
for rel in self.column_relations
]
)
return ret
def __eq__(self, rhs):
@ -73,7 +81,7 @@ class CompositeForeignKey: # export
return self.__parent_col_set
def parent_column(self, child_column) -> Any:
child_column_name = child_column if isinstance(child_column, str) else child_column.name
child_column if isinstance(child_column, str) else child_column.name
if self.__parent_columns_by_child_column is None:
d: dict[str, Any] = {}
assert (len(self.__child_col_set) == len(self.__parent_col_set))
@ -83,8 +91,12 @@ class CompositeForeignKey: # export
return self.__parent_columns_by_child_column[child_column]
def child_column(self, parent_column) -> Any:
slog(WARNING, f'{self}: Looking for child column belonging to parent column "{parent_column}"')
parent_column_name = parent_column if isinstance(parent_column, str) else parent_column.name
slog(
WARNING,
f'{self}: Looking for child column belonging to parent column '
f'"{parent_column}"'
)
parent_column if isinstance(parent_column, str) else parent_column.name
if self.__child_columns_by_parent_column is None:
d: dict[str, Any] = {}
assert (len(self.__parent_col_set) == len(self.__child_col_set))
@ -98,6 +110,8 @@ class CompositeForeignKey: # export
ret = []
if self.__column_relations is None:
for i in range(0, self.__len):
ret.append(SingleForeignKey(self.__child_col_set[i], self.__parent_col_set[i]))
ret.append(
SingleForeignKey(self.__child_col_set[i], self.__parent_col_set[i])
)
self.__column_relations = ret
return self.__column_relations

View file

@ -1,10 +1,8 @@
# -*- coding: utf-8 -*-
from typing import Optional
from enum import Enum, auto
from datetime import datetime
from enum import Enum, auto
from typing import Optional
from ...log import *
from ...log import ERR, throw
class Id(Enum):
Integer = auto()
@ -42,7 +40,10 @@ class DataType: # export
def __init__(self, type_id: Id, size: Optional[int] = None):
if not isinstance(type_id, Id):
throw(ERR, f'Passed type id "{type_id}" with unsupported data type {type(type_id)}')
throw(
ERR,
f'Passed type id "{type_id}" with unsupported data type {type(type_id)}'
)
if size is not None:
assert (isinstance(size, int))
assert (size > 0)

View file

@ -1,20 +1,20 @@
# -*- coding: utf-8 -*-
from typing import Optional, Iterable
from __future__ import annotations
import abc
from ...log import *
from typing import TYPE_CHECKING, Iterable, Optional
from .Table import Table
from ...log import DEBUG, ERR, slog, throw
if TYPE_CHECKING:
from .Column import Column
from .DataType import DataType
from .CompositeForeignKey import CompositeForeignKey
from .Table import Table
class Schema(abc.ABC): # export
def __init__(self) -> None:
self.___tables: Optional[list[Table]] = None
self.___tables: Optional[dict[str, Table]] = None
self.__foreign_keys: Optional[list[CompositeForeignKey]] = None
self.__access_defining_columns: Optional[list[str]] = None
@ -62,7 +62,7 @@ class Schema(abc.ABC): # export
yield from self.__tables.values()
def __repr__(self):
return '|'.join([table.name for table in self.__tables])
return '|'.join([table.name for table in self.__tables.values()])
def __getitem__(self, index):
return self.__tables[index]

View file

@ -1,9 +1,9 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
from typing import Optional, Any
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from .Column import Column
from .ColumnSet import ColumnSet
class SingleForeignKey:
@ -19,7 +19,10 @@ class SingleForeignKey:
yield from self.__iterable
def __repr__(self):
return f'{self.__child_col.table.name}.{self.__child_col.name} -> {self.__parent_col.table.name}.{self.__parent_col.name}'
return (
f'{self.__child_col.table.name}.{self.__child_col.name} -> '
'{self.__parent_col.table.name}.{self.__parent_col.name}'
)
def __getitem__(self, index):
return self.__iterable[index]

View file

@ -1,18 +1,22 @@
# -*- coding: utf-8 -*-
from typing import Optional, Union, Iterable, Self, Any # TODO: Need any for many things, as I can't figure out how to avoid circular imports from here
from __future__ import annotations
import abc
import re
from collections import OrderedDict
from typing import TYPE_CHECKING
from urllib.parse import quote_plus
from ...log import *
from ...log import ERR, WARNING, slog, throw
from ...misc import load_class
from .ColumnSet import ColumnSet
from .DataType import DataType
from .CompositeForeignKey import CompositeForeignKey
from .Column import Column
from .ColumnSet import ColumnSet
if TYPE_CHECKING:
from typing import Any, Iterable, Optional, Self, Union
from .CompositeForeignKey import CompositeForeignKey
from .DataType import DataType
class Table(abc.ABC): # export
@ -61,7 +65,8 @@ class Table(abc.ABC): # export
if self.___foreign_key_parent_tables is None:
self.___foreign_key_parent_tables = OrderedDict()
for cfk in self.foreign_key_constraints:
self.___foreign_key_parent_tables[cfk.parent_table.name] = cfk.parent_table
self.___foreign_key_parent_tables[cfk.parent_table.name
] = cfk.parent_table
return self.___foreign_key_parent_tables
@property
@ -77,12 +82,12 @@ class Table(abc.ABC): # export
def __add_child_row_location_rules(self) -> dict[str, str]:
if self.___add_child_row_location_rules is None:
ret: dict[str, str] = {}
for foreign_table_name, foreign_table in self.__relationship_by_foreign_table.items():
if len([self.foreign_keys_to_parent_table(foreign_table)]):
rule = self._add_child_row_location_rule(foreign_table_name)
for table_name, table in self.__relationship_by_foreign_table.items():
if len([self.foreign_keys_to_parent_table(table)]):
rule = self._add_child_row_location_rule(table_name)
if rule is None:
continue
ret[foreign_table_name] = rule
ret[table_name] = rule
self.___add_child_row_location_rules = ret
return self.___add_child_row_location_rules
@ -168,7 +173,8 @@ class Table(abc.ABC): # export
return None
def _model_module_search_paths(self) -> list[tuple[str, type]]:
return self.schema.model_module_search_paths # Fall back to Schema-global default
# Fall back to Schema-global default
return self.schema.model_module_search_paths
@abc.abstractmethod
def _query_name(self) -> str:
@ -190,7 +196,9 @@ class Table(abc.ABC): # export
for col in self.__schema.access_defining_columns:
if col in self.primary_keys:
ret += f'/<{col}>'
ret += self.base_location_rule
base = self.base_location_rule
if base is not None:
ret += base if isinstance(base, str) else 'what-goes-here?'.join(base)
return ret
def _row_location_rule(self) -> Optional[str]:
@ -288,7 +296,7 @@ class Table(abc.ABC): # export
return self.__location_rule
def location(self, *args, **kwargs):
ret = self.location_rule
ret = str(self.location_rule)
for token, val in kwargs.items(): # FIXME: Poor man's row location assembly
ret = re.sub(f'<{token}>', quote_plus(quote_plus(str(val))), ret)
return ret
@ -300,7 +308,7 @@ class Table(abc.ABC): # export
return self.__row_location_rule
def row_location(self, *args, **kwargs):
ret = self.row_location_rule
ret = str(self.row_location_rule)
for col in self.primary_keys:
if col in kwargs: # FIXME: Poor man's row location assembly
ret = re.sub(f'<{col}>', quote_plus(quote_plus(str(kwargs[col]))), ret)
@ -313,7 +321,7 @@ class Table(abc.ABC): # export
return self.__add_row_location_rule
def add_row_location(self, *args, **kwargs) -> Optional[str]:
ret = self.add_row_location_rule
ret = str(self.add_row_location_rule)
for col in self.primary_keys:
if col in kwargs: # FIXME: Poor man's row location assembly
ret = re.sub(f'<{col}>', quote_plus(quote_plus(str(kwargs[col]))), ret)
@ -323,12 +331,14 @@ class Table(abc.ABC): # export
def add_child_row_location_rules(self) -> Iterable[str]:
return self.__add_child_row_location_rules.values()
def add_child_row_location_rule(self, child_table: Union[Self, str]) -> Optional[str]:
def add_child_row_location_rule(self, child_table: Union[Self,
str]) -> Optional[str]:
if isinstance(child_table, Table):
child_table = child_table.name
return self.__add_child_row_location_rules.get(child_table)
def add_child_row_location(self, parent_table: Union[Self, str], **kwargs) -> Optional[str]:
def add_child_row_location(self, parent_table: Union[Self, str],
**kwargs) -> Optional[str]:
ret = self.add_child_row_location_rule(parent_table)
if isinstance(parent_table, str):
parent_table = self.schema[parent_table]
@ -337,7 +347,11 @@ class Table(abc.ABC): # export
for cfk in self.foreign_keys_to_parent_table(parent_table):
for fk in cfk:
if fk.parent_column.name in kwargs:
ret = re.sub(f'<{fk.child_column.name}>', quote_plus(quote_plus(str(kwargs[fk.parent_column.name]))), ret)
ret = re.sub(
f'<{fk.child_column.name}>',
quote_plus(quote_plus(str(kwargs[fk.parent_column.name]))),
ret
)
return ret
@property
@ -443,7 +457,8 @@ class Table(abc.ABC): # export
def foreign_key_parent_tables(self):
return self.__foreign_key_parent_tables.values()
def foreign_keys_to_parent_table(self, parent_table) -> Iterable[CompositeForeignKey]:
def foreign_keys_to_parent_table(self,
parent_table) -> Iterable[CompositeForeignKey]:
if self.__foreign_keys_to_parent_table is None:
self.__foreign_keys_to_parent_table = OrderedDict()
for cfk in self.foreign_key_constraints:
@ -451,8 +466,12 @@ class Table(abc.ABC): # export
if pt not in self.__foreign_keys_to_parent_table:
self.__foreign_keys_to_parent_table[pt] = []
self.__foreign_keys_to_parent_table[pt].append(cfk)
parent_table_name = parent_table if isinstance(parent_table, str) else parent_table.name
return self.__foreign_keys_to_parent_table[parent_table_name] if parent_table_name in self.__foreign_keys_to_parent_table else []
parent_table_name = parent_table if isinstance(
parent_table, str
) else parent_table.name
return self.__foreign_keys_to_parent_table[
parent_table_name
] if parent_table_name in self.__foreign_keys_to_parent_table else []
@property
def relationships(self) -> list[tuple[str, Self]]:

View file

@ -1,12 +1,18 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
from typing import TYPE_CHECKING
from ...log import NOTICE, slog
if TYPE_CHECKING:
from .Schema import Schema
from ...log import *
def check_schema(schema: Schema): # export
slog(NOTICE, f'There are {len(schema)} tables in the database')
for cfk in schema.foreign_key_constraints:
for fk in cfk:
if fk.child_column.data_type != fk.parent_column.data_type:
raise Exception(f'Type mismatch in foreign key {fk}: {fk.child_column.data_type} != {fk.parent_column.data_type}')
raise Exception(
f'Type mismatch in foreign key {fk}: {fk.child_column.data_type} '
f'!= {fk.parent_column.data_type}'
)

View file

@ -1,14 +1,19 @@
# -*- coding: utf-8 -*-
from collections.abc import Callable
from __future__ import annotations
import xml.etree.ElementTree as ET
from ...log import *
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from collections.abc import Callable
from typing import Any
class MapAttr2Shape: # export
def __init__(self, mappings: dict[str, str|Callable[[dict[str, str]], str]]|None=None) -> None:
def __init__(
self,
mappings: dict[str, str | Callable[[dict[str, str]], str]] | None = None
) -> None:
self.__mappings = mappings if mappings is not None else {}
self.__shape_node_key = 'd25'
self.__ns_gml = "http://graphml.graphdrawing.org/xmlns"
@ -16,7 +21,8 @@ class MapAttr2Shape: # export
# -- Standard GraphML
"": self.__ns_gml,
"xsi": "http://www.w3.org/2001/XMLSchema-instance",
"xsi:schemaLocation": "http://graphml.graphdrawing.org/xmlns http://graphml.graphdrawing.org/xmlns/1.0/graphml.xsd",
"xsi:schemaLocation":
"http://graphml.graphdrawing.org/xmlns http://graphml.graphdrawing.org/xmlns/1.0/graphml.xsd",
# -- YWorks GraphML
"java": "http://www.yworks.com/xml/yfiles-common/1.0/java",
@ -81,10 +87,7 @@ class MapAttr2Shape: # export
if children is not None:
__add(el, children)
default_values = {
'color': '#FFCC00',
'text': ''
}
default_values = {'color': '#FFCC00', 'text': ''}
values = {}
for key, default in default_values.items():
@ -98,7 +101,7 @@ class MapAttr2Shape: # export
continue
mapped = mapping(self.__attribs(node, keys))
values[key] = mapped or default
except:
except Exception:
pass
color = values['color']
@ -110,14 +113,34 @@ class MapAttr2Shape: # export
shape = {
'data': {
'a': {'key': self.__shape_node_key},
'a': {
'key': self.__shape_node_key
},
'c': {
'y:ShapeNode': {
'a': {},
'c': {
'y:Geometry': {'a': {'height': '30.0', 'width': str(width_box), 'x': str(-(width_box / 2)), 'y':' -15.0'}},
'y:Fill': {'a': {'color': color, 'transparent': 'false'}},
'y:BorderStyle': {'a': {'color': '#000000', 'raised': 'false', 'type': 'line', 'width': '1.0'}},
'y:Geometry': {
'a': {
'height': '30.0',
'width': str(width_box),
'x': str(-(width_box / 2)),
'y': ' -15.0'
}
},
'y:Fill': {
'a': {
'color': color, 'transparent': 'false'
}
},
'y:BorderStyle': {
'a': {
'color': '#000000',
'raised': 'false',
'type': 'line',
'width': '1.0'
}
},
'y:NodeLabel': {
'a': {
'alignment': 'center',
@ -142,7 +165,11 @@ class MapAttr2Shape: # export
'c': {
'y:LabelModel': {
'c': {
'y:SmartNodeLabelModel': {'a': {'distance': '4.0'}}
'y:SmartNodeLabelModel': {
'a': {
'distance': '4.0'
}
}
},
},
'y:ModelParameter': {
@ -164,7 +191,11 @@ class MapAttr2Shape: # export
},
't': text
},
'y:Shape': {'a': {'type': 'rectangle'}}
'y:Shape': {
'a': {
'type': 'rectangle'
}
}
}
}
}
@ -175,7 +206,7 @@ class MapAttr2Shape: # export
def __massage_nodes(self, root) -> None:
keys = self.__keys(root)
graph = root.find(f'graph', self.__ns)
graph = root.find('graph', self.__ns)
for node in graph:
self.__massage_node(node, keys)

View file

@ -1,22 +1,32 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
import copy
import getpass
import pathlib
import ldap, getpass, pathlib, copy
from ldap.schema.models import ObjectClass
from enum import Flag, auto
import networkx as nx
from typing import Any, Self
from typing import TYPE_CHECKING, Any, Self
import ldap # type: ignore[import-untyped]
import networkx as nx # type: ignore[import-untyped]
from ldap.schema.models import ObjectClass # type: ignore[import-untyped]
from .log import ERR, INFO, WARNING, slog
if TYPE_CHECKING:
from collections.abc import Callable
from .Config import Config as BaseConfig
from .log import *
class Config:
def __init__(self, external: BaseConfig | None = None):
self.__external = external
for attr in ['ldap_uri', 'bind_dn', 'bind_pw', 'base_dn']:
setattr(self, '_Config__' + attr, None)
def __get(self, key: str, default: str):
def __get(self, key: str, default: str | None):
if not self.__external:
return default
return self.__external.value(key, default = default)
@ -31,6 +41,7 @@ class Config:
else:
self.__ldap_uri = 'ldap://ldap.janware.com'
return self.__ldap_uri
@ldap_uri.setter
def ldap_uri(self, rhs):
self.__ldap_uri = rhs
@ -38,8 +49,12 @@ class Config:
@property
def bind_dn(self):
if self.__bind_dn is None:
self.__bind_dn = self.__get('bind_dn', default=f'uid={getpass.getuser()},ou=users,dc=jannet,dc=de')
self.__bind_dn = self.__get(
'bind_dn',
default = f'uid={getpass.getuser()},ou=users,dc=jannet,dc=de'
)
return self.__bind_dn
@bind_dn.setter
def bind_dn(self, rhs):
self.__bind_dn = rhs
@ -52,13 +67,17 @@ class Config:
if ret is not None:
break
if ret is None:
ldap_secret_file = self.__get('secret_file', f'{pathlib.Path.home()}/.ldap.secret')
ldap_secret_file = self.__get(
'secret_file', f'{pathlib.Path.home()}/.ldap.secret'
)
assert ldap_secret_file is not None, 'ldap_secret_file'
with open(ldap_secret_file, 'r') as file:
ret = file.read()
file.closed
ret = ret.strip()
self.__bind_pw = ret
return self.__bind_pw
@bind_pw.setter
def bind_pw(self, rhs):
self.__bind_pw = rhs
@ -66,8 +85,9 @@ class Config:
@property
def base_dn(self):
if self.__base_dn is None:
self.__base_dn = self.__get('base_dn', default=f'dc=jannet,dc=de')
self.__base_dn = self.__get('base_dn', default = 'dc=jannet,dc=de')
return self.__base_dn
@base_dn.setter
def base_dn(self, rhs):
self.__base_dn = rhs
@ -83,8 +103,10 @@ class Connection: # export
c = conf if isinstance(conf, Config) else Config(conf)
try:
uri = c.ldap_uri
except:
uri = c.uri
except Exception:
# mypy says: E: "Config" has no attribute "uri" [attr-defined]
# FIXME: Who adds .uri?
uri = c.uri # type: ignore
try:
ret = ldap.initialize(uri)
ret.start_tls_s()
@ -92,7 +114,7 @@ class Connection: # export
slog(ERR, f'Failed to initialize LDAP connection to "{uri}" ({str(e)})')
raise
try:
rr = ret.bind_s(c.bind_dn, c.bind_pw) # method)
ret.bind_s(c.bind_dn, c.bind_pw) # method)
except Exception as e:
slog(ERR, f'Failed to bind to "{uri}" with dn "{c.bind_dn}" ({str(e)})')
raise
@ -108,13 +130,16 @@ class Connection: # export
def add(self, attrs: dict[str, bytes], dn: str | None = None):
if dn is None:
if not 'dn' in attrs:
if 'dn' not in attrs:
raise Exception('No DN to add an LDAP entry to')
attrs = copy.deepcopy(attrs)
del attrs['dn']
try:
slog(INFO, f'LDAP: Add [{dn}] -> {attrs}')
self.__ldap.add_s(dn, ldap.modlist.addModlist(attrs))
ml =ldap.modlist.addModlist( # pyright: ignore[reportAttributeAccessIssue]
attrs
)
self.__ldap.add_s(dn, ml)
except Exception as e:
slog(ERR, f'{dn}: Failed to add entry {attrs} ({e})')
raise
@ -122,16 +147,27 @@ class Connection: # export
def delete(self, dn: str, recursive = False, force_existence: bool = False):
def __walk_cb_delete(conn: Connection, entry, context):
self.walk(__walk_cb_delete, base=entry[0], scope=ldap.SCOPE_ONELEVEL, context=context)
self.walk(
__walk_cb_delete,
base = entry[0],
scope = ldap.
SCOPE_ONELEVEL, # pyright: ignore[reportAttributeAccessIssue]
context = context
)
self.__ldap.delete_s(entry[0])
try:
if recursive:
self.walk(__walk_cb_delete, dn, scope=ldap.SCOPE_ONELEVEL)
self.walk(
__walk_cb_delete,
dn,
scope = ldap.
SCOPE_ONELEVEL # pyright: ignore[reportAttributeAccessIssue]
)
self.__ldap.delete_s(dn)
else:
self.__ldap.delete_s(dn)
except ldap.NO_SUCH_OBJECT as e:
except ldap.NO_SUCH_OBJECT: # pyright: ignore[reportAttributeAccessIssue]
if force_existence:
raise
except Exception as e:
@ -156,20 +192,25 @@ class Connection: # export
):
# TODO: Support ignored arguments
search_return = self.__ldap.search(base=base,
search_return = self.__ldap.search(
base = base,
scope = scope,
filterstr = filterstr,
attrlist = attrlist,
attrsonly=attrsonly)
attrsonly = attrsonly
)
while True:
result_type, result_data = self.__ldap.result(search_return, 0)
if (result_data == []):
if (not result_data):
break
if result_type != ldap.RES_SEARCH_ENTRY:
if result_type != ldap.RES_SEARCH_ENTRY: # pyright: ignore[reportAttributeAccessIssue]
continue
for entry in result_data:
if decode:
entry = entry[0], {key: [val.decode() for val in vals] for key, vals in entry[1].items()}
entry = entry[0], {
key: [val.decode() for val in vals]
for key, vals in entry[1].items()
}
if unroll and False:
entry = entry[0], {key: val[0] for key, val in entry[1].items()}
try:
@ -182,7 +223,8 @@ class Connection: # export
slog(WARNING, msg)
continue
def find(self,
def find(
self,
base: str,
scope,
filterstr = None,
@ -204,7 +246,13 @@ class Connection: # export
try:
result: list[Any] = []
self.walk(__walk_cb_find, base, scope=scope, filterstr=filterstr, attrlist=attrlist)
self.walk(
__walk_cb_find,
base,
scope = scope,
filterstr = filterstr,
attrlist = attrlist
)
except Exception as e:
slog(ERR, f'Failed search {__search()} ({e})')
raise
@ -218,15 +266,32 @@ class Connection: # export
def object_classes(self) -> dict[str, ObjectClass]:
#def object_classes(self):
if self.__object_classes_by_oid is None:
res = self.find(base='', scope=ldap.SCOPE_BASE, filterstr='(objectClass=*)', attrlist=['subschemaSubentry'])
dn = res[0][1]['subschemaSubentry'][0].decode('utf-8') # Usually yields cn=Subschema
res = self.find(base=dn, scope=ldap.SCOPE_BASE, filterstr='(objectClass=*)', attrlist=['*', '+'])
res = self.find(
base = '',
scope = ldap.SCOPE_BASE, # pyright: ignore[reportAttributeAccessIssue]
filterstr = '(objectClass=*)',
attrlist = ['subschemaSubentry']
)
dn = res[0][1]['subschemaSubentry'][0].decode(
'utf-8'
) # Usually yields cn=Subschema
res = self.find(
base = dn,
scope = ldap.SCOPE_BASE, # pyright: ignore[reportAttributeAccessIssue]
filterstr = '(objectClass=*)',
attrlist = ['*', '+']
)
subschema_entry = res[0]
subschema_subentry = ldap.cidict.cidict(subschema_entry[1])
subschema = ldap.schema.SubSchema(subschema_subentry)
subschema_subentry = ldap.cidict.cidict( # pyright: ignore[reportAttributeAccessIssue]
subschema_entry[1]
)
subschema = ldap.schema.SubSchema( # pyright: ignore[reportAttributeAccessIssue]
subschema_subentry
)
object_class_oids = subschema.listall(ObjectClass)
self.__object_classes_by_oid = {
oid: subschema.get_obj(ObjectClass, oid) for oid in object_class_oids
oid: subschema.get_obj(ObjectClass, oid)
for oid in object_class_oids
}
return self.__object_classes_by_oid
@ -243,14 +308,17 @@ class Connection: # export
return self.__object_classes_by_name
def __oc_recurse_to_top(self, cur: str | ObjectClass, cb, context):
cur_oc = cur if isinstance(cur, ObjectClass) else self.object_class_by_name[cur.lower()]
cur_oc = cur if isinstance(cur, ObjectClass) else self.object_class_by_name[
cur.lower()]
for s in cur_oc.sup:
self.__oc_recurse_to_top(s, cb, context)
cb(cur_oc, context)
def object_class_path(self, leaf: str | ObjectClass):
def cb(oc, context):
ret.append(oc)
ret: list[str] = []
self.__oc_recurse_to_top(leaf, cb, None)
return reversed(ret)
@ -262,20 +330,18 @@ class Connection: # export
def collect(root, attr):
ret = set()
def cb(oc, attr):
vals = getattr(oc, attr)
if vals is None:
return
for val in vals:
ret.add(val)
self.__oc_recurse_to_top(root, cb, attr)
return ret
kind = {
0: 'STRUCTURAL',
1: 'ABSTRACT',
2: 'AUXILIARY'
}
kind = {0: 'STRUCTURAL', 1: 'ABSTRACT', 2: 'AUXILIARY'}
ret = nx.DiGraph()
for oid, oc in self.object_classes.items():
ret.add_node(
@ -288,21 +354,31 @@ class Connection: # export
)
for base_class in oc.sup:
try:
ret.add_edge(oid, self.object_class_by_name[base_class.lower()].oid)
ret.add_edge(
oid, self.object_class_by_name[base_class.lower()].oid
)
except Exception as e:
slog(WARNING, f'Failed to add edge {oid}:{oc.names} -> {base_class} ({e})')
slog(
WARNING,
f'Failed to add edge {oid}:{oc.names} -> {base_class} ({e})'
)
self.__object_class_tree = ret
return self.__object_class_tree
def object_class_attrs(self, oc: str|ObjectClass, required: AttrType = AttrType.Must, origins: bool=False) -> dict[str, set[str]]|set[str]:
def object_class_attrs(
self,
oc: str | ObjectClass,
required: AttrType = AttrType.Must,
origins: bool = False
) -> dict[str, set[str]] | set[str]:
all_attrs: set[str] = set()
attrs_by_origin: dict[str, set[str]] = {}
for oc in self.object_class_path(oc):
cur = set()
if required & self.AttrType.Must:
cur |= set(oc.must)
cur |= set(oc.must) # pyright: ignore[reportAttributeAccessIssue]
if required & self.AttrType.May:
cur |= set(oc.may)
cur |= set(oc.may) # pyright: ignore[reportAttributeAccessIssue]
if cur:
all_attrs |= cur
attrs_by_origin[oc] = cur
@ -313,7 +389,11 @@ class Connection: # export
#base_oid = self.object_class_by_name[base_candidate].oid
#if base_oid in [oc.oid for oc in self.object_class_path(name)]:
# return True
return nx.has_path(self.object_class_tree, self.object_class_by_name[name.lower()].oid, self.object_class_by_name[base_candidate.lower()].oid)
return nx.has_path(
self.object_class_tree,
self.object_class_by_name[name.lower()].oid,
self.object_class_by_name[base_candidate.lower()].oid
)
def default_config() -> Config: # export
return Config()

View file

@ -1,22 +1,28 @@
# -*- coding: utf-8 -*-
from __future__ import annotations, print_function
from __future__ import print_function
import inspect
import re
import sys
import syslog
import unicodedata
from typing import List, Tuple, Optional, Any
import sys, re, io, syslog, inspect, unicodedata
from os.path import basename
from datetime import datetime
from os.path import basename
from typing import TYPE_CHECKING, Any, List, Optional, Tuple
from . import misc
if TYPE_CHECKING:
import io
# --- python 2 / 3 compatibility stuff
try:
basestring # type: ignore
except NameError:
basestring = str
# fmt: disable # don't conflate
# yapf: disable # don't conflate
_special_chars = {
'\a' : '\\a',
'\b' : '\\b',
@ -26,12 +32,20 @@ _special_chars = {
'\f' : '\\f',
'\r' : '\\r',
}
# yapf: enable
# fmt: enable
_special_char_regex = re.compile("(%s)" % "|".join(map(re.escape, _special_chars.keys())))
_special_char_regex = re.compile(
"(%s)" % "|".join(map(re.escape, _special_chars.keys()))
)
_all_control_chars = ''.join(chr(c) for c in range(sys.maxunicode) if unicodedata.category(chr(c)) in {'Cc'})
_all_control_chars = ''.join(
chr(c) for c in range(sys.maxunicode) if unicodedata.category(chr(c)) in {'Cc'}
)
_clean_str_regex = re.compile(r'(\033\[[0-9]*m|[%s])' % re.escape(_all_control_chars))
# fmt: disable # don't conflate
# yapf: disable # don't conflate
EMERG = int(syslog.LOG_EMERG)
ALERT = int(syslog.LOG_ALERT)
CRIT = int(syslog.LOG_CRIT)
@ -98,13 +112,17 @@ _prio_colors = {
EMERG : [ CONSOLE_FONT_BOLD + CONSOLE_FONT_MAGENTA, CONSOLE_FONT_OFF ],
}
# yapf: enable
# fmt: enable
class Stream:
def __init__(self, stream, flags):
self.stream = stream
self.flags = flags
_streams: dict[int, Stream] = dict()
_stream_descriptors = [reversed(range(1, 16))]
_stream_descriptors = list(reversed(range(1, 16)))
def add_capture_stream(stream, flags = 0x0):
ret = _stream_descriptors.pop()
@ -125,7 +143,8 @@ def log_level(s: Optional[str]=None) -> int: # export
return _level
return parse_log_prio_str(s)
def get_caller_pos(up: int = 1, kwargs: Optional[dict[str, Any]] = None) -> Tuple[str, str, int]:
def get_caller_pos(up: int = 1,
kwargs: Optional[dict[str, Any]] = None) -> Tuple[str, str, int]:
if kwargs and 'caller' in kwargs:
r = kwargs['caller']
del kwargs['caller']
@ -188,7 +207,9 @@ def slog(prio: int, *args, only_printable: bool=False, **kwargs) -> None: # expo
for a in args:
margs += ' ' + str(a)
if only_printable:
margs = _special_char_regex.sub(lambda mo: _special_chars[mo.string[mo.start():mo.end()]], margs)
margs = _special_char_regex.sub(
lambda mo: _special_chars[mo.string[mo.start():mo.end()]], margs
)
margs = re.sub('[\x01-\x1f]', '.', margs)
for file in _log_file_streams:
@ -233,6 +254,8 @@ def parse_log_prio_str(prio: str) -> int: # export
if r < 0 or r > DEVEL:
raise Exception("Invalid log priority ", prio)
except ValueError:
# fmt: disable # don't conflate
# yapf: disable # don't conflate
map_prio_str_to_val = {
"EMERG" : EMERG,
"emerg" : EMERG,
@ -255,6 +278,8 @@ def parse_log_prio_str(prio: str) -> int: # export
"OFF" : OFF,
"off" : OFF,
}
# yapf: enable
# fmt: enable
if prio in map_prio_str_to_val:
return map_prio_str_to_val[prio]
raise Exception("Unknown priority string \"", prio, "\"")
@ -312,18 +337,18 @@ def remove_from_prefix(count) -> str: # export
_clean_log_prefix = _clean_str_regex.sub('', _log_prefix)
return r
def set_filename_length(l: int) -> int: # export
def set_filename_length(length: int) -> int: # export
global _file_name_len
r = _file_name_len
if l:
_file_name_len = l
if length:
_file_name_len = length
return r
def set_module_name_length(l: int) -> int: # export
def set_module_name_length(length: int) -> int: # export
global _module_name_len
r = _module_name_len
if l:
_module_name_len = l
if length:
_module_name_len = length
return r
def add_log_file(path: str) -> None: # export

View file

@ -1,6 +1,11 @@
# -*- coding: utf-8 -*-
import os, errno, atexit, tempfile, filecmp, inspect, importlib, re
import atexit
import errno
import filecmp
import importlib
import inspect
import os
import re
import tempfile
from typing import Iterable
@ -43,7 +48,9 @@ def atomic_store(contents, path): # export
with open(path, 'w') as outfile:
outfile.write(contents)
return
outfile = tempfile.NamedTemporaryFile(prefix=os.path.basename(path), delete=False, dir=os.path.dirname(path))
outfile = tempfile.NamedTemporaryFile(
prefix = os.path.basename(path), delete = False, dir = os.path.dirname(path)
)
name = outfile.name
_tmpfiles.add(name)
outfile.write(contents)
@ -68,8 +75,10 @@ def get_derived_classes(mod, base, flt=None): # export
if inspect.isabstract(c):
log.slog(log.DEBUG, " is abstract")
continue
if not base in inspect.getmro(c):
log.slog(log.DEBUG, " is not derived from", base, "only", inspect.getmro(c))
if base not in inspect.getmro(c):
log.slog(
log.DEBUG, " is not derived from", base, "only", inspect.getmro(c)
)
continue
if flt and not re.match(flt, name):
log.slog(log.DEBUG, ' "{}.{}" has wrong name'.format(mod, name))
@ -89,9 +98,15 @@ def load_class(module_path, baseclass, class_name_filter=None): # export
mod = importlib.import_module(module_path)
classes = get_derived_classes(mod, baseclass, flt = class_name_filter)
if len(classes) == 0:
raise Exception(f'no class matching "{class_name_filter}" of type "{baseclass}" found in module "{module_path}"')
raise Exception(
f'no class matching "{class_name_filter}" of type "{baseclass}" '
f'found in module "{module_path}"'
)
if len(classes) > 1:
raise Exception(f'{len(classes)} classes matching "{class_name_filter}" of type "{baseclass}" found in module "{module_path}"')
raise Exception(
f'{len(classes)} classes matching "{class_name_filter}" of type '
f'"{baseclass}" found in module "{module_path}"'
)
return classes[0]
def load_class_names(path, baseclass, flt = None, remove_flt = False): # export
@ -100,7 +115,7 @@ def load_class_names(path, baseclass, flt=None, remove_flt=False): # export
for c in classes:
name = c.__name__
if flt and remove_flt:
name = re.subst(flt, "", name)
name = re.sub(flt, '', name)
if name not in r:
r.append(name)
else:
@ -108,8 +123,12 @@ def load_class_names(path, baseclass, flt=None, remove_flt=False): # export
#log.slog(log.WARNING, "{} is already in in {}".format(name, r))
return r
def load_object(module_path, baseclass, class_name_filter=None, *args, **kwargs): # export
return load_class(module_path, baseclass, class_name_filter=class_name_filter)(*args, **kwargs)
def load_object( # export
module_path, baseclass, class_name_filter = None, *args, **kwargs
):
return load_class(
module_path, baseclass, class_name_filter = class_name_filter
)(*args, **kwargs)
def load_function(module_path, name): # export
mod = importlib.import_module(module_path)
@ -127,30 +146,32 @@ def commit_tmpfile(tmp: str, path: str) -> None: # export
def multi_regex_edit(spec, strings): # export
for cmd in spec:
if len(cmd) < 2:
raise Exception('Invalid command in multi_regex_edit(): {}'.format(str(cmd)))
raise Exception(
'Invalid command in multi_regex_edit(): {}'.format(str(cmd))
)
if cmd[0] == 'sub':
rx = re.compile(cmd[1])
replacement = cmd[2]
r = []
for l in strings:
r.append(re.sub(rx, replacement, l))
for string in strings:
r.append(re.sub(rx, replacement, string))
strings = r
continue
if cmd[0] == 'del':
rx = re.compile(cmd[1])
r = []
for l in strings:
if rx.search(l) is not None:
for string in strings:
if rx.search(string) is not None:
continue
r.append(l)
r.append(string)
strings = r
continue
if cmd[0] == 'match':
rx = re.compile(cmd[1])
r = []
for l in strings:
if rx.search(l) is not None:
r.append(l)
for string in strings:
if rx.search(string) is not None:
r.append(string)
strings = r
continue
raise Exception('Invalid command in multi_regex_edit(): {}'.format(str(cmd)))
@ -163,7 +184,9 @@ def dump(prio: int, objects: Iterable, *args, **kwargs) -> None: # export
log.append_to_prefix(prefix)
i = 1
for o in objects:
o.dump(prio, "{} ({})".format(i, o.__class__.__name__), caller=caller, **kwargs)
o.dump(
prio, "{} ({})".format(i, o.__class__.__name__), caller = caller, **kwargs
)
i += 1
log.remove_from_prefix(prefix)
log.slog(prio, "`---------- {}".format(' '.join(args)), caller = caller)

View file

@ -1,5 +1,4 @@
# -*- coding: utf-8 -*-
'''
Created on 26 May 2013
@ -14,38 +13,42 @@ ___________________________________
Permission is hereby granted, free of charge, to any person obtaining a copy of this
software and associated documentation files (the "Software"), to deal in the Software
without restriction, including without limitation the rights to use, copy, modify, merge,
publish, distribute, sub-license, and/or sell copies of the Software, and to permit persons
to whom the Software is furnished to do so, subject to the following conditions:
without restriction, including without limitation the rights to use, copy, modify,
merge, publish, distribute, sub-license, and/or sell copies of the Software, and to
permit persons to whom the Software is furnished to do so, subject to the following
conditions:
- The above copyright notice and this permission notice shall be included in all copies
or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT
OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
OTHER DEALINGS IN THE SOFTWARE.
'''
import platform
_python3 = int(platform.python_version_tuple()[0]) >= 3
class multi_key_dict(object):
""" The purpose of this type is to provide a multi-key dictionary.
This kind of dictionary has a similar interface to the standard dictionary, and indeed if used
"""
The purpose of this type is to provide a multi-key dictionary. This kind of
dictionary has a similar interface to the standard dictionary, and indeed if used
with single key key elements - it's behaviour is the same as for a standard dict().
However it also allows for creation of elements using multiple keys (using tuples/lists).
Such elements can be accessed using either of those keys (e.g read/updated/deleted).
Dictionary provides also an extended interface for iterating over items and keys by the key type.
This can be useful e.g.: when creating dictionaries with (index,name) allowing one to iterate over
items using either: names or indexes. It can be useful for many many other similar use-cases,
and there is no limit to the number of keys used to map to the value.
However it also allows for creation of elements using multiple keys (using
tuples/lists). Such elements can be accessed using either of those keys (e.g
read/updated/deleted). Dictionary provides also an extended interface for iterating
over items and keys by the key type. This can be useful e.g.: when creating
dictionaries with (index,name) allowing one to iterate over items using either:
names or indexes. It can be useful for many many other similar use-cases, and there
is no limit to the number of keys used to map to the value.
There are also methods to find other keys mapping to the same value as the specified keys etc.
Refer to examples and test code to see it in action.
There are also methods to find other keys mapping to the same value as the specified
keys etc. Refer to examples and test code to see it in action.
simple example:
k = multi_key_dict()
@ -58,21 +61,26 @@ class multi_key_dict(object):
print k[1000] # will print 'kilo (x1000)'
print k['k'] # will also print 'kilo (x1000)'
# the same way objects can be updated, and if an object is updated using one key, the new value will
# be accessible using any other key, e.g. for example above:
# the same way objects can be updated, and if an object is updated using one
# key, the new value will be accessible using any other key, e.g. for example
# above:
k['kilo'] = 'kilo'
print k[1000] # will print 'kilo' as value was updated
"""
def __init__(self, mapping_or_iterable = None, **kwargs):
""" Initializes dictionary from an optional positional argument and a possibly empty set of keyword arguments."""
""" Initializes dictionary from an optional positional argument and a possibly
empty set of keyword arguments."""
self.items_dict = {}
if mapping_or_iterable is not None:
if type(mapping_or_iterable) is dict:
mapping_or_iterable = mapping_or_iterable.items()
for kv in mapping_or_iterable:
if len(kv) != 2:
raise Exception('Iterable should contain tuples with exactly two values but specified: {0}.'.format(kv))
raise Exception(
'Iterable should contain tuples with exactly two values '
'but specified: {0}.'.format(kv)
)
self[kv[0]] = kv[1]
for keys, value in kwargs.items():
self[keys] = value
@ -89,14 +97,13 @@ class multi_key_dict(object):
(item update)
If this is not the case - KeyError is raised. """
if (type(keys) in [tuple, list]):
at_least_one_key_exists = False
num_of_keys_we_have = 0
for x in keys:
try:
self.__getitem__(x)
num_of_keys_we_have += 1
except Exception as err:
except Exception:
continue
if num_of_keys_we_have:
@ -112,20 +119,22 @@ class multi_key_dict(object):
if new != direct_key:
all_select_same_item = False
break
except Exception as err:
except Exception:
all_select_same_item = False
break;
break
if not all_select_same_item:
raise KeyError(', '.join(str(key) for key in keys))
first_key = keys[0] # combination if keys is allowed, simply use the first one
first_key = keys[
0] # combination if keys is allowed, simply use the first one
else:
first_key = keys
key_type = str(type(first_key)) # find the intermediate dictionary..
if first_key in self:
self.items_dict[self.__dict__[key_type][first_key]] = value # .. and update the object if it exists..
self.items_dict[self.__dict__[key_type][first_key]
] = value # .. and update the object if it exists..
else:
if (type(keys) not in [tuple, list]):
key = keys
@ -136,9 +145,8 @@ class multi_key_dict(object):
""" Called to implement deletion of self[key]."""
key_type = str(type(key))
if (key in self and
self.items_dict and
(self.__dict__[key_type][key] in self.items_dict) ):
if (key in self and self.items_dict
and (self.__dict__[key_type][key] in self.items_dict)):
intermediate_key = self.__dict__[key_type][key]
# remove the item in main dictionary
@ -167,9 +175,9 @@ class multi_key_dict(object):
return key in self
def get_other_keys(self, key, including_current = False):
""" Returns list of other keys that are mapped to the same value as specified key.
@param key - key for which other keys should be returned.
@param including_current if set to True - key will also appear on this list."""
""" Returns list of other keys that are mapped to the same value as specified
key. @param key - key for which other keys should be returned. @param
including_current if set to True - key will also appear on this list."""
other_keys = []
if key in self:
other_keys.extend(self.__dict__[str(type(key))][key])
@ -179,10 +187,15 @@ class multi_key_dict(object):
def iteritems(self, key_type = None, return_all_keys = False):
""" Returns an iterator over the dictionary's (key, value) pairs.
@param key_type if specified, iterator will be returning only (key,value) pairs for this type of key.
Otherwise (if not specified) ((keys,...), value)
i.e. (tuple of keys, values) pairs for all items in this dictionary will be generated.
@param return_all_keys if set to True - tuple of keys is retuned instead of a key of this type."""
@param key_type if specified, iterator will be returning only (key,value)
pairs for this type of key.
Otherwise (if not specified) ((keys,...), value) i.e. (tuple of keys,
values) pairs for all items in this dictionary will be generated.
@param return_all_keys if set to True - tuple of keys is retuned instead of
a key of this type."""
if key_type is None:
for item in self.items_dict.items():
@ -202,10 +215,12 @@ class multi_key_dict(object):
def iterkeys(self, key_type = None, return_all_keys = False):
""" Returns an iterator over the dictionary's keys.
@param key_type if specified, iterator for a dictionary of this type will be used.
@param key_type if specified, iterator for a dictionary of this type will
be used.
Otherwise (if not specified) tuples containing all (multiple) keys
for this dictionary will be generated.
@param return_all_keys if set to True - tuple of keys is retuned instead of a key of this type."""
@param return_all_keys if set to True - tuple of keys is retuned instead of
a key of this type."""
if (key_type is not None):
the_key = str(key_type)
if the_key in self.__dict__:
@ -220,8 +235,11 @@ class multi_key_dict(object):
def itervalues(self, key_type = None):
""" Returns an iterator over the dictionary's values.
@param key_type if specified, iterator will be returning only values pointed by keys of this type.
Otherwise (if not specified) all values in this dictinary will be generated."""
@param key_type if specified, iterator will be returning only values pointed
by keys of this type.
Otherwise (if not specified) all values in this dictinary will be
generated."""
if (key_type is not None):
intermediate_key = str(key_type)
if intermediate_key in self.__dict__:
@ -232,16 +250,19 @@ class multi_key_dict(object):
yield value
if _python3:
items = iteritems
items = iteritems # type: ignore
else:
def items(self, key_type = None, return_all_keys = False):
return list(self.iteritems(key_type, return_all_keys))
items.__doc__ = iteritems.__doc__
def keys(self, key_type = None):
""" Returns a copy of the dictionary's keys.
@param key_type if specified, only keys for this type will be returned.
Otherwise list of tuples containing all (multiple) keys will be returned."""
Otherwise list of tuples containing all (multiple) keys will be
returned."""
if key_type is not None:
intermediate_key = str(key_type)
if intermediate_key in self.__dict__:
@ -254,15 +275,17 @@ class multi_key_dict(object):
def values(self, key_type = None):
""" Returns a copy of the dictionary's values.
@param key_type if specified, only values pointed by keys of this type will be returned.
Otherwise list of all values contained in this dictionary will be returned."""
@param key_type if specified, only values pointed by keys of this type
will be returned
Otherwise list of all values contained in this dictionary will be
returned."""
if (key_type is not None):
all_items = {} # in order to preserve keys() type (dict_values for python3)
keys_used = set()
direct_key = str(key_type)
if direct_key in self.__dict__:
for intermediate_key in self.__dict__[direct_key].values():
if not intermediate_key in keys_used:
if intermediate_key not in keys_used:
all_items[intermediate_key] = self.items_dict[intermediate_key]
keys_used.add(intermediate_key)
return all_items.values()
@ -279,19 +302,22 @@ class multi_key_dict(object):
def __add_item(self, item, keys = None):
""" Internal method to add an item to the multi-key dictionary"""
if (not keys or not len(keys)):
raise Exception('Error in %s.__add_item(%s, keys=tuple/list of items): need to specify a tuple/list containing at least one key!'
% (self.__class__.__name__, str(item)))
raise Exception(
'Error in %s.__add_item(%s, keys=tuple/list of items): need to specify'
'a tuple/list containing at least one key!' %
(self.__class__.__name__, str(item))
)
direct_key = tuple(keys) # put all keys in a tuple, and use it as a key
for key in keys:
key_type = str(type(key))
# store direct key as a value in an intermediate dictionary
if(not key_type in self.__dict__):
if (key_type not in self.__dict__):
self.__setattr__(key_type, dict())
self.__dict__[key_type][key] = direct_key
# store the value in the actual dictionary
if(not 'items_dict' in self.__dict__):
if ('items_dict' not in self.__dict__):
self.items_dict = dict()
self.items_dict[direct_key] = item
@ -304,17 +330,21 @@ class multi_key_dict(object):
def __str__(self):
items = []
str_repr = lambda x: '\'%s\'' % x if type(x) == str else str(x)
def str_repr(x):
return '\'%s\'' % x if isinstance(x, str) else str(x)
if hasattr(self, 'items_dict'):
for (keys, value) in self.items():
keys_str = [str_repr(k) for k in keys]
items.append('(%s): %s' % (', '.join(keys_str),
str_repr(value)))
items.append('(%s): %s' % (', '.join(keys_str), str_repr(value)))
dict_str = '{%s}' % (', '.join(items))
return dict_str
def test_multi_key_dict():
contains_all = lambda cont, in_items: not (False in [c in cont for c in in_items])
def contains_all(cont, in_items):
return False not in [c in cont for c in in_items]
m = multi_key_dict()
assert (len(m) == 0), 'expected len(m) == 0'
@ -327,19 +357,27 @@ def test_multi_key_dict():
# try retrieving other keys mapped to the same value using one of them
res = m.get_other_keys('aa')
expected = ['mmm', 32, 12]
assert(set(res) == set(expected)), 'get_other_keys(\'aa\'): {0} other than expected: {1} '.format(res, expected)
# try retrieving other keys mapped to the same value using one of them: also include this key
assert (set(res) == set(expected)), (
'get_other_keys(\'aa\'): {0} other '
'than expected: {1} '.format(res, expected)
)
# try retrieving other keys mapped to the same value using one of them: also include
# this key
res = m.get_other_keys(32, True)
expected = ['aa', 'mmm', 32, 12]
assert(set(res) == set(expected)), 'get_other_keys(32): {0} other than expected: {1} '.format(res, expected)
assert (set(res) == set(expected)), (
'get_other_keys(32): {0} other than expected: '
'{1} '.format(res, expected)
)
assert( m.has_key('aa') == True ), 'expected m.has_key(\'aa\') == True'
assert( m.has_key('aab') == False ), 'expected m.has_key(\'aab\') == False'
assert (m.has_key('aa')), 'expected m.has_key(\'aa\') == True'
assert (not m.has_key('aab')), 'expected m.has_key(\'aab\') == False'
assert( m.has_key(12) == True ), 'expected m.has_key(12) == True'
assert( m.has_key(13) == False ), 'expected m.has_key(13) == False'
assert( m.has_key(32) == True ), 'expected m.has_key(32) == True'
assert (m.has_key(12)), 'expected m.has_key(12) == True'
assert (not m.has_key(13)), 'expected m.has_key(13) == False'
assert (m.has_key(32)), 'expected m.has_key(32) == True'
m['something else'] = 'abcd'
assert (len(m) == 2), 'expected len(m) == 2'
@ -367,11 +405,16 @@ def test_multi_key_dict():
assert (m[12] == '4'), 'expected m[12] == \'4\''
# test __str__
m_str_exp = '{(23): 0, (\'aa\', \'mmm\', 32, 12): \'4\', (\'something else\'): \'abcd\'}'
m_str_exp = (
'{(23): 0, (\'aa\', \'mmm\', 32, 12): \'4\', '
'(\'something else\'): \'abcd\'}'
)
m_str = str(m)
assert (len(m_str) > 0), 'str(m) should not be empty!'
assert (m_str[0] == '{'), 'str(m) should start with \'{\', but does with \'%c\'' % m_str[0]
assert (m_str[-1] == '}'), 'str(m) should end with \'}\', but does with \'%c\'' % m_str[-1]
assert (m_str[0] == '{'
), ('str(m) should start with \'{\', but does with \'%c\'' % m_str[0])
assert (m_str[-1] == '}'
), ('str(m) should end with \'}\', but does with \'%c\'' % m_str[-1])
# check if all key-values are there as expected. They might be sorted differently
def get_values_from_str(dict_str):
@ -381,41 +424,52 @@ def test_multi_key_dict():
keys = tuple(sorted([k.strip() for k in keys.split(',')]))
sorted_keys_and_values.append((keys, val))
return sorted_keys_and_values
exp = get_values_from_str(m_str_exp)
act = get_values_from_str(m_str)
assert (set(act) == set(exp)), 'str(m) values: \'{0}\' are not {1} '.format(act, exp)
assert (set(act) == set(exp)
), ('str(m) values: \'{0}\' are not {1} '.format(act, exp))
# try accessing / creating new (keys)-> value mapping whilst one of these
# keys already maps to a value in this dictionaries
try:
m['aa', 'bb'] = 'something new'
assert(False), 'Should not allow adding multiple-keys when one of keys (\'aa\') already exists!'
except KeyError as err:
assert(False), (
'Should not allow adding multiple-keys when one of keys '
'(\'aa\') already exists!'
)
except KeyError:
pass
# now check if we can get all possible keys (formed in a list of tuples)
# each tuple containing all keys)
res = sorted([sorted([str(x) for x in k]) for k in m.keys()])
res = sorted([sorted([str(x) for x in k]) for k in m.keys()]) # type: ignore
expected = sorted([sorted([str(x) for x in k]) for k in all_keys])
assert(res == expected), 'unexpected values from m.keys(), got:\n%s\n expected:\n%s' %(res, expected)
assert (res == expected), (
'unexpected values from m.keys(), got:\n%s\n expected:\n%s' % (res, expected)
)
# check default items (which will unpack tupe with key(s) and value)
num_of_elements = 0
for keys, value in m.items():
sorted_keys = sorted([str(k) for k in keys])
num_of_elements += 1
assert(sorted_keys in expected), 'm.items(): unexpected keys: %s' % (sorted_keys)
assert(m[keys[0]] == value), 'm.items(): unexpected value: %s (keys: %s)' % (value, keys)
assert(num_of_elements > 0), 'm.items() returned generator that did not produce anything'
assert (sorted_keys
in expected), ('m.items(): unexpected keys: %s' % (sorted_keys))
assert (m[keys[0]] == value
), ('m.items(): unexpected value: %s (keys: %s)' % (value, keys))
assert (num_of_elements
> 0), ('m.items() returned generator that did not produce anything')
# test default iterkeys()
num_of_elements = 0
for keys in m.keys():
for keys in m.keys(): # type: ignore
num_of_elements += 1
keys_s = sorted([str(k) for k in keys])
assert (keys_s in expected), 'm.keys(): unexpected keys: {0}'.format(keys_s)
assert(num_of_elements > 0), 'm.iterkeys() returned generator that did not produce anything'
assert (num_of_elements
> 0), ('m.iterkeys() returned generator that did not produce anything')
# test iterkeys(int, True): useful to get all info from the dictionary
# dictionary is iterated over the type specified, but all keys are returned.
@ -423,49 +477,67 @@ def test_multi_key_dict():
for keys in m.iterkeys(int, True):
keys_s = sorted([str(k) for k in keys])
num_of_elements += 1
assert(keys_s in expected), 'm.iterkeys(int, True): unexpected keys: {0}'.format(keys_s)
assert(num_of_elements > 0), 'm.iterkeys(int, True) returned generator that did not produce anything'
assert (keys_s in expected
), ('m.iterkeys(int, True): unexpected keys: {0}'.format(keys_s))
assert (num_of_elements > 0), (
'm.iterkeys(int, True) returned generator that did not produce anything'
)
# test values for different types of keys()
expected = set([0, '4'])
res = set(m.values(int))
assert (res == expected), 'm.values(int) are {0}, but expected: {1}.'.format(res, expected)
assert (res == expected
), ('m.values(int) are {0}, but expected: {1}.'.format(res, expected))
expected = sorted(['4', 'abcd'])
res = sorted(m.values(str))
assert (res == expected), 'm.values(str) are {0}, but expected: {1}.'.format(res, expected)
assert (res == expected
), ('m.values(str) are {0}, but expected: {1}.'.format(res, expected))
current_values = set([0, '4', 'abcd']) # default (should give all values)
res = set(m.values())
assert (res == current_values), 'm.values() are {0}, but expected: {1}.'.format(res, current_values)
assert (res == current_values
), ('m.values() are {0}, but expected: {1}.'.format(res, current_values))
#test itervalues() (default) - should return all values. (Itervalues for other types
# are tested below)
#test itervalues() (default) - should return all values. (Itervalues for other types are tested below)
vals = set()
for value in m.itervalues():
vals.add(value)
assert (current_values == vals), 'itervalues(): expected {0}, but collected {1}'.format(current_values, vals)
assert (current_values == vals), (
'itervalues(): expected {0}, but collected {1}'.format(current_values, vals)
)
#test items(int)
items_for_int = sorted([((12, 32), '4'), ((23, ), 0)])
assert (items_for_int == sorted(m.items(int))), 'items(int): expected {0}, but collected {1}'.format(items_for_int,
sorted(m.items(int)))
assert (items_for_int == sorted(m.items(int))), (
'items(int): expected {0}, but collected {1}'.format(
items_for_int, sorted(m.items(int))
)
)
# test items(str)
items_for_str = set([(('aa', 'mmm'), '4'), (('something else', ), 'abcd')])
res = set(m.items(str))
assert (set(res) == items_for_str), 'items(str): expected {0}, but collected {1}'.format(items_for_str, res)
assert (set(res) == items_for_str), (
'items(str): expected {0}, but collected {1}'.format(items_for_str, res)
)
# test items() (default - all items)
# we tested keys(), values(), and __get_item__ above so here we'll re-create all_items using that
# we tested keys(), values(), and __get_item__ above so here we'll re-create
# all_items using that
all_items = set()
keys = m.keys()
values = m.values()
for k in keys:
m.values()
for k in keys: # type: ignore
all_items.add((tuple(k), m[k[0]]))
res = set(m.items())
assert (all_items == res), 'items() (all items): expected {0},\n\t\t\t\tbut collected {1}'.format(all_items, res)
assert (all_items == res), (
'items() (all items): expected {0},\n\t\t\t\tbut '
'collected {1}'.format(all_items, res)
)
# now test deletion..
curr_len = len(m)
@ -477,21 +549,21 @@ def test_multi_key_dict():
try:
del m['aa']
assert (False), 'cant remove again: item m[\'aa\'] should not exist!'
except KeyError as err:
except KeyError:
pass
# try to access non-existing
try:
k = m['aa']
assert (False), 'removed item m[\'aa\'] should not exist!'
except KeyError as err:
except KeyError:
pass
# try to access non-existing with a different key
try:
k = m[12]
assert (False), 'removed item m[12] should not exist!'
except KeyError as err:
except KeyError:
pass
# prepare for other tests (also testing creation of new items)
@ -503,7 +575,8 @@ def test_multi_key_dict():
# test items()
for key, value in m.items(int):
assert(key == (value,)), 'items(int): expected {0}, but received {1}'.format(key, value)
assert (key == (value, )
), ('items(int): expected {0}, but received {1}'.format(key, value))
# test iterkeys()
num_of_elements = 0
@ -511,9 +584,10 @@ def test_multi_key_dict():
for key in m.iterkeys(int):
returned_keys.add(key)
num_of_elements += 1
assert(num_of_elements > 0), 'm.iteritems(int) returned generator that did not produce anything'
assert (returned_keys == set(tst_range)), 'iterkeys(int): expected {0}, but received {1}'.format(expected, key)
assert (num_of_elements
> 0), ('m.iteritems(int) returned generator that did not produce anything')
assert (returned_keys == set(tst_range)
), ('iterkeys(int): expected {0}, but received {1}'.format(expected, key))
#test itervalues(int)
num_of_elements = 0
@ -521,22 +595,26 @@ def test_multi_key_dict():
for value in m.itervalues(int):
returned_values.add(value)
num_of_elements += 1
assert (num_of_elements > 0), 'm.itervalues(int) returned generator that did not produce anything'
assert (returned_values == set(tst_range)), 'itervalues(int): expected {0}, but received {1}'.format(expected, value)
assert (num_of_elements > 0
), ('m.itervalues(int) returned generator that did not produce anything')
assert (returned_values == set(tst_range)), (
'itervalues(int): expected {0}, '
'but received {1}'.format(expected, value)
)
# test values(int)
res = sorted([x for x in m.values(int)])
assert (res == tst_range), 'm.values(int) is not as expected.'
# test keys()
assert (set(m.keys(int)) == set(tst_range)), 'm.keys(int) is not as expected.'
assert (set(m.keys(int)) == set(tst_range)), 'm.keys(int) is not as expected.' # type: ignore
# test setitem with multiple keys
m['xy', 999, 'abcd'] = 'teststr'
try:
m['xy', 998] = 'otherstr'
assert (False), 'creating / updating m[\'xy\', 998] should fail!'
except KeyError as err:
except KeyError:
pass
# test setitem with multiple keys
@ -544,7 +622,7 @@ def test_multi_key_dict():
try:
m['cd', 999] = 'otherstr'
assert (False), 'creating / updating m[\'cd\', 999] should fail!'
except KeyError as err:
except KeyError:
pass
m['xy', 999] = 'otherstr'
@ -560,7 +638,7 @@ def test_multi_key_dict():
# test get functionality of basic dictionaries
m['CanIGet'] = 'yes'
assert (m.get('CanIGet') == 'yes')
assert (m.get('ICantGet') == None)
assert (m.get('ICantGet') is None)
assert (m.get('ICantGet', "Ok") == "Ok")
k = multi_key_dict()
@ -570,29 +648,39 @@ def test_multi_key_dict():
import datetime
n = datetime.datetime.now()
l = multi_key_dict()
l[n] = 'now' # use datetime obj as a key
d = multi_key_dict()
d[n] = 'now' # use datetime obj as a key
#test keys..
res = [x for x in l.keys()][0] # for python3 keys() returns dict_keys dictionarly
res = [
x for x in d.keys(
) # type: ignore # for python3 keys() returns dict_keys dictionarly
][0]
expected = n,
assert (expected == res), 'Expected \"{0}\", but got: \"{1}\"'.format(expected, res)
res = [x for x in l.keys(datetime.datetime)][0]
res = [x for x in d.keys(datetime.datetime)][0] # type: ignore
assert (n == res), 'Expected {0} as a key, but got: {1}'.format(n, res)
res = [x for x in l.values()] # for python3 keys() returns dict_values dictionarly
res = [x for x in d.values()] # for python3 keys() returns dict_values dictionarly
expected = ['now']
assert (res == expected), 'Expected values: {0}, but got: {1}'.format(expected, res)
# test items..
exp_items = [((n, ), 'now')]
r = list(l.items())
assert(r == exp_items), 'Expected for items(): tuple of keys: {0}, but got: {1}'.format(r, exp_items)
assert(exp_items[0][1] == 'now'), 'Expected for items(): value: {0}, but got: {1}'.format('now',
exp_items[0][1])
r = list(d.items())
assert (r == exp_items), (
'Expected for items(): tuple of keys: {0}, but got: {1}'.format(r, exp_items)
)
assert (exp_items[0][1] == 'now'), (
'Expected for items(): value: {0}, but got: {1}'.format('now', exp_items[0][1])
)
x = multi_key_dict({('k', 'kilo'):1000, ('M', 'MEGA', 1000000):1000000}, milli=0.01)
x = multi_key_dict(
{
('k', 'kilo'): 1000, ('M', 'MEGA', 1000000): 1000000
}, milli = 0.01
)
assert (x['k'] == 1000), 'x[\'k\'] is not equal to 1000'
x['kilo'] = 'kilo'
assert (x['kilo'] == 'kilo'), 'x[\'kilo\'] is not equal to \'kilo\''
@ -605,8 +693,10 @@ def test_multi_key_dict():
try:
y = multi_key_dict([(('two', 'duo'), 2), ('one', 'uno', 1), ('three', 3)])
assert(False), 'creating dictionary using iterable with tuples of size > 2 should fail!'
except:
assert (False), (
'creating dictionary using iterable with tuples of size > 2 should fail!'
)
except Exception:
pass
print('All test passed OK!')
@ -618,4 +708,3 @@ if __name__ == '__main__':
test_multi_key_dict()
except KeyboardInterrupt:
print('\n(interrupted by user)')

View file

@ -1,14 +1,16 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
from typing import Any, List, Optional, Union
import fnmatch
import re
import re, fnmatch
from collections import OrderedDict
from enum import Enum, auto
from typing import TYPE_CHECKING
from ..log import *
if TYPE_CHECKING:
from typing import List, Optional, Union, Any
from ..log import DEBUG, get_caller_pos, slog
def quote(s):
if is_quoted(s):
@ -40,7 +42,9 @@ def cleanup_string(s: str) -> str:
class StringTree: # export
def __init__(self, path: str, content: str, parent: StringTree|None=None) -> None:
def __init__(
self, path: str, content: str, parent: StringTree | None = None
) -> None:
slog(DEBUG, f'Constructing StringTree(path="{path}", content="{content}")')
self.__parent = parent
self.children: OrderedDict[str, StringTree] = OrderedDict()
@ -65,33 +69,59 @@ class StringTree: # export
#parent.dump(INFO, "These children are added")
self.content = parent.content
for name, c in parent.children.items():
if not name in self.children.keys():
if name not in self.children.keys():
slog(DEBUG, f'At {self.content}: Adding new child {c}')
self.children[name] = c
else:
self.children[name].__adopt_children(c)
def __set(self, path_, content, split = True):
slog(DEBUG, ('At "{}": '.format(str(self.content)) if hasattr(self, "content") else "") + f'Setting "{path_}" -> "{content}"')
#assert self.content != str(content) # Not sure what the idea behind this was. It often goes off, and all works fine without.
if content is not None and not type(content) in [str, StringTree]:
raise Exception("Tried to add content of unsupported type {}".format(type(content).__name__))
slog(
DEBUG,
('At "{}": '.format(str(self.content)) if hasattr(self, "content") else "")
+ f'Setting "{path_}" -> "{content}"'
)
# Not sure what the idea behind this was. It often goes off, and all
# works fine without.
#assert self.content != str(content)
if content is not None and type(content) not in [str, StringTree]:
raise Exception(
"Tried to add content of unsupported type {}".format(
type(content).__name__
)
)
if path_ is None:
if isinstance(content, str):
self.content = cleanup_string(content)
elif isinstance(content, StringTree):
self.__adopt_children(content)
else:
raise Exception("Tried to add content of unsupported type {}".format(type(content).__name__))
slog(DEBUG, " -- content = >" + str(content) + "<, self.content = >" + str(self.content) + "<")
raise Exception(
"Tried to add content of unsupported type {}".format(
type(content).__name__
)
)
slog(
DEBUG,
" -- content = >" + str(content) + "<, self.content = >" +
str(self.content) + "<"
)
return self
path = cleanup_string(path_)
components = path.split('.') if split else [path]
l = len(components)
if len(path) == 0 or l == 0:
#assert self.content is None or (isinstance(content, StringTree) and content.content == self.content)
length = len(components)
if len(path) == 0 or length == 0:
#assert self.content is None or (
# isinstance(content, StringTree) and content.content == self.content
#)
if isinstance(content, StringTree):
#assert isinstance(content, StringTree), "Type: " + type(content).__name__
#assert isinstance(content, StringTree), (
# f'Type: {type(content).__name__ }'
#)
self.__adopt_children(content)
else:
if self.content != content:
@ -103,22 +133,25 @@ class StringTree: # export
#self.children[content] = StringTree(None, content)
return self
#assert self.content is not None, "tried to set empty content to {}".format(path_)
#assert self.content is not None, f'Tried to set empty content to "{path_}"'
nibble = components[0]
rest = '.'.join(components[1:])
if nibble not in self.children:
self.children[nibble] = StringTree('', content = nibble, parent = self)
if l > 1:
if length > 1:
assert len(rest) > 0
return self.children[nibble].__set(rest, content = content)
# last component, a.k.a. leaf
if content is not None:
gc = content if isinstance(content, StringTree) else StringTree('', content=content, parent=self.children[nibble])
gc = content if isinstance(content, StringTree) else StringTree(
'', content = content, parent = self.children[nibble]
)
# Make sure no existing grand child is updated. It would reside too
# far up in the grand child OrderedDict, we need it last
if gc.content in self.children[nibble].children:
del self.children[nibble].children[gc.content]
assert gc.content is not None, 'Grand-child content is None'
self.children[nibble].children[gc.content] = gc
return self.children[nibble]
@ -164,7 +197,12 @@ class StringTree: # export
raise Exception("Tried to set empty content")
self.content = content
def add(self, path: str, content: Optional[Union[str, StringTree]] = None, split: bool = True) -> StringTree:
def add(
self,
path: str,
content: Optional[Union[str, StringTree]] = None,
split: bool = True
) -> StringTree:
slog(DEBUG, f'-- At "{self.content}": Adding "{path}" -> "{content}"')
return self.__set(path, content, split)
@ -176,7 +214,7 @@ class StringTree: # export
slog(DEBUG, "returning myself")
return self
if is_quoted(path_):
if not path in self.children.keys():
if path not in self.children.keys():
return None
return self.children[path]
components = path.split('.')
@ -185,7 +223,7 @@ class StringTree: # export
name = cleanup_string(components[0])
if not hasattr(self, "children"):
return None
if not name in self.children.keys():
if name not in self.children.keys():
slog(DEBUG, "Name \"" + name + "\" is not in children of", self.content)
for child in self.children:
slog(DEBUG, "child = ", child)
@ -217,8 +255,11 @@ class StringTree: # export
return self.__parent.root
def child_list(self, depth_first: bool = True) -> List[StringTree]:
if depth_first == False:
raise Exception("tried to retrieve child list with breadth-first search, not yet implemented")
if not depth_first:
raise Exception(
'Tried to retrieve child list with breadth-first '
'search, not yet implemented'
)
r = []
for name, c in self.children.items():
r.append(c)
@ -241,21 +282,19 @@ class StringTree: # export
GlobArg = auto()
GlobConf = auto()
def __find(self, key: str|None, val: str|None, match: Match, depth_first: bool):
def __find(self, key: str | None, val: str | None, m: Match, depth_first: bool):
def __children():
for name, child in self.children.items():
ret.extend(child.__find(key, val, match, depth_first))
ret.extend(child.__find(key, val, m, depth_first))
def __self():
_val = self.value()
_content = self.content
try:
if (
(key == _content and matcher(val, _val))
if ((key == _content and matcher(val, _val))
or (key is None and matcher(val, _val))
or (key == _content and val is None)
):
or (key == _content and val is None)):
ret.append(self)
except Exception as e:
if isinstance(e, re.PatternError):
@ -263,29 +302,33 @@ class StringTree: # export
else:
raise
def __select_matcher(m: StringTree.Match) -> Any:
match m:
case self.Match.Equal:
return lambda x, y: x == y
case self.Match.RegExArg:
return lambda x, y: re.search(x, y) is not None
case self.Match.RegExConf:
return lambda x, y: re.search(y, x) is not None
case self.Match.GlobArg:
return lambda x, y: fnmatch.fnmatch(y, x)
case self.Match.GlobConf:
return lambda x, y: fnmatch.fnmatch(x, y)
case _:
raise NotImplementedError(f'Matcher {m} is not yet implemented')
def __debug_matcher(matcher, log_level = DEBUG):
def __matcher(x, y):
slog(log_level, f'Comparing "{x}" ~ "{y}"')
return matcher(x, y)
return __matcher
if not self.children:
return []
matcher = lambda x, y: x == y
match match:
case self.Match.Equal:
pass
case self.Match.RegExArg:
matcher = lambda x, y: re.search(x, y) is not None
case self.Match.RegExConf:
matcher = lambda x, y: re.search(y, x) is not None
case self.Match.GlobArg:
matcher = lambda x, y: fnmatch.fnmatch(y, x)
case self.Match.GlobConf:
matcher = lambda x, y: fnmatch.fnmatch(x, y)
case _:
raise NotImplementedError(f'Matcher {match} is not yet implemented')
matcher = __select_matcher(m)
ret: list[StringTree] = []
@ -298,5 +341,16 @@ class StringTree: # export
return ret
def find(self, key: str|None=None, val: str|None=None, match: Match=Match.Equal, depth_first: bool=False):
return [ node.parent.path for node in self.__find(key, val, match=match, depth_first=depth_first)]
def find(
self,
key: str | None = None,
val: str | None = None,
match: Match = Match.Equal,
depth_first: bool = False
):
ret: list[str] = []
for node in self.__find(key, val, m = match, depth_first = depth_first):
if node.parent is None:
break
ret.append(node.parent.path)
return ret

View file

@ -1,9 +1,9 @@
# -*- coding: utf-8 -*-
import glob
import os
import re
import os, glob
from .StringTree import *
from ..log import *
from ..log import DEBUG, ERR, INFO, slog, slog_m
from .StringTree import StringTree, cleanup_string
def _cleanup_line(line: str) -> str:
line = line.strip()
@ -24,7 +24,11 @@ def _cleanup_line(line: str) -> str:
return r[1:-1]
return r
def parse(s: str, allow_full_lines: bool=True, root_content: str='root') -> StringTree: # export
def parse( # export
s: str,
allow_full_lines: bool = True,
root_content: str = 'root'
) -> StringTree:
slog_m(DEBUG, "--->--- parsing --->---\n" + s + "\n---<--- parsing ---<---\n")
root = StringTree('', content = root_content)
sec = ''
@ -70,7 +74,9 @@ def parse(s: str, allow_full_lines: bool=True, root_content: str='root') -> Stri
root.add(sec + '.' + cleanup_string(lhs), cleanup_string(rhs), split = split)
return root
def _read_lines_from_one_path(path: str, throw=True, level=0, log_prio=INFO, paths_buf=None):
def _read_lines_from_one_path(
path: str, throw = True, level = 0, log_prio = INFO, paths_buf = None
):
try:
with open(path, 'r') as infile:
slog(log_prio, 'Reading {}"{}"'.format(' ' * level * 2, path))
@ -86,7 +92,12 @@ def _read_lines_from_one_path(path: str, throw=True, level=0, log_prio=INFO, pat
dir_name = os.path.dirname(path)
if len(dir_name):
include_path = dir_name + '/' + include_path
include_lines = _read_lines(include_path, throw=(not optional), level=level+1, paths_buf=paths_buf)
include_lines = _read_lines(
include_path,
throw = (not optional),
level = level + 1,
paths_buf = paths_buf
)
if include_lines is None:
slog(DEBUG, f'{path}: Failed to process "{line}"')
continue
@ -104,13 +115,22 @@ def _read_lines(path: str, throw=True, level=0, log_prio=INFO, paths_buf=None):
paths = glob.glob(path)
ret = []
for p in paths:
rr = _read_lines_from_one_path(p, throw=throw, level=level, log_prio=log_prio, paths_buf=paths_buf)
rr = _read_lines_from_one_path(
p, throw = throw, level = level, log_prio = log_prio, paths_buf = paths_buf
)
if rr is None:
return None
ret.extend(rr)
return ret
def read(path: str, root_content: str='root', log_prio=INFO, paths_buf=None) -> StringTree: # export
def read( # export
path: str,
root_content: str = 'root',
log_prio = INFO,
paths_buf = None
) -> StringTree:
lines = _read_lines_from_one_path(path, log_prio = log_prio, paths_buf = paths_buf)
if lines is None:
raise Exception(f'Could not read ini file from "{path}"')
s = ''.join(lines)
return parse(s, root_content = root_content)