Makefile: Include py-topdir.mk

Include py-topdir.mk, which entails loads of fallout from make check. Fix it.

Signed-off-by: Jan Lindemann <jan@janware.com>
This commit is contained in:
Jan Lindemann 2026-06-10 13:28:35 +02:00
commit e9845b5a1f
Signed by: Jan Lindemann
GPG key ID: 3750640C9E25DD61
45 changed files with 1796 additions and 1191 deletions

View file

@ -2,3 +2,4 @@ TOPDIR = .
include $(TOPDIR)/make/proj.mk include $(TOPDIR)/make/proj.mk
include $(JWBDIR)/make/topdir.mk include $(JWBDIR)/make/topdir.mk
include $(JWBDIR)/make/py-topdir.mk

View file

@ -1,20 +1,17 @@
# -*- coding: utf-8 -*-
import argparse import argparse
from collections import OrderedDict from collections import OrderedDict
from .log import slog from .log import get_caller_pos, slog
class ArgsContainer: # export class ArgsContainer: # export
__args: OrderedDict[str, str] = OrderedDict() __args: OrderedDict[str, str] = OrderedDict()
__kwargs: OrderedDict[str, str] = OrderedDict() __kwargs: OrderedDict[str, dict[str, str]] = OrderedDict()
__values: dict[str, str] = {} __values: dict[str, str] = {}
__specified_args: list[str] = list() __specified_args: list[str] = list()
def __getattr__(self, name): def __getattr__(self, name: str) -> str:
values = self.__values
if name in self.__values: if name in self.__values:
return self.__values[name] return self.__values[name]
if name in self.__kwargs.keys(): if name in self.__kwargs.keys():
@ -25,7 +22,7 @@ class ArgsContainer: # export
raise Exception(f'No argument "{name}" defined') raise Exception(f'No argument "{name}" defined')
def __setattr__(self, name, value): def __setattr__(self, name, value):
if not name in self.__kwargs.keys(): if name not in self.__kwargs.keys():
raise Exception(f'No argument "{name}" defined') raise Exception(f'No argument "{name}" defined')
self.__values[name] = value self.__values[name] = value
self.__specified_args.append(name) self.__specified_args.append(name)
@ -41,7 +38,7 @@ class ArgsContainer: # export
else: else:
raise Exception('Missing argument name') raise Exception('Missing argument name')
name = name.replace('-', '_') name = name.replace('-', '_')
self.__args[name] = args self.__args[name] = arg
self.__kwargs[name] = kwargs self.__kwargs[name] = kwargs
def keys(self): def keys(self):
@ -50,7 +47,7 @@ class ArgsContainer: # export
def args(self, name) -> str: def args(self, name) -> str:
return self.__args[name] return self.__args[name]
def kwargs(self, name) -> str: def kwargs(self, name) -> dict[str, str]:
return self.__kwargs[name] return self.__kwargs[name]
def dump(self, prio, *args, **kwargs): def dump(self, prio, *args, **kwargs):
@ -59,15 +56,17 @@ class ArgsContainer: # export
val = None val = None
try: try:
val = self.__getattr__(name) val = self.__getattr__(name)
except: except Exception:
pass pass
slog(prio, f'{name}: {val}', caller=caller) slog(prio, f'{name}: {val}', caller = caller)
@property @property
def specified_args(self): def specified_args(self):
return self.__specified_args return self.__specified_args
def add_argument(p: argparse.ArgumentParser|ArgsContainer, name: str, *args, **kwargs): # export def add_argument( # export
p: argparse.ArgumentParser | ArgsContainer, name: str, *args, **kwargs
):
key = name.strip('--').replace('-', '_') key = name.strip('--').replace('-', '_')
if isinstance(p, ArgsContainer): if isinstance(p, ArgsContainer):

View file

@ -1,8 +1,6 @@
# -*- coding: utf-8 -*-
from typing import Any from typing import Any
class Bunch: # export class Bunch: # export
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.__dict__.update(kwargs) self.__dict__.update(kwargs)

View file

@ -1,39 +1,49 @@
# -*- coding: utf-8 -*-
from __future__ import annotations from __future__ import annotations
import inspect, sys, re, abc, argparse
from argparse import ArgumentParser, _SubParsersAction import abc
import argparse
import inspect
import re
import sys
from argparse import ArgumentParser
from typing import TYPE_CHECKING
from . import log from . import log
if TYPE_CHECKING:
from .Cmds import Cmds
# full blown example of one level of nested subcommands # full blown example of one level of nested subcommands
# git -C project remote -v show -n myremote # git -C project remote -v show -n myremote
class Cmd(abc.ABC): # export class Cmd(abc.ABC): # export
@abc.abstractmethod @abc.abstractmethod
async def run(self, args): async def run(self, args):
pass pass
def __init__(self, name: str, help: str) -> None: def __init__(self, name: str, help: str) -> None:
from . import Cmds
self.name = name self.name = name
self.help = help self.help = help
self.parent = None self.parent = None
self.children: list[Cmd] = [] self.children: list[Cmd] = []
self.child_classes: list[type[Cmd]] = [] self.child_classes: list[type[Cmd]] = []
self.app: Cmds|None = None self.app: Cmds | None = None
async def _run(self, args): async def _run(self, args):
pass pass
def add_parser(self, parsers) -> ArgumentParser: def add_parser(self, parsers) -> ArgumentParser:
r = parsers.add_parser(self.name, help=self.help, r = parsers.add_parser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter) self.name,
r.set_defaults(func=self.run) help = self.help,
formatter_class = argparse.ArgumentDefaultsHelpFormatter
)
r.set_defaults(func = self.run)
return r return r
def add_subcommands(self, cmd: str|type[Cmd]|list[type[Cmd]]) -> None: def add_subcommands(self, cmd: str | type[Cmd] | list[type[Cmd]]) -> None:
if isinstance(cmd, str): if isinstance(cmd, str):
sc = [] sc = []
for name, obj in inspect.getmembers(sys.modules[self.__class__.__module__]): for name, obj in inspect.getmembers(sys.modules[self.__class__.__module__]):
@ -54,7 +64,7 @@ class Cmd(abc.ABC): # export
def add_arguments(self, parser: ArgumentParser) -> None: def add_arguments(self, parser: ArgumentParser) -> None:
pass pass
def conf_value(self, path, default=None): def conf_value(self, path, default = None):
ret = None if self.app is None else self.app.conf_value(path, default) ret = None if self.app is None else self.app.conf_value(path, default)
if ret is None and default is not None: if ret is None and default is not None:
return default return default

View file

@ -1,13 +1,21 @@
# -*- coding: utf-8 -*- import argparse
import asyncio
import cProfile
import importlib
import inspect
import os
import re
import sys
import os, sys, argcomplete, argparse, importlib, inspect, re, pickle, asyncio, cProfile
from argparse import ArgumentParser from argparse import ArgumentParser
from pathlib import Path, PurePath from pathlib import Path, PurePath
from .log import * import argcomplete
from .log import DEBUG, ERR, NOTICE, add_log_file, set_flags, set_level, slog
from .stree import serdes from .stree import serdes
class Cmds: # export class Cmds: # export
def __instantiate(self, cls): def __instantiate(self, cls):
try: try:
@ -15,7 +23,7 @@ class Cmds: # export
except Exception as e: except Exception as e:
slog(ERR, f'Failed to instantiate command of type {cls}: {e}') slog(ERR, f'Failed to instantiate command of type {cls}: {e}')
raise raise
r.cmds = self # TODO: Rename Cmds class to App, "Cmds" isn't very self-explanatory r.cmds = self # TODO: Rename Cmds class to App, "Cmds" isn't self-explanatory
r.app = self r.app = self
return r return r
@ -26,7 +34,9 @@ class Cmds: # export
for c in cmd.child_classes: for c in cmd.child_classes:
cmd.children.append(self.__instantiate(c)) cmd.children.append(self.__instantiate(c))
if len(cmd.children) > 0: if len(cmd.children) > 0:
subparsers = parser.add_subparsers(title='Available subcommands of ' + cmd.name, metavar='') subparsers = parser.add_subparsers(
title = 'Available subcommands of ' + cmd.name, metavar = ''
)
for sub_cmd in cmd.children: for sub_cmd in cmd.children:
self.__add_cmd_to_parser(sub_cmd, subparsers) self.__add_cmd_to_parser(sub_cmd, subparsers)
@ -38,7 +48,13 @@ class Cmds: # export
slog(DEBUG, 'Reading configuration "{}"'.format(path)) slog(DEBUG, 'Reading configuration "{}"'.format(path))
return serdes.read(path, ''), [path] return serdes.read(path, ''), [path]
def __init__(self, description: str = '', filter: str = '^Cmd.*', modules: None=None, eloop: None=None) -> None: def __init__(
self,
description: str = '',
filter: str = '^Cmd.*',
modules: None = None,
eloop: None = None
) -> None:
self.__description = description self.__description = description
self.__filter = filter self.__filter = filter
self.__modules = modules self.__modules = modules
@ -68,18 +84,34 @@ class Cmds: # export
set_flags(log_flags) set_flags(log_flags)
set_level(log_level) set_level(log_level)
slog(DEBUG, "set log level to {}".format(log_level)) slog(DEBUG, "set log level to {}".format(log_level))
self.__parser = argparse.ArgumentParser(usage=os.path.basename(sys.argv[0]) + ' [options]', self.__parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter, description=self.__description) usage = os.path.basename(sys.argv[0]) + ' [options]',
self.__parser.add_argument('--log-flags', help='Log flags', default=log_flags) formatter_class = argparse.ArgumentDefaultsHelpFormatter,
self.__parser.add_argument('--log-level', help='Log level', default=log_level) description = self.__description
self.__parser.add_argument('--backtrace', help='Show exception backtraces', action='store_true', default=False) )
self.__parser.add_argument('--write-profile', help='Profile code and store output to file', default=None) self.__parser.add_argument(
self.__parser.add_argument('--log-file', help='Log file', default=log_file) '--log-flags', help = 'Log flags', default = log_flags
if self.__modules == None: )
self.__modules = [ '__main__' ] self.__parser.add_argument(
'--log-level', help = 'Log level', default = log_level
)
self.__parser.add_argument(
'--backtrace',
help = 'Show exception backtraces',
action = 'store_true',
default = False
)
self.__parser.add_argument(
'--write-profile',
help = 'Profile code and store output to file',
default = None
)
self.__parser.add_argument('--log-file', help = 'Log file', default = log_file)
if self.__modules is None:
self.__modules = ['__main__']
subcmds = set() subcmds = set()
slog(DEBUG, '-- searching for commands') slog(DEBUG, '-- searching for commands')
for m in self.__modules: # type: ignore for m in self.__modules: # type: ignore
if m != '__main__': if m != '__main__':
importlib.import_module(m) importlib.import_module(m)
for name, c in inspect.getmembers(sys.modules[m], inspect.isclass): for name, c in inspect.getmembers(sys.modules[m], inspect.isclass):
@ -96,24 +128,27 @@ class Cmds: # export
subcmds.update(cmd.child_classes) subcmds.update(cmd.child_classes)
cmds = [cmd for cmd in self.__cmds if type(cmd) not in subcmds] cmds = [cmd for cmd in self.__cmds if type(cmd) not in subcmds]
subparsers = self.__parser.add_subparsers(title='Available commands', metavar='') subparsers = self.__parser.add_subparsers(
title = 'Available commands', metavar = ''
)
for cmd in cmds: for cmd in cmds:
slog(DEBUG, f'Adding top-level command {cmd} to parser') slog(DEBUG, f'Adding top-level command {cmd} to parser')
self.__add_cmd_to_parser(cmd, subparsers) self.__add_cmd_to_parser(cmd, subparsers)
# Run all sub-commands. Overwrite if you want to do anything before or after # Run all sub-commands. Overwrite if you want to do anything before or after
async def _run(self, argv=None): async def _run(self, argv = None):
return await self.args.func(self.args) return await self.args.func(self.args)
async def __run(self, argv=None): async def __run(self, argv = None):
argcomplete.autocomplete(self.__parser) argcomplete.autocomplete(self.__parser)
self.args = self.__parser.parse_args(args=argv) self.args = self.__parser.parse_args(args = argv)
set_flags(self.args.log_flags) set_flags(self.args.log_flags)
set_level(self.args.log_level) set_level(self.args.log_level)
self.__back_trace = self.args.backtrace self.__back_trace = self.args.backtrace
exit_status = 0 exit_status = 0
# This is the toplevel parser, i.e. no func member has been added to the args via # This is the toplevel parser, i.e. no func member has been added to the args
# yet, via
# #
# Cmds.__init__() # Cmds.__init__()
# Cmds.__add_cmd_to_parser(cmd, subparsers) # Cmds.__add_cmd_to_parser(cmd, subparsers)
@ -135,17 +170,15 @@ class Cmds: # export
add_log_file(self.args.log_file) add_log_file(self.args.log_file)
try: try:
ret = await self._run(self.args) await self._run(self.args)
except Exception as e: except Exception as e:
if hasattr(e, 'message'): slog(ERR, f'Exception: {type(e)}: {str(e)}')
slog(ERR, e.message)
else:
slog(ERR, f'Exception: {type(e)}: {e}')
exit_status = 1 exit_status = 1
if self.__back_trace: if self.__back_trace:
raise raise
finally: finally:
if pr is not None: if pr is not None:
assert self.args.write_profile is not None, 'args.write_profile'
pr.disable() pr.disable()
slog(NOTICE, f'Writing profile statistics to {self.args.write_profile}') slog(NOTICE, f'Writing profile statistics to {self.args.write_profile}')
pr.dump_stats(self.args.write_profile) pr.dump_stats(self.args.write_profile)
@ -160,7 +193,7 @@ class Cmds: # export
self.eloop = None self.eloop = None
self.__own_eloop = False self.__own_eloop = False
def conf_value(self, path, default=None): def conf_value(self, path, default = None):
ret = None if self.__conf is None else self.__conf.value(path) ret = None if self.__conf is None else self.__conf.value(path)
if ret is None and default is not None: if ret is None and default is not None:
return default return default
@ -169,10 +202,12 @@ class Cmds: # export
def parser(self) -> ArgumentParser: def parser(self) -> ArgumentParser:
return self.__parser return self.__parser
def run(self, argv=None) -> None: def run(self, argv = None) -> None:
#return self.__run() #return self.__run()
return self.eloop.run_until_complete(self.__run(argv)) # type: ignore return self.eloop.run_until_complete(self.__run(argv)) # type: ignore
def run_sub_commands(description = '', filter = '^Cmd.*', modules=None, argv=None): # export def run_sub_commands( # export
description = '', filter = '^Cmd.*', modules = None, argv = None
):
cmds = Cmds(description, filter, modules) cmds = Cmds(description, filter, modules)
return cmds.run(argv=argv) return cmds.run(argv = argv)

View file

@ -1,14 +1,16 @@
# -*- coding: utf-8 -*- import glob
import os
import re
import sys
from typing import Optional, Dict, cast from pathlib import Path
import os, re, glob, sys from typing import Dict, Optional, cast
from pathlib import Path, PosixPath
from . import stree from .stree import serdes
from .log import DEBUG, ERR, slog, get_caller_pos
from .stree.StringTree import StringTree from .stree.StringTree import StringTree
from .log import *
class Config(): # export class Config(): # export
def __load(self, search_dirs, glob_paths, refuse_mode_mask): def __load(self, search_dirs, glob_paths, refuse_mode_mask):
@ -33,15 +35,15 @@ class Config(): # export
for path in glob_paths: for path in glob_paths:
dirs = search_dirs dirs = search_dirs
if dirs is None: if dirs is None:
dirs = [''] if __is_abs(path) else [ str(Path.home()), str(Path.cwd()) ] dirs = [''] if __is_abs(path) else [str(Path.home()), str(Path.cwd())]
for d in dirs: for d in dirs:
g = d + '/' + path if len(d) else path g = d + '/' + path if len(d) else path
slog(DEBUG, 'Looking for config "{}"'.format(g)) slog(DEBUG, 'Looking for config "{}"'.format(g))
for f in glob.glob(g): for f in glob.glob(g):
slog(DEBUG, 'Reading config "{}"'.format(f)) slog(DEBUG, 'Reading config "{}"'.format(f))
paths_buf = [] paths_buf = []
tree = stree.read(f, paths_buf=paths_buf) tree = serdes.read(f, paths_buf = paths_buf)
assert(len(paths_buf)) assert (len(paths_buf))
if refuse_mode_mask is not None: if refuse_mode_mask is not None:
for p in paths_buf: for p in paths_buf:
st = os.stat(p) st = os.stat(p)
@ -49,37 +51,44 @@ class Config(): # export
for item in tree.child_list(): for item in tree.child_list():
if item.content is None: if item.content is None:
continue continue
if not re.search('password|secret', cast(str, item.content), flags=re.IGNORECASE): if not re.search('password|secret',
cast('str', item.content),
flags = re.IGNORECASE):
continue continue
msg = "Config files define secret, but at least one has file permissions open for world" msg = (
'Config files define secret, but at least one '
'has file permissions open for world'
)
slog(ERR, f'{msg}:') slog(ERR, f'{msg}:')
for pp in paths_buf: for pp in paths_buf:
slog(ERR, f' {((os.stat(pp).st_mode) & 0o7777):o} {pp}') mode = (os.stat(pp).st_mode) & 0o7777
slog(ERR, f' {mode:o} {pp}')
raise Exception(msg) raise Exception(msg)
tree.dump(DEBUG, f) tree.dump(DEBUG, f)
ret.add("", tree) ret.add("", tree)
return ret return ret
def __init__(self, def __init__(
search_dirs: Optional[list[str]]=None, self,
glob_paths: Optional[list[str]]=None, search_dirs: Optional[list[str]] = None,
glob_paths_env_key: Optional[str]=None, glob_paths: Optional[list[str]] = None,
defaults: Optional[Dict[str, str]]=None, glob_paths_env_key: Optional[str] = None,
tree: Optional[StringTree]=None, defaults: Optional[Dict[str, str]] = None,
parent=None, tree: Optional[StringTree] = None,
root_section=None, parent = None,
refuse_mode_mask=0o0027 root_section = None,
) -> None: refuse_mode_mask = 0o0027
) -> None:
self.__parent = parent self.__parent = parent
if tree is not None: if tree is not None:
assert(search_dirs is None) assert (search_dirs is None)
assert(glob_paths is None) assert (glob_paths is None)
assert(glob_paths_env_key is None) assert (glob_paths_env_key is None)
self.__conf = tree self.__conf = tree
else: else:
assert(tree is None) assert (tree is None)
if glob_paths_env_key is not None: if glob_paths_env_key is not None:
glob_paths_env = os.getenv(glob_paths_env_key) glob_paths_env = os.getenv(glob_paths_env_key)
if glob_paths_env is not None: if glob_paths_env is not None:
@ -87,8 +96,11 @@ class Config(): # export
glob_paths = [] glob_paths = []
glob_paths.extend(glob_paths_env.split(':')) glob_paths.extend(glob_paths_env.split(':'))
self.__conf = self.__load(search_dirs=search_dirs, glob_paths=glob_paths, self.__conf = self.__load(
refuse_mode_mask=refuse_mode_mask) search_dirs = search_dirs,
glob_paths = glob_paths,
refuse_mode_mask = refuse_mode_mask
)
if root_section is not None: if root_section is not None:
tmp = self.__conf.get(root_section) tmp = self.__conf.get(root_section)
@ -141,7 +153,11 @@ class Config(): # export
def value(self, key: str, default = None) -> Optional[str]: def value(self, key: str, default = None) -> Optional[str]:
return self.get(key, default) return self.get(key, default)
def branch(self, path: str, throw: bool=True): # type: ignore # Optional[Config]: FIXME: Don't know how to get hold of this type here def branch(
self,
path: str,
throw: bool = True
): # type: ignore # Optional[Config]: FIXME: Don't know how to get hold of this type here
if self.__conf: if self.__conf:
tree = self.__conf.get(path) tree = self.__conf.get(path)
if tree is None: if tree is None:
@ -151,19 +167,24 @@ class Config(): # export
return None return None
self.dump(ERR, msg) self.dump(ERR, msg)
raise Exception(msg) raise Exception(msg)
return Config(tree=tree, parent=self) # type: ignore return Config(tree = tree, parent = self) # type: ignore
return None return None
def dump(self, prio: int, *args, **kwargs) -> None: def dump(self, prio: int, *args, **kwargs) -> None:
caller = get_caller_pos(1, kwargs) caller = get_caller_pos(1, kwargs)
self.__conf.dump(prio, caller=caller, *args, **kwargs) self.__conf.dump(prio, caller = caller, *args, **kwargs)
@property @property
def name(self): def name(self):
return self.__conf.content return self.__conf.content
def find(self, key: str|None, val: str|None, match:StringTree.Match=StringTree.Match.Equal) -> list[str]: def find(
return self.__conf.find(key, val, match=match) self,
key: str | None,
val: str | None,
match: StringTree.Match = StringTree.Match.Equal
) -> list[str]:
return self.__conf.find(key, val, match = match)
#def __getattr__(self, name: str): #def __getattr__(self, name: str):
# return getattr(self.__conf, name) # return getattr(self.__conf, name)

View file

@ -1,6 +1,4 @@
# -*- coding: utf-8 -*- class CppState: # export
class CppState: # export
def __init__(self): def __init__(self):
self.__pair_square = ['[', ']'] self.__pair_square = ['[', ']']
@ -33,37 +31,39 @@ class CppState: # export
self.things.append(self.__pair_square) self.things.append(self.__pair_square)
elif tok == ']': elif tok == ']':
self.square -= 1 self.square -= 1
assert(self.things.pop() == self.__pair_square) assert (self.things.pop() == self.__pair_square)
elif tok == '{': elif tok == '{':
self.curly += 1 self.curly += 1
self.things.append(self.__pair_curly) self.things.append(self.__pair_curly)
elif tok == '}': elif tok == '}':
self.curly -= 1 self.curly -= 1
assert(self.things.pop() == self.__pair_curly) assert (self.things.pop() == self.__pair_curly)
elif tok == '(': elif tok == '(':
self.paren += 1 self.paren += 1
self.things.append(self.__pair_paren) self.things.append(self.__pair_paren)
elif tok == ')': elif tok == ')':
self.paren -= 1 self.paren -= 1
assert(self.things.pop() == self.__pair_paren) assert (self.things.pop() == self.__pair_paren)
elif tok == '<': elif tok == '<':
self.ext += 1 self.ext += 1
self.things.append(self.__pair_ext) self.things.append(self.__pair_ext)
elif tok == '>': elif tok == '>':
self.ext -= 1 self.ext -= 1
assert(self.things.pop() == self.__pair_ext) assert (self.things.pop() == self.__pair_ext)
elif tok == '?': elif tok == '?':
if not self.in_special: if not self.in_special:
self.in_special = True self.in_special = True
self.things.append(self.__pair_special) self.things.append(self.__pair_special)
else: else:
self.in_special = False self.in_special = False
assert(self.things.pop() == self.__pair_special) assert (self.things.pop() == self.__pair_special)
elif tok == '/*': elif tok == '/*':
self.in_c_comment = True self.in_c_comment = True
self.things.append(self.__pair_c_comment) self.things.append(self.__pair_c_comment)
elif tok == '*/': elif tok == '*/':
raise Exception("Unmatched closing C-style comment mark", tok, "in line", line) raise Exception(
"Unmatched closing C-style comment mark", tok, "in line", line
)
else: else:
if self.in_cpp_comment: if self.in_cpp_comment:
if tok == '\n': if tok == '\n':
@ -72,7 +72,7 @@ class CppState: # export
if tok == '/*': if tok == '/*':
raise Exception("Nested C-style comment", tok, "in line", line) raise Exception("Nested C-style comment", tok, "in line", line)
elif tok == '*/': elif tok == '*/':
assert(self.things.pop() == self.__pair_c_comment) assert (self.things.pop() == self.__pair_c_comment)
self.in_c_comment = False self.in_c_comment = False
if self.curly < 0 or self.square < 0 or self.ext < 0 or self.paren < 0: if self.curly < 0 or self.square < 0 or self.ext < 0 or self.paren < 0:
@ -101,4 +101,3 @@ class CppState: # export
def is_optional(self): def is_optional(self):
return self.in_list() or self.in_option() return self.in_list() or self.in_option()

View file

@ -1,6 +1,4 @@
TOPDIR = ../../../.. TOPDIR = ../../../..
PY_UPDATE_INIT_PY ?= false
include $(TOPDIR)/make/proj.mk include $(TOPDIR)/make/proj.mk
include $(JWBDIR)/make/py-mod.mk include $(JWBDIR)/make/py-mod.mk

View file

@ -1,16 +1,14 @@
# -*- coding: utf-8 -*-
from __future__ import print_function from __future__ import print_function
from . import log from . import log
class Object(object): # export class Object(object): # export
def __init__(self): def __init__(self):
self.log_level = log.level self.log_level = log.log_level()
def log(self, prio, *args): def log(self, prio, *args):
if self.log_level == log.level: if self.log_level == log.log_level():
log.slog(prio, args) log.slog(prio, args)
return return
if prio <= self.log_level: if prio <= self.log_level:

View file

@ -1,11 +1,13 @@
import re
import json import json
from collections import OrderedDict import re
from .log import *
import shlex import shlex
import traceback import traceback
class Options: # export from collections import OrderedDict
from .log import ERR, get_caller_pos, slog, slog_m
class Options: # export
class OrderedData: class OrderedData:
@ -30,8 +32,8 @@ class Options: # export
if spec[0] != '{': if spec[0] != '{':
spec = '{' + spec + '}' spec = '{' + spec + '}'
try: try:
return json.loads(spec, object_pairs_hook=cls) return json.loads(spec, object_pairs_hook = cls)
except: except Exception:
pass pass
return None return None
@ -42,7 +44,7 @@ class Options: # export
r = cls() r = cls()
try: try:
opt_strs = shlex.split(opts_str) opt_strs = shlex.split(opts_str)
except Exception as e: except Exception:
slog_m(ERR, traceback.format_exc()) slog_m(ERR, traceback.format_exc())
slog(ERR, 'Failed to split options string >{}<'.format(opts_str)) slog(ERR, 'Failed to split options string >{}<'.format(opts_str))
raise raise
@ -52,7 +54,7 @@ class Options: # export
lhs = sides[0].strip() lhs = sides[0].strip()
if not len(lhs): if not len(lhs):
continue continue
if self.__allowed_keys and not lhs in self.__allowed_keys: if self.__allowed_keys and lhs not in self.__allowed_keys:
raise Exception('Field "{}" not supported'.format(lhs)) raise Exception('Field "{}" not supported'.format(lhs))
rhs = ' '.join(sides[1:]).strip() if len(sides) > 1 else self.__true_val rhs = ' '.join(sides[1:]).strip() if len(sides) > 1 else self.__true_val
if cls == OrderedDict: if cls == OrderedDict:
@ -82,7 +84,7 @@ class Options: # export
self.__str = self.__str__() self.__str = self.__str__()
def __getitem__(self, key): def __getitem__(self, key):
if not key in self.__dict.keys(): if key not in self.__dict.keys():
return None return None
return self.__dict[key] return self.__dict[key]
@ -99,35 +101,38 @@ class Options: # export
return len(self.__data.pairs) return len(self.__data.pairs)
def __contains__(self, keys): def __contains__(self, keys):
if not type(keys) in [list, set]: if type(keys) not in [list, set]:
return keys in self.__dict.keys() return keys in self.__dict.keys()
for key in keys: for key in keys:
if not key in self.__dict.keys(): if key not in self.__dict.keys():
return False return False
return True return True
def __iter__(self): def __iter__(self):
return iter(self.__list) return iter(self.__list)
def __next__(self): #def __next__(self):
return next(self.__list) # return next(self.__list)
def __init__(self, spec=None, delimiter=',', allowed_keys=None, true_val=True): def __init__(
self, spec = None, delimiter = ',', allowed_keys = None, true_val = True
):
self.__true_val = true_val self.__true_val = true_val
self.__allowed_keys = None self.__allowed_keys = None
self.__delimiter = delimiter self.__delimiter = delimiter
self.__data = self.OrderedData() if spec is None else self.__parse(spec, self.OrderedData) self.__data = self.OrderedData(
) if spec is None else self.__parse(spec, self.OrderedData)
self.__dict = {} self.__dict = {}
#self.__dict = OrderedDict() if spec is None else self.__parse(spec, OrderedDict) #self.__dict = OrderedDict() if spec is None else self.__parse(spec,OrderedDict)
self.__list = [] self.__list = []
self.__str = None self.__str = None
self.__recache() self.__recache()
def dump(self, prio, caller=None): def dump(self, prio, caller = None):
if caller is None: if caller is None:
caller = get_caller_pos() caller = get_caller_pos()
for key, val in self.__data.pairs: for key, val in self.__data.pairs:
slog(prio, "{}=\"{}\"".format(key, val), caller=caller) slog(prio, "{}=\"{}\"".format(key, val), caller = caller)
def keys(self): def keys(self):
return self.__dict.keys() return self.__dict.keys()
@ -136,22 +141,28 @@ class Options: # export
#return self.__dict.items() #return self.__dict.items()
return self.__data.pairs return self.__data.pairs
def get(self, key, default=None, by_index=False): def get(self, key, default = None, by_index = False):
if by_index: if by_index:
if type(key) != int: if isinstance(key, int):
raise KeyError('Tried to get value from options string with ' + raise KeyError(
'index {} of type "{}": {}'.format(key, type(key), str(self))) 'Tried to get value from options string with ' +
'index {} of type "{}": {}'.format(key, type(key), str(self))
)
if key >= len(self.__data.pairs): if key >= len(self.__data.pairs):
if default is not None: if default is not None:
return default return default
raise KeyError('Tried to get value from options string with ' + raise KeyError(
'index {} of {}: {}'.format(key, len(self.__data.pairs), str(self))) 'Tried to get value from options string with ' +
'index {} of {}: {}'.format(key, len(self.__data.pairs), str(self))
)
return self.__list[key] return self.__list[key]
if key in self.__dict.keys(): if key in self.__dict.keys():
return self.__dict[key] return self.__dict[key]
if default is not None: if default is not None:
return default return default
raise KeyError('Key "{}" is not present in options string: {}'.format(key, str(self))) raise KeyError(
'Key "{}" is not present in options string: {}'.format(key, str(self))
)
def update(self, rhs): def update(self, rhs):
if hasattr(rhs, 'items'): if hasattr(rhs, 'items'):
@ -159,9 +170,13 @@ class Options: # export
self.__dict[key] = val self.__dict[key] = val
return return
if isinstance(rhs, str): if isinstance(rhs, str):
self.update(self.__parse(rhs)) self.update(self.__parse(rhs, self.OrderedData))
return return
raise Exception('Tried to update options with object of incompatible type {}'.format(type(rhs))) raise Exception(
'Tried to update options with object of incompatible type {}'.format(
type(rhs)
)
)
def append_to(self, obj): def append_to(self, obj):
for opt in self.__list: for opt in self.__list:

View file

@ -1,41 +1,46 @@
# -*- coding: utf-8 -*-
from __future__ import annotations from __future__ import annotations
import signal
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from enum import Enum, Flag, auto from enum import Enum, Flag, auto
from typing import List from typing import TYPE_CHECKING
from .log import ERR, slog
if TYPE_CHECKING:
from .Signals import Signals
def _sigchld_handler(signum, process): def _sigchld_handler(signum, process):
if not signum == signal.SIGCHLD: if not signum == signal.SIGCHLD:
return return
Process.propagate_signal(signum) Process.propagate_signal(signum)
class Process(ABC): # export class Process(ABC): # export
__processes: List[Process] = [] __processes: set[Process] = set()
class State(Enum): class State(Enum):
Running = auto() Running = auto()
Shutdown = auto() Shutdown = auto()
Done = auto() Done = auto()
class Flags(Flag): class Flags(Flag):
FailOnExitWithoutShutdown = auto() FailOnExitWithoutShutdown = auto()
def __init__(self): def __init__(self):
self.__state = Running self.__state = Process.State.Running
self.__flags = Flags.FailOnExitWithoutShutdown self.__flags = self.Flags.FailOnExitWithoutShutdown
if len(self.__processes) == 0: if len(self.__processes) == 0:
self._signals().add_handler(signals.SIGCHLD, _sigchld_handler) self.signals().add_handler(signal.SIGCHLD, _sigchld_handler)
self.__processes.add(self) self.__processes.add(self)
@classmethod @classmethod
def propagate_signal(cls, signum): def propagate_signal(cls, signum):
for p in cls.__processes: cls.signals().propagate(signum)
p.__signal(signum)
def signal(self, signum): def signal(self, signum):
if signum == signals.SIGCHLD: if signum == signal.SIGCHLD:
self.exited() self.exited()
@abstractmethod @abstractmethod
@ -44,7 +49,7 @@ class Process(ABC): # export
@classmethod @classmethod
@abstractmethod @abstractmethod
def signals(cls): def signals(cls) -> Signals:
pass pass
# to be reimplemented # to be reimplemented
@ -56,17 +61,15 @@ class Process(ABC): # export
return str(self._pid()) return str(self._pid())
def request_shutdown(self): def request_shutdown(self):
if not self.__state == Shutdown: if not self.__state == Process.State.Shutdown:
self.__state = Shutdown self.__state = Process.State.Shutdown
self._request_shutdown() self._request_shutdown()
def exited(self): def exited(self):
if self.__state == Process.State.Running: if self.__state == Process.State.Running:
slog(ERR, 'process "{}" exited unexpectedly'.format(process.name())) slog(ERR, 'process exited unexpectedly')
if __flags & Process.Flags.FailOnExitWithoutShutdown: if self.__flags & Process.Flags.FailOnExitWithoutShutdown:
slog(ERR, 'exiting') slog(ERR, 'exiting')
exit(1) exit(1)
self.__state = Process.State.Done self.__state = Process.State.Done
self.__processes.erase(self) self.__processes.remove(self)
if len(self.__processes) == 0:
self._signals().remove_handler(signals.SIGCHLD) # FIXME: broken logic

View file

@ -1,13 +1,13 @@
# -*- coding: utf-8 -*-
from __future__ import print_function from __future__ import print_function
import os, io, sys, traceback import os
import io
import sys
from fcntl import fcntl, F_GETFL, F_SETFL from fcntl import fcntl, F_GETFL, F_SETFL
class RedirectStdIO: # export class RedirectStdIO: # export
def __init__(self, stderr='on', stdout='off'): def __init__(self, stderr = 'on', stdout = 'off'):
self.__stderr = stderr self.__stderr = stderr
self.__stdout = stdout self.__stdout = stdout
# TODO: arguments not fully implemented, # TODO: arguments not fully implemented,
@ -30,12 +30,12 @@ class RedirectStdIO: # export
sys.stdout.flush() sys.stdout.flush()
os.dup2(self.real_stdout_fd, 1) os.dup2(self.real_stdout_fd, 1)
if type is not None: if type is not None:
#print("-------- Error while stdio was suppressed --------") #print("-------- Error while stdio was suppressed --------")
#traceback.print_stack() #traceback.print_stack()
#print(traceback) #print(traceback)
print("-------- Captured output --------") print("-------- Captured output --------")
print(*self.rfile.readlines()) print(*self.rfile.readlines())
self.rfile.close() self.rfile.close()
#print('type = ' + str(type)) #print('type = ' + str(type))
#print('value = ' + str(value)) #print('value = ' + str(value))
raise type(value) raise type(value)

View file

@ -1,19 +1,10 @@
# -*- coding: utf-8 -*- from abc import abstractmethod
from typing import Dict
from typing import Dict, Callable
from abc import ABC, abstractmethod
_handled_signals: Dict[int, Callable] = {}
def _signal_handler(signal, frame):
if not signal in _handled_signals.keys():
return
for h in _handled_signals[signal]:
h.func(signal, *h.args)
class Signals: class Signals:
class Handler: class Handler:
def __init__(self, func, args): def __init__(self, func, args):
self.func = func self.func = func
self.args = args self.args = args
@ -23,15 +14,27 @@ class Signals:
@classmethod @classmethod
@abstractmethod @abstractmethod
def _add_handler(self, signal, handler): def _add_handler(cls, signal, handler):
raise Exception("_add_handler() is not reimplemented") raise Exception("_add_handler() is not reimplemented")
@classmethod @classmethod
def add_handler(cls, signals, handler, *args): def add_handler(cls, signals, handler, *args):
for signal in signals: for signal in signals:
h = Signals.Handler(handler, args) h = Signals.Handler(handler, args)
if not signal in _handled_signals.keys(): if signal not in _handled_signals.keys():
_handled_signals[signal] = [h] _handled_signals[signal] = [h]
cls._add_signal_handler(signal, _signal_handler) cls._add_handler(signal, _signal_handler)
else: else:
_handled_signals[signal].add(h) _handled_signals[signal].append(h)
@classmethod
def propagate(cls, signal):
_signal_handler(signal, None)
_handled_signals: Dict[int, list[Signals.Handler]] = {}
def _signal_handler(signal, frame):
if signal not in _handled_signals.keys():
return
for h in _handled_signals[signal]:
h.func(signal, *h.args)

View file

@ -1,12 +1,10 @@
# -*- coding: utf-8 -*-
from datetime import datetime from datetime import datetime
from .log import * from .log import get_caller_pos, slog
class StopWatch: # export class StopWatch: # export
def __init__(self, name=''): def __init__(self, name = ''):
self.__start = datetime.now() self.__start = datetime.now()
self.__last = self.__start self.__last = self.__start
self.name = name self.name = name
@ -21,5 +19,9 @@ class StopWatch: # export
else: else:
msg = '------------------ ' msg = '------------------ '
caller = kwargs['caller'] if 'caller' in kwargs.keys() else get_caller_pos(1) caller = kwargs['caller'] if 'caller' in kwargs.keys() else get_caller_pos(1)
slog(prio, '{} {} {}'.format(self.name, str(now - self.__last), msg), caller=caller) slog(
prio,
'{} {} {}'.format(self.name, str(now - self.__last), msg),
caller = caller
)
self.__last = now self.__last = now

View file

@ -1,3 +0,0 @@
from pkgutil import extend_path
__path__ = extend_path(__path__, __name__)

View file

@ -1,16 +1,15 @@
# -*- coding: utf-8 -*- import re
import re, shlex
from collections import namedtuple from collections import namedtuple
from ..log import * from ..log import DEBUG, get_caller_pos, prio_gets_logged, slog
L, R = 'Left Right'.split() L, R = 'Left Right'.split()
ARG, KEYW, QUOTED, LPAREN, RPAREN = 'arg kw quoted ( )'.split() ARG, KEYW, QUOTED, LPAREN, RPAREN = 'arg kw quoted ( )'.split()
class Operator: # export class Operator: # export
def __init__(self, func=None, nargs=2, precedence=3, assoc=L): def __init__(self, func = None, nargs = 2, precedence = 3, assoc = L):
self.func = func self.func = func
self.nargs = nargs self.nargs = nargs
self.prec = precedence self.prec = precedence
@ -18,7 +17,7 @@ class Operator: # export
class Stack: class Stack:
def __init__(self, itemlist=[]): def __init__(self, itemlist = []):
self.items = itemlist self.items = itemlist
def __repr__(self): def __repr__(self):
@ -39,7 +38,7 @@ class Stack:
self.items.append(item) self.items.append(item)
return 0 return 0
class ShuntingYard(object): # export class ShuntingYard(object): # export
def __init__(self, operators = None): def __init__(self, operators = None):
self.do_debug = prio_gets_logged(DEBUG) self.do_debug = prio_gets_logged(DEBUG)
@ -54,7 +53,7 @@ class ShuntingYard(object): # export
for count, thing in enumerate(args): for count, thing in enumerate(args):
msg += ' ' + str(thing) msg += ' ' + str(thing)
if len(msg): if len(msg):
slog(DEBUG, msg[1:], caller=get_caller_pos()) slog(DEBUG, msg[1:], caller = get_caller_pos())
def operator(self, key: str) -> Operator: def operator(self, key: str) -> Operator:
return self.__ops[key] return self.__ops[key]
@ -65,7 +64,7 @@ class ShuntingYard(object): # export
v = self.__ops[k] v = self.__ops[k]
buf = ", \"" + k buf = ", \"" + k
if v.nargs == 1: if v.nargs == 1:
if k[len(k)-1].isalnum(): if k[len(k) - 1].isalnum():
buf = buf + ' ' buf = buf + ' '
buf = buf + "xxx" buf = buf + "xxx"
buf = buf + "\"" buf = buf + "\""
@ -83,16 +82,20 @@ class ShuntingYard(object): # export
regex = regex[1:] regex = regex[1:]
scanner = re.Scanner([ scanner = re.Scanner( # pyright: ignore[reportAttributeAccessIssue]
(regex, lambda scanner,token:(KEYW, token)), [
(r"\"[^\"]*\"|'[^']*'", lambda scanner,token:(QUOTED, token[1:-1])), (regex, lambda scanner, token: (KEYW, token)),
(r"[^\s()]+", lambda scanner,token:(ARG, token)), (r"\"[^\"]*\"|'[^']*'", lambda scanner, token: (QUOTED, token[1:-1])),
(r"\s+", None), # None == skip token. (r"[^\s()]+", lambda scanner, token: (ARG, token)),
]) (r"\s+", None), # None == skip token.
]
)
tokens, remainder = scanner.scan(spec) tokens, remainder = scanner.scan(spec)
if len(remainder)>0: if len(remainder) > 0:
raise Exception("Failed to tokenize " + spec + ", remaining bit is ", remainder) raise Exception(
"Failed to tokenize " + spec + ", remaining bit is ", remainder
)
#self.debug(tokens) #self.debug(tokens)
return tokens return tokens
@ -112,14 +115,14 @@ class ShuntingYard(object): # export
tokenized = self.tokenize(infix) tokenized = self.tokenize(infix)
self.debug("tokenized = ", tokenized) self.debug("tokenized = ", tokenized)
outq, stack = [], [] outq, stack = [], []
table = ['TOKEN,ACTION,RPN OUTPUT,OP STACK,NOTES'.split(',')] table = ['TOKEN', 'ACTION', 'RPN OUTPUT', ('OP STACK', ), 'NOTES']
for toktype, token in tokenized: for toktype, token in tokenized:
self.debug("Checking token", token) self.debug("Checking token", token)
note = action = '' note = action = ''
if toktype in [ ARG, QUOTED ]: if toktype in [ARG, QUOTED]:
action = 'Add arg to output' action = 'Add arg to output'
outq.append(token) outq.append(token)
table.append( (token, action, outq, (s[0] for s in stack), note) ) table.append((token, action, outq, (s[0] for s in stack), note))
elif toktype == KEYW: elif toktype == KEYW:
val = self.__ops[token] val = self.__ops[token]
t1, op1 = token, val t1, op1 = token, val
@ -127,7 +130,9 @@ class ShuntingYard(object): # export
note = 'Pop ops from stack to output' note = 'Pop ops from stack to output'
while stack: while stack:
t2, op2 = stack[-1] t2, op2 = stack[-1]
if (op1.assoc == L and op1.prec <= op2.prec) or (op1.assoc == R and op1.prec < op2.prec): if (op1.assoc == L
and op1.prec <= op2.prec) or (op1.assoc == R
and op1.prec < op2.prec):
if t1 != RPAREN: if t1 != RPAREN:
if t2 != LPAREN: if t2 != LPAREN:
stack.pop() stack.pop()
@ -143,9 +148,11 @@ class ShuntingYard(object): # export
else: else:
stack.pop() stack.pop()
action = '(Pop & discard "(")' action = '(Pop & discard "(")'
table.append( (v, action, outq, (s[0] for s in stack), note) ) table.append(
(v, action, outq, (s[0] for s in stack), note)
)
break break
table.append( (v, action, (outq), (s[0] for s in stack), note) ) table.append((v, action, (outq), (s[0] for s in stack), note))
v = note = '' v = note = ''
else: else:
note = '' note = ''
@ -157,7 +164,7 @@ class ShuntingYard(object): # export
action = 'Push op token to stack' action = 'Push op token to stack'
else: else:
action = 'Discard ")"' action = 'Discard ")"'
table.append( (v, action, (outq), (s[0] for s in stack), note) ) table.append((v, action, (outq), (s[0] for s in stack), note))
note = 'Drain stack to output' note = 'Drain stack to output'
while stack: while stack:
v = '' v = ''
@ -165,15 +172,27 @@ class ShuntingYard(object): # export
action = '(Pop op)' action = '(Pop op)'
stack.pop() stack.pop()
outq.append(t2) outq.append(t2)
table.append( (v, action, outq, (s[0] for s in stack), note) ) table.append((v, action, outq, (s[0] for s in stack), note))
v = note = '' v = note = ''
if self.do_debug: if self.do_debug:
maxcolwidths = [len(max(x, key=len)) for x in [zip(*table)]] maxcolwidths = [len(max(x, key = len)) for x in [zip(*table)]]
caller = get_caller_pos() get_caller_pos()
row = table[0] row = table[0]
slog(DEBUG, ' '.join('{cell:^{width}}'.format(width=width, cell=cell) for (width, cell) in zip(maxcolwidths, row))) slog(
DEBUG,
' '.join(
'{cell:^{width}}'.format(width = width, cell = cell)
for (width, cell) in zip(maxcolwidths, row)
)
)
for row in table[1:]: for row in table[1:]:
slog(DEBUG, ' '.join('{cell:<{width}}'.format(width=width, cell=cell) for (width, cell) in zip(maxcolwidths, row))) slog(
DEBUG,
' '.join(
'{cell:<{width}}'.format(width = width, cell = cell)
for (width, cell) in zip(maxcolwidths, row)
)
)
return table[-1][2] return table[-1][2]
def infix_to_postfix_orig(self, infix): def infix_to_postfix_orig(self, infix):
@ -185,7 +204,7 @@ class ShuntingYard(object): # export
for tokinfo in tokens: for tokinfo in tokens:
self.debug(tokinfo) self.debug(tokinfo)
toktype, token = tokinfo[0], tokinfo[1] _toktype, token = tokinfo[0], tokinfo[1]
self.debug("Checking token ", token) self.debug("Checking token ", token)
@ -204,7 +223,8 @@ class ShuntingYard(object): # export
topToken = s.pop() topToken = s.pop()
continue continue
while (not s.isEmpty()) and (self.__ops[s.peek()].prec >= self.__ops[token].prec): while (not s.isEmpty()) and (self.__ops[s.peek()].prec
>= self.__ops[token].prec):
#self.debug(token) #self.debug(token)
r.append(s.pop()) r.append(s.pop())
#self.debug(r) #self.debug(r)
@ -240,7 +260,9 @@ class ShuntingYard(object): # export
args.append(vals.pop()) args.append(vals.pop())
#self.debug("running %s(%s)" % (token, ', '.join(reversed(args)))) #self.debug("running %s(%s)" % (token, ', '.join(reversed(args))))
val = op.func(*reversed(args)) val = op.func(*reversed(args))
self.debug("%s(%s) = %s" % (token, ', '.join(map(str, reversed(args))), val)) self.debug(
"%s(%s) = %s" % (token, ', '.join(map(str, reversed(args))), val)
)
vals.push(val) vals.push(val)
return vals.pop() return vals.pop()
@ -266,27 +288,27 @@ if __name__ == '__main__':
# return string.split() # return string.split()
def f_mult(self, a, b): def f_mult(self, a, b):
return str(atof(a) * atof(b)); return str(atof(a) * atof(b))
def f_div(self, a, b): def f_div(self, a, b):
return str(atof(a) / atof(b)); return str(atof(a) / atof(b))
def f_add(self, a, b): def f_add(self, a, b):
return str(atof(a) + atof(b)); return str(atof(a) + atof(b))
def f_sub(self, a, b): def f_sub(self, a, b):
return str(atof(a) - atof(b)); return str(atof(a) - atof(b))
def __init__(self): def __init__(self):
Op = Operator Op = Operator
operators = { operators = {
'^': Op(None, 2, 4, R), '^': Op(None, 2, 4, R),
'*': Op(self.f_mult, 2, 3, L), '*': Op(self.f_mult, 2, 3, L),
'/': Op(self.f_div, 2, 3, L), '/': Op(self.f_div, 2, 3, L),
'+': Op(self.f_add, 2, 2, L), '+': Op(self.f_add, 2, 2, L),
'-': Op(self.f_sub, 2, 2, L), '-': Op(self.f_sub, 2, 2, L),
'(': Op(None, 0, 9, L), '(': Op(None, 0, 9, L),
')': Op(None, 0, 0, L), ')': Op(None, 0, 0, L),
} }
super(Calculator, self).__init__(operators) super(Calculator, self).__init__(operators)
@ -295,7 +317,7 @@ if __name__ == '__main__':
# ------------- testbed match object # ------------- testbed match object
Object = namedtuple("Object", [ "Name", "Label" ]) Object = namedtuple("Object", ["Name", "Label"])
class Matcher(ShuntingYard): class Matcher(ShuntingYard):
@ -324,14 +346,14 @@ if __name__ == '__main__':
def __init__(self, obj): def __init__(self, obj):
Op = Operator Op = Operator
operators = { operators = {
'(': Op(None, 2, 9, L), '(': Op(None, 2, 9, L),
')': Op(None, 2, 0, L), ')': Op(None, 2, 0, L),
'name=': Op(self.f_is_name, 1, 3, R), 'name=': Op(self.f_is_name, 1, 3, R),
'and': Op(self.f_and, 2, 3, L), 'and': Op(self.f_and, 2, 3, L),
'label~=': Op(self.f_matches_label, 1, 3, R), 'label~=': Op(self.f_matches_label, 1, 3, R),
'False': Op(None, 0, 3, L), 'False': Op(None, 0, 3, L),
'True': Op(None, 0, 3, L), 'True': Op(None, 0, 3, L),
'not': Op(self.f_is_not, 1, 3, R), 'not': Op(self.f_is_not, 1, 3, R),
} }
super(Matcher, self).__init__(operators) super(Matcher, self).__init__(operators)

View file

@ -1,16 +1,14 @@
# -*- coding: utf-8 -*-
from abc import abstractmethod from abc import abstractmethod
from ..Process import Process as ProcessBase from ..Process import Process as ProcessBase
from .Signals import Signals from .Signals import Signals
class Process(ProcessBase): # export class Process(ProcessBase): # export
__signals = Signals() __signals = Signals()
def __init__(self, aio): def __init__(self, aio):
super().__init() super().__init__()
self.aio = aio self.aio = aio
@classmethod @classmethod

View file

@ -1,9 +1,11 @@
import asyncio import asyncio
from ..log import * import re
from ..log import DEBUG, ERR, INFO, WARNING, slog
# FIXME: Derive this from Process, or merge the classes entirely # FIXME: Derive this from Process, or merge the classes entirely
class ShellCmd: # export class ShellCmd: # export
class SubprocessProtocol(asyncio.SubprocessProtocol): class SubprocessProtocol(asyncio.SubprocessProtocol):
@ -26,12 +28,12 @@ class ShellCmd: # export
self.process.exited() self.process.exited()
class ShutdownState: class ShutdownState:
Running = 1 Running = 1
Triggered = 2 Triggered = 2
Completed = 3 Completed = 3
Unnecessary = 4 Unnecessary = 4
def __init__(self, cmdline, eloop=None, name=None): def __init__(self, cmdline, eloop = None, name = None):
if eloop is None: if eloop is None:
eloop = asyncio.get_running_loop() eloop = asyncio.get_running_loop()
self.__eloop = eloop self.__eloop = eloop
@ -56,12 +58,19 @@ class ShellCmd: # export
return r[1:] return r[1:]
try: try:
slog(INFO, "Running shell command [{}]: {}".format(self.__name, format_cmdline(self.__cmdline))) slog(
INFO,
"Running shell command [{}]: {}".format(
self.__name, format_cmdline(self.__cmdline)
)
)
self.__transport, self.__protocol = await self.__eloop.subprocess_exec( self.__transport, self.__protocol = await self.__eloop.subprocess_exec(
lambda: self.SubprocessProtocol(self, self.__name), lambda: self.SubprocessProtocol(self, self.__name),
*self.__cmdline, *self.__cmdline,
) )
self.__proc = self.__transport.get_extra_info('subprocess') # Popen instance self.__proc = self.__transport.get_extra_info(
'subprocess'
) # Popen instance
except: except:
slog(ERR, "Failed to run process [{}]".format(self.__name)) slog(ERR, "Failed to run process [{}]".format(self.__name))
raise raise
@ -69,7 +78,8 @@ class ShellCmd: # export
def __reap(self): def __reap(self):
if self.__rc is None and self.__transport: if self.__rc is None and self.__transport:
self.__transport = None self.__transport = None
self.__rc = self.__proc.wait() if self.__proc is not None:
self.__rc = self.__proc.wait()
# to be called from SubprocessProtocol / SIGCHLD handler # to be called from SubprocessProtocol / SIGCHLD handler
def exited(self): def exited(self):
@ -78,13 +88,24 @@ class ShellCmd: # export
async def __cleanup(self): async def __cleanup(self):
pid = self.__reap() pid = self.__reap()
sd_fine = self.__shutdown in [ self.ShutdownState.Unnecessary, self.ShutdownState.Completed ] sd_fine = self.__shutdown in [
self.ShutdownState.Unnecessary, self.ShutdownState.Completed
]
if self.__rc == 0 and sd_fine: if self.__rc == 0 and sd_fine:
slog(INFO, "The shell command [{}], pid {}, has exited cleanly".format(self.__name, self.__proc.pid)) assert self.__proc is not None
slog(
INFO,
"The shell command [{}], pid {}, has exited cleanly".format(
self.__name, self.__proc.pid
)
)
self.monitor = self.console = self.__protocol = self.__task = None self.monitor = self.console = self.__protocol = self.__task = None
return 0 return 0
slog(ERR, "The process ([{}], pid {}) has exited {}with status code {}, aborting".format( slog(
self.__name, pid, "" if sd_fine else "prematurely ", self.__rc)) ERR,
"The process ([{}], pid {}) has exited {}with status code {}, aborting".
format(self.__name, pid, "" if sd_fine else "prematurely ", self.__rc)
)
exit(1) exit(1)
async def init(self): async def init(self):
@ -100,9 +121,9 @@ class ShellCmd: # export
if __name__ == '__main__': if __name__ == '__main__':
from .. import log from .. import log
log.set_level('info') log.set_level('info')
async def run(): async def run():
sp = ShellCmd([ 'echo', 'hello world!' ]) sp = ShellCmd(['echo', 'hello world!'])
await sp.run() await sp.run()
asyncio.run(run()) asyncio.run(run())

View file

@ -1,5 +1,3 @@
# -*- coding: utf-8 -*-
import asyncio import asyncio
from ..Signals import Signals as SignalsBase from ..Signals import Signals as SignalsBase
@ -10,4 +8,4 @@ class Signals(SignalsBase):
@classmethod @classmethod
def _add_handler(cls, signal, handler): def _add_handler(cls, signal, handler):
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
loop.add_signal_handler(signal, handler, None) # None = *args loop.add_signal_handler(signal, handler, None) # None = *args

View file

@ -1,27 +1,28 @@
# -*- coding: utf-8 -*- from __future__ import annotations
from typing import Optional, Union, Self
import abc import abc
from enum import Flag, Enum, auto from enum import Enum, Flag, auto
from typing import TYPE_CHECKING, Optional, Self, Union
from ..log import * from ..log import ERR
from ..Config import Config
from ..misc import load_object from ..misc import load_object
class Access(Enum): # export if TYPE_CHECKING:
Read = auto() from ..Config import Config
class Access(Enum): # export
Read = auto()
Modify = auto() Modify = auto()
Create = auto() Create = auto()
Delete = auto() Delete = auto()
class ProjectFlags(Flag): # export class ProjectFlags(Flag): # export
NoFlags = auto() NoFlags = auto()
Contributing = auto() Contributing = auto()
Active = auto() Active = auto()
class Group: # export class Group: # export
def __repr__(self): def __repr__(self):
return f'Group({self.name})' return f'Group({self.name})'
@ -34,7 +35,7 @@ class Group: # export
def name(self) -> str: def name(self) -> str:
return self._name() return self._name()
class User: # export class User: # export
def __repr__(self): def __repr__(self):
return f'User({self.name})' return f'User({self.name})'
@ -70,14 +71,14 @@ class User: # export
def email(self) -> str: def email(self) -> str:
return self._email() return self._email()
class Auth(abc.ABC): # export class Auth(abc.ABC): # export
@classmethod @classmethod
def load(cls, conf: Config, tp: str='') -> Self: def load(cls, conf: Config, tp: str = '') -> Self:
if tp == '': if tp == '':
val = conf.get('type') val = conf.get('type')
if val is None: if val is None:
msg = f'No type specified in auth configuration' msg = 'No type specified in auth configuration'
conf.dump(ERR, msg) conf.dump(ERR, msg)
raise Exception(msg) raise Exception(msg)
tp = val tp = val
@ -92,10 +93,17 @@ class Auth(abc.ABC): # export
return self.__conf return self.__conf
@abc.abstractmethod @abc.abstractmethod
def _access(self, what: str, access_type: Optional[Access], who: User|Group|None) -> bool: def _access(
self, what: str, access_type: Optional[Access], who: User | Group | None
) -> bool:
raise NotImplementedError raise NotImplementedError
def access(self, what: str, access_type: Optional[Access]=None, who: Optional[Union[User|Group]]=None) -> bool: def access(
self,
what: str,
access_type: Optional[Access] = None,
who: Optional[Union[User, Group]] = None
) -> bool:
return self._access(what, access_type, who) return self._access(what, access_type, who)
@abc.abstractmethod @abc.abstractmethod

View file

@ -1,16 +1,18 @@
# -*- coding: utf-8 -*- from __future__ import annotations
from typing import Optional, Union from typing import TYPE_CHECKING, Optional
from ...log import * from ...log import WARNING, slog
from ... import Config from ..Auth import Access
from .. import Access from ..Auth import Auth as AuthBase
from .. import Auth as AuthBase from ..Auth import Group as GroupBase
from .. import Group as GroupBase from ..Auth import ProjectFlags
from .. import User as UserBase from ..Auth import User as UserBase
from .. import ProjectFlags
class Group(GroupBase): # export if TYPE_CHECKING:
from ...Config import Config
class Group(GroupBase): # export
def __init__(self, auth: AuthBase, name: str): def __init__(self, auth: AuthBase, name: str):
self.__name = name self.__name = name
@ -19,7 +21,7 @@ class Group(GroupBase): # export
def _name(self) -> str: def _name(self) -> str:
return self.__name return self.__name
class User(UserBase): # export class User(UserBase): # export
def __init__(self, auth: AuthBase, name: str, conf: Config): def __init__(self, auth: AuthBase, name: str, conf: Config):
self.__name = name self.__name = name
@ -47,13 +49,13 @@ class User(UserBase): # export
def _email(self) -> str: def _email(self) -> str:
return self.__email return self.__email
class Auth(AuthBase): # export class Auth(AuthBase): # export
def __init__(self, conf: Config): def __init__(self, conf: Config):
super().__init__(conf) super().__init__(conf)
self.___users: Optional[dict[str, UserBase]] = None self.___users: Optional[dict[str, UserBase]] = None
self.__groups = None self.__groups = None
self.__current_user: UserBase|None = None self.__current_user: UserBase | None = None
self.__user_by_email: Optional[dict[str, UserBase]] = None self.__user_by_email: Optional[dict[str, UserBase]] = None
@property @property
@ -62,12 +64,18 @@ class Auth(AuthBase): # export
ret: dict[str, UserBase] = {} ret: dict[str, UserBase] = {}
for name in self.conf.entries('user'): for name in self.conf.entries('user'):
conf = self.conf.branch('user.' + name) conf = self.conf.branch('user.' + name)
assert conf is not None, 'Config is None'
ret[name] = User(self, name, conf) ret[name] = User(self, name, conf)
self.___users = ret self.___users = ret
return self.___users return self.___users
def _access(self, what: str, access_type: Optional[Access], who: User|GroupBase|None) -> bool: # type: ignore def _access(
slog(WARNING, f'Returning False for {access_type} access to resource {what} by {who}') self, what: str, access_type: Access | None, who: UserBase | GroupBase | None
) -> bool: # type: ignore
slog(
WARNING,
f'Returning False for {access_type} access to resource {what} by {who}'
)
return False return False
def _user(self, name) -> UserBase: def _user(self, name) -> UserBase:

View file

@ -1,19 +1,21 @@
# -*- coding: utf-8 -*- from __future__ import annotations
from typing import Optional, Union from typing import TYPE_CHECKING, Optional
import ldap import ldap # type: ignore[import-untyped]
from ...log import *
from ...ldap import bind from ...ldap import bind
from ...Config import Config from ...log import DEBUG, ERR, WARNING, slog
from .. import Access from ..Auth import Access
from .. import Auth as AuthBase from ..Auth import Auth as AuthBase
from .. import Group as GroupBase from ..Auth import Group as GroupBase
from .. import User as UserBase from ..Auth import ProjectFlags
from .. import ProjectFlags from ..Auth import User as UserBase
class Group(GroupBase): # export if TYPE_CHECKING:
from ...Config import Config
class Group(GroupBase): # export
def __init__(self, auth: AuthBase, name: str): def __init__(self, auth: AuthBase, name: str):
self.__name = name self.__name = name
@ -24,13 +26,7 @@ class Group(GroupBase): # export
class User(UserBase): class User(UserBase):
def __init__( def __init__(self, auth: AuthBase, name: str, cn: str, email: str):
self,
auth: AuthBase,
name: str,
cn: str,
email: str
):
self.__auth = auth self.__auth = auth
self.__name = name self.__name = name
@ -50,14 +46,14 @@ class User(UserBase):
def _display_name(self) -> str: def _display_name(self) -> str:
return self.__cn return self.__cn
class Auth(AuthBase): # export class Auth(AuthBase): # export
def __init__(self, conf: Config): def __init__(self, conf: Config):
super().__init__(conf) super().__init__(conf)
self.___users: Optional[dict[str, UserBase]] = None self.___users: Optional[dict[str, UserBase]] = None
self.___user_by_email: Optional[dict[str, User]] = None self.___user_by_email: Optional[dict[str, User]] = None
self.__groups = None self.__groups = None
self.__current_user: User|None = None self.__current_user: User | None = None
self.__user_base_dn = conf['user_base_dn'] self.__user_base_dn = conf['user_base_dn']
self.__conn = self.__bind() self.__conn = self.__bind()
self.__dummy = self.load(conf, 'dummy') self.__dummy = self.load(conf, 'dummy')
@ -72,18 +68,16 @@ class Auth(AuthBase): # export
ret_by_email: dict[str, User] = {} ret_by_email: dict[str, User] = {}
for res in self.__conn.find( for res in self.__conn.find(
self.__user_base_dn, self.__user_base_dn,
ldap.SCOPE_SUBTREE, ldap.SCOPE_SUBTREE, # pyright: ignore[reportAttributeAccessIssue]
"objectClass=inetOrgPerson", "objectClass=inetOrgPerson",
('uid', 'cn', 'uidNumber', 'mail', 'maildrop') ('uid', 'cn', 'uidNumber', 'mail', 'maildrop')):
):
try: try:
display_name = None
if 'displayName' in res[1]: if 'displayName' in res[1]:
cn = res[1]['displayName'][0].decode('utf-8') cn = res[1]['displayName'][0].decode('utf-8')
else: else:
cn = res[1]['cn'][0].decode('utf-8') cn = res[1]['cn'][0].decode('utf-8')
uid = res[1]['uid'][0].decode('utf-8') uid = res[1]['uid'][0].decode('utf-8')
uidNumber = res[1]['uidNumber'][0].decode('utf-8') res[1]['uidNumber'][0].decode('utf-8')
emails = [] emails = []
#for attr in ['mail', 'maildrop']: #for attr in ['mail', 'maildrop']:
for attr in ['mail']: for attr in ['mail']:
@ -93,7 +87,7 @@ class Auth(AuthBase): # export
if not emails: if not emails:
slog(DEBUG, f'No email for user "{uid}", skipping') slog(DEBUG, f'No email for user "{uid}", skipping')
continue continue
user = User(self, name=uid, cn=cn, email=emails[0]) user = User(self, name = uid, cn = cn, email = emails[0])
ret[uid] = user ret[uid] = user
for email in emails: for email in emails:
ret_by_email[email] = user ret_by_email[email] = user
@ -111,10 +105,18 @@ class Auth(AuthBase): # export
def __user_by_email(self) -> dict[str, UserBase]: def __user_by_email(self) -> dict[str, UserBase]:
if self.___user_by_email is None: if self.___user_by_email is None:
self.__users self.__users
return self.___user_by_email # type: ignore # We are sure that ___user_by_email is not None at this point return self.___user_by_email # type: ignore # We are sure that ___user_by_email is not None at this point
def _access(self, what: str, access_type: Optional[Access], who: User|GroupBase|None) -> bool: # type: ignore def _access(
slog(WARNING, f'Returning False for {access_type} access to resource {what} by {who}') self,
what: str,
access_type: Optional[Access],
who: UserBase | GroupBase | None
) -> bool: # type: ignore
slog(
WARNING,
f'Returning False for {access_type} access to resource {what} by {who}'
)
return False return False
def _user(self, name) -> UserBase: def _user(self, name) -> UserBase:
@ -136,5 +138,9 @@ class Auth(AuthBase): # export
def _projects(self, name, flags: ProjectFlags) -> list[str]: def _projects(self, name, flags: ProjectFlags) -> list[str]:
if flags & ProjectFlags.Contributing: if flags & ProjectFlags.Contributing:
# TODO: Ask LDAP # TODO: Ask LDAP
slog(WARNING, f'Querying LDAP for projects a user contributes to is not implemented, ignoring') slog(
WARNING,
'Querying LDAP for projects a user contributes to is not '
'implemented, ignoring'
)
return [] return []

View file

@ -1,110 +1,130 @@
# -*- coding: utf-8 -*- import os
import pytimeparse, os
from datetime import datetime, timedelta
from collections import OrderedDict from collections import OrderedDict
from datetime import timedelta
from .log import * import pytimeparse # type: ignore[import-untyped]
from .log import DEBUG, WARNING, slog
_int_chars = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'] _int_chars = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
def _strip(s_, throw=True, log_level=ERR): def _strip(s_) -> str:
s = s_.strip() s = s_.strip()
if len(s) != 0: if len(s) != 0:
return s return s
msg = f'Tried to strip empty string "{s_}" to int' raise Exception(f'Tried to strip empty string "{s_}"')
if throw:
raise Exception(msg)
slog(log_level, msg)
return None
def cast_str_to_timedelta(s_: str, throw=True, log_level=DEBUG): # export def cast_str_to_timedelta(s_: str): # export
s = _strip(s_, throw=throw, log_level=log_level) s = _strip(s_)
try: seconds = pytimeparse.parse(s)
return (True, timedelta(seconds=pytimeparse.parse(s_))) if seconds is None:
except Exception as e: raise Exception(f'Failed to convert {s} to timedelta')
msg = f'Could not convert string "{s_}" to time ({e})' return timedelta(seconds = seconds)
if throw:
raise Exception(msg)
slog(log_level, msg)
return (False, None)
def cast_str_to_int(s_: str, throw=True, log_level=DEBUG): # export def cast_str_to_int(s_: str): # export
s = _strip(s_, throw=throw, log_level=log_level) s = _strip(s_)
if s[0] == '-': if s[0] == '-':
s = s[1:] s = s[1:]
for c in s: for c in s:
if not c in _int_chars: if c not in _int_chars:
break raise Exception(f'Could not convert string "{s}" to int')
else: return int(s)
return (True, int(s_))
msg = f'Could not convert string "{s_}" to int'
if throw:
raise Exception(msg)
slog(log_level, msg)
return (False, None)
def cast_str_to_bool(s_: str, throw=True, log_level=DEBUG): # export def cast_str_to_bool(s_: str): # export
s = _strip(s_, throw=throw, log_level=log_level).lower() s = _strip(s_).lower()
if s in ['true', 'yes', '1']: if s in ['true', 'yes', '1']:
return (True, True) return True
if s in ['false', 'no', '0']: if s in ['false', 'no', '0']:
return (True, False) return False
msg = f'Could not convert string "{s_}" to bool' raise Exception(f'Could not convert string "{s_}" to bool')
if throw:
raise Exception(msg)
slog(log_level, msg)
return (False, None)
_str_cast_functions = OrderedDict({ _str_cast_functions = OrderedDict(
bool: cast_str_to_bool, {
int: cast_str_to_int, bool: cast_str_to_bool, int: cast_str_to_int, timedelta: cast_str_to_timedelta
timedelta: cast_str_to_timedelta }
)
}) def guess_type(s: str, default = None, log_level = DEBUG, throw = False): # export
def guess_type(s: str, default=None, log_level=DEBUG, throw=False): # export
if s is None: if s is None:
raise Exception('None string passed to guess_type()') raise Exception('None string passed to guess_type()')
for tp, func in _str_cast_functions.items(): for tp, func in _str_cast_functions.items():
try: try:
success, value = func(s, log_level=OFF, throw=False) func(s)
if success: except Exception:
return tp
except:
continue continue
return tp
msg = f'Failed to guess type of string "{s}"' msg = f'Failed to guess type of string "{s}"'
if throw: if throw:
raise Exception(msg) raise Exception(msg)
slog(log_level, msg) slog(log_level, msg)
return default return default
def from_str(s: str, target_type=None, default_type=None, throw=True, log_level=WARNING, caller=None): # export def from_str( # export
s: str,
target_type = None,
default_type = None,
throw = True,
log_level = WARNING,
caller = None
):
if target_type is None: if target_type is None:
target_type = guess_type(s, default_type) for tp, func in _str_cast_functions.items():
if target_type is None: try:
msg = f'Could not deduce type to cast to from string "{s}"' return func(s)
if throw: except Exception:
raise Exception(msg) continue
slog(log_level, msg) msg = f'Could not deduce type to cast to from string "{s}"'
return None if throw:
result = _str_cast_functions[target_type](s, throw=throw, log_level=log_level) raise Exception(msg)
if result[0]: slog(log_level, msg)
return result[1] return None
msg = f'Failed to cast string "{s}" to type {target_type}' try:
if throw: return _str_cast_functions[target_type](s)
raise Exception(msg) except Exception as e:
slog(log_level, msg) msg = f'Failed to cast string "{s}" to type {target_type} ({str(e)})'
if throw:
raise Exception(msg)
slog(log_level, msg)
return None return None
def from_env(key: str, default=None, target_type=None, default_type=None, throw=True, log_level=WARNING, caller=None): # export def from_env( # export
key: str,
default = None,
target_type = None,
default_type = None,
throw = True,
log_level = WARNING,
caller = None
):
val = os.getenv(key) val = os.getenv(key)
if val is None: if val is None:
return default return default
if target_type is None and default is not None: if target_type is None and default is not None:
target_type = type(default) target_type = type(default)
return from_str(val, target_type=target_type, default_type=default_type, throw=throw, log_level=log_level, caller=caller) return from_str(
val,
target_type = target_type,
default_type = default_type,
throw = throw,
log_level = log_level,
caller = caller
)
# deprecated name # deprecated name
def cast_str(s: str, target_type=None, default_type=None, throw=True, log_level=WARNING, caller=None): def cast_str(
return from_str(s, target_type=target_type, default_type=None, throw=True, log_level=WARNING, caller=None) s: str,
target_type = None,
default_type = None,
throw = True,
log_level = WARNING,
caller = None
):
return from_str(
s,
target_type = target_type,
default_type = None,
throw = True,
log_level = WARNING,
caller = None
)

View file

@ -1,15 +1,16 @@
# -*- coding: utf-8 -*- from __future__ import annotations
from typing import Any
import abc import abc
from contextlib import contextmanager
from ..Config import Config from contextlib import contextmanager
from .schema.Schema import Schema from typing import TYPE_CHECKING
from ..Cmds import Cmds
from .Session import Session from ..log import NOTICE
from ..log import *
if TYPE_CHECKING:
from ..Config import Config
from .schema.Schema import Schema
from .Session import Session
class DataBase(abc.ABC): class DataBase(abc.ABC):
@ -37,6 +38,7 @@ class DataBase(abc.ABC):
def session(self): def session(self):
ret = self._create_session() ret = self._create_session()
try: try:
yield ret yield ret
finally: finally:
self._delete_session(ret) if ret is not None:
self._delete_session(ret)

View file

@ -1,8 +1,6 @@
# -*- coding: utf-8 -*-
import abc import abc
class Session(abc.ABC): # export class Session(abc.ABC): # export
def __init__(self, db): def __init__(self, db):
self.__db = db self.__db = db

View file

@ -1,18 +1,17 @@
# -*- coding: utf-8 -*- from __future__ import annotations
from typing import Any, List, Union, Optional, Dict
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import re, csv, json from typing import TYPE_CHECKING, Any, Dict, Union
from ..log import * from ..log import ERR, INFO, OFF, slog, slog_m
from ..cast import cast_str from .rows import rows_check_not_null, rows_dump, rows_duplicates
from .schema.Schema import Schema
from .rows import * if TYPE_CHECKING:
from .schema.Schema import Schema
TType = Union[Any, Dict[str, Any]] TType = Union[Any, Dict[str, Any]]
class TableIoHandler(ABC): # export class TableIoHandler(ABC): # export
def __init__(self, schema: Schema): def __init__(self, schema: Schema):
self.__table_meta = None self.__table_meta = None
@ -22,7 +21,8 @@ class TableIoHandler(ABC): # export
def _table_meta(self): def _table_meta(self):
if self.__table_meta is None: if self.__table_meta is None:
self.__table_meta = self.__schema.table_by_model_name( self.__table_meta = self.__schema.table_by_model_name(
self.__class__.__name__, throw=True) self.__class__.__name__, throw = True
)
return self.__table_meta return self.__table_meta
@property @property
@ -35,24 +35,29 @@ class TableIoHandler(ABC): # export
def _check_non_nullable(self, rows): def _check_non_nullable(self, rows):
buf = [] buf = []
non_nullable = self.__table_meta.not_null_insertible_columns non_nullable = self._table_meta.not_null_insertible_columns
try: try:
rows_check_not_null(rows, non_nullable, buf=buf) rows_check_not_null(rows, non_nullable, buf = buf)
except: except:
cn = self.__class__.__name__ cn = self.__class__.__name__
tn = self._table_name tn = self._table_name
d = '=========================================================' d = '========================================================='
slog_m(ERR, f'{d} Null values in {cn}\n') slog_m(ERR, f'{d} Null values in {cn}\n')
for key in non_nullable: for key in non_nullable:
buf = rows_check_not_null(rows, key, log_prio=OFF, throw=False) buf = rows_check_not_null(rows, key, log_prio = OFF, throw = False)
if not buf: if not buf:
continue continue
slog_m(ERR, f'\n{d} Null values in {cn} / {tn}: "{key}"\n') slog_m(ERR, f'\n{d} Null values in {cn} / {tn}: "{key}"\n')
use_cols=self.log_columns use_cols = list(self.log_columns)
if key not in use_cols: if key not in use_cols:
use_cols.append(key) use_cols.append(key)
rows_dump(buf, use_cols=use_cols, log_prio=ERR) rows_dump(buf, use_cols = use_cols, log_prio = ERR)
rows_dump(buf, use_cols=use_cols, out_path=f'/tmp/missing_{key}_in_{tn}.html', heading=f'Missing "{key}" in table {tn}') rows_dump(
buf,
use_cols = use_cols,
out_path = f'/tmp/missing_{key}_in_{tn}.html',
heading = f'Missing "{key}" in table {tn}'
)
raise raise
@property @property
@ -67,7 +72,9 @@ class TableIoHandler(ABC): # export
def _store(self, uri: str, data: TType): def _store(self, uri: str, data: TType):
pass pass
def load(self, uri: str, reference, check_duplicates=False, write_csv=None) -> TType: def load(
self, uri: str, reference, check_duplicates = False, write_csv = None
) -> TType:
slog(INFO, f'Reading table "{self._table_name}" from "{uri}"') slog(INFO, f'Reading table "{self._table_name}" from "{uri}"')
ret = self._load(uri, reference) ret = self._load(uri, reference)
if check_duplicates: if check_duplicates:

View file

@ -1,18 +1,18 @@
# -*- coding: utf-8 -*- from __future__ import annotations
from typing import Any, TYPE_CHECKING
from typing import Any
import abc import abc
from ...log import * from ...log import slog, slog_m, ERR, INFO
from ...misc import load_classes from ...misc import load_classes
from ...Cmds import Cmds
from ..DataBase import DataBase from ..DataBase import DataBase
from ..schema.Schema import Schema
from .Query import Query as QueryBase from .Query import Query as QueryBase
from .QueryResult import QueryResult
class Queries(abc.ABC): # export if TYPE_CHECKING:
from ..schema.Schema import Schema
from .QueryResult import QueryResult
class Queries(abc.ABC): # export
class Query(QueryBase): class Query(QueryBase):
@ -38,7 +38,7 @@ class Queries(abc.ABC): # export
return self.__name return self.__name
def __init__(self, db: DataBase) -> None: def __init__(self, db: DataBase) -> None:
assert(isinstance(db, DataBase)) assert (isinstance(db, DataBase))
self.__db = db self.__db = db
self.__queries: dict[str, Any] = dict() self.__queries: dict[str, Any] = dict()
@ -57,7 +57,7 @@ class Queries(abc.ABC): # export
def db(self) -> DataBase: def db(self) -> DataBase:
return self.__db return self.__db
def load(self, modules: list[str], cls=QueryBase): def load(self, modules: list[str], cls = QueryBase):
for path in modules: for path in modules:
slog(INFO, f'Loading modules from {path}') slog(INFO, f'Loading modules from {path}')
for c in load_classes(path, cls): for c in load_classes(path, cls):
@ -69,8 +69,8 @@ class Queries(abc.ABC): # export
def add(self, query: QueryBase, query_name: str, location: str, func: Any): def add(self, query: QueryBase, query_name: str, location: str, func: Any):
slog(INFO, f'Adding query "{query_name}" on location "{location}"') slog(INFO, f'Adding query "{query_name}" on location "{location}"')
assert(isinstance(query_name, str)) assert (isinstance(query_name, str))
assert(isinstance(location, str)) assert (isinstance(location, str))
#ret = self.Query(query, func) #ret = self.Query(query, func)
ret = self.Query(query, query_name, location, func) ret = self.Query(query, query_name, location, func)
#setattr(ret, 'name', name) #setattr(ret, 'name', name)

View file

@ -1,18 +1,15 @@
# -*- coding: utf-8 -*- from __future__ import annotations
from typing import Any
import abc import abc
from ...log import * from typing import TYPE_CHECKING, Any
from ...misc import load_classes
from ...Cmds import Cmds
from ..DataBase import DataBase
from ..Session import Session
from .QueryResult import QueryResult
#from .Queries import Queries
class Query(abc.ABC): # export if TYPE_CHECKING:
from ..DataBase import DataBase
from ..Session import Session
from .QueryResult import QueryResult
class Query(abc.ABC): # export
def __init__(self, parent: Any) -> None: def __init__(self, parent: Any) -> None:
self.__parent = parent self.__parent = parent

View file

@ -1,23 +1,22 @@
# -*- coding: utf-8 -*- from __future__ import annotations
from typing import Any, Union
import abc import abc
from enum import Enum, auto from enum import Enum, auto
from typing import TYPE_CHECKING, Any, Union
from ...log import * if TYPE_CHECKING:
from ...Cmds import Cmds from ..DataBase import DataBase
from ..DataBase import DataBase from ..Session import Session
from ..Session import Session
class ResType(Enum): # export class ResType(Enum): # export
Statement = auto() Statement = auto()
Scalars = auto() Scalars = auto()
One = auto() One = auto()
First = auto() First = auto()
Pages = auto() Pages = auto()
class QueryResult(abc.ABC): # export class QueryResult(abc.ABC): # export
def __init__(self, session: Session, query: Any) -> None: def __init__(self, session: Session, query: Any) -> None:
self.__query = query self.__query = query
@ -42,8 +41,8 @@ class QueryResult(abc.ABC): # export
def rows(self) -> list[Any]: def rows(self) -> list[Any]:
return self._cast(ResType.Scalars) return self._cast(ResType.Scalars)
def pages(self, per_page=20, page=1) -> Any: def pages(self, per_page = 20, page = 1) -> Any:
return self._cast(ResType.Pages, per_page=per_page, page=page) return self._cast(ResType.Pages, per_page = per_page, page = page)
def one(self) -> Any: def one(self) -> Any:
return self._cast(ResType.One) return self._cast(ResType.One)
@ -58,5 +57,5 @@ class QueryResult(abc.ABC): # export
# -- pure virtuals # -- pure virtuals
@abc.abstractmethod @abc.abstractmethod
def _cast(self, res_type: ResType, **kwargs) -> Union[Any|list[Any]]: def _cast(self, res_type: ResType, **kwargs) -> Union[Any, list[Any]]:
pass pass

View file

@ -1,19 +1,24 @@
# -*- coding: utf-8 -*- import csv
import io
import json
import os
import re
import textwrap
import io, os, re, textwrap, json, csv from tabulate import TableFormat, tabulate # type: ignore
from tabulate import tabulate # type: ignore
from ..log import * from ..log import (ERR, INFO, WARNING, get_caller_pos, prio_gets_logged, slog, slog_m)
def rows_pretty(rows): # export def rows_pretty(rows): # export
if type(rows) == dict: if isinstance(rows, dict):
rows = [rows] rows = [rows]
out = [] out = []
for row in rows: for row in rows:
out.append(json.dumps(row, sort_keys=True, indent=4, default=str)) out.append(json.dumps(row, sort_keys = True, indent = 4, default = str))
return '\n'.join(out) return '\n'.join(out)
def rows_duplicates(rows, log_prio=INFO, caller=None): # export def rows_duplicates(rows, log_prio = INFO, caller = None): # export
def __equal(r1, r2): def __equal(r1, r2):
for col in set(r1.keys()) | set(r2.keys()): for col in set(r1.keys()) | set(r2.keys()):
if col in r1: if col in r1:
@ -25,11 +30,12 @@ def rows_duplicates(rows, log_prio=INFO, caller=None): # export
if r1[col] != r2[col]: if r1[col] != r2[col]:
return False return False
return True return True
ret = [] ret = []
last = len(rows) - 1 last = len(rows) - 1
i = last i = last
while last > 0: while last > 0:
for i in reversed(range(0, last-1)): for i in reversed(range(0, last - 1)):
if __equal(rows[last], rows[i]): if __equal(rows[last], rows[i]):
ret.append(last) ret.append(last)
last -= 1 last -= 1
@ -37,12 +43,15 @@ def rows_duplicates(rows, log_prio=INFO, caller=None): # export
last -= 1 last -= 1
return ret return ret
def rows_remove(rows, callback=None, candidates=None, log_prio=INFO, caller=None): # export def rows_remove( # export
rows, callback = None, candidates = None, log_prio = INFO, caller = None
):
def __is_remove_candidate(row): def __is_remove_candidate(row):
assert candidates is not None, 'Candidates is None'
for remove_row in candidates: for remove_row in candidates:
for col, val in row.items(): for col, val in row.items():
if not col in remove_row.keys(): if col not in remove_row.keys():
break break
if val != remove_row[col]: if val != remove_row[col]:
break break
@ -65,14 +74,14 @@ def rows_remove(rows, callback=None, candidates=None, log_prio=INFO, caller=None
remove.append(index) remove.append(index)
continue continue
for index in reversed(remove): for index in reversed(remove):
slog(log_prio, f'Removing row {rows[index]}', caller=caller) slog(log_prio, f'Removing row {rows[index]}', caller = caller)
del rows[index] del rows[index]
def rows_select(rows, rules): # export def rows_select(rows, rules): # export
ret = [] ret = []
for row in rows: for row in rows:
for rule in rules: for rule in rules:
if type(rule) == tuple(): if isinstance(rule, tuple):
search_rule = rule[0] search_rule = rule[0]
else: else:
search_rule = rule search_rule = rule
@ -84,7 +93,7 @@ def rows_select(rows, rules): # export
break break
return ret return ret
def rows_rewrite_regex(rows, rules): # export def rows_rewrite_regex(rows, rules): # export
for row in rows: for row in rows:
for rule in rules: for rule in rules:
try: try:
@ -93,14 +102,25 @@ def rows_rewrite_regex(rows, rules): # export
break break
else: else:
for exec_col_name, exec_val in rule[1].items(): for exec_col_name, exec_val in rule[1].items():
slog(INFO, f'Rewriting {row} {row.get(exec_col_name)} -> {exec_val}') slog(
INFO,
f'Rewriting {row} {row.get(exec_col_name)} -> {exec_val}'
)
row[exec_col_name] = exec_val row[exec_col_name] = exec_val
except Exception as e: except Exception as e:
slog(ERR, f'Failed to run rule {rule} against {row} ({e})') slog(ERR, f'Failed to run rule {rule} against {row} ({e})')
raise raise
def rows_check_not_null(rows, keys, log_prio=WARNING, buf=None, stat_key=None, throw=True, caller=None): # export def rows_check_not_null( # export
if type(keys) == str: rows,
keys,
log_prio = WARNING,
buf = None,
stat_key = None,
throw = True,
caller = None
):
if isinstance(keys, str):
keys = [keys] keys = [keys]
if caller is None: if caller is None:
caller = get_caller_pos() caller = get_caller_pos()
@ -113,11 +133,11 @@ def rows_check_not_null(rows, keys, log_prio=WARNING, buf=None, stat_key=None, t
for row in rows: for row in rows:
for key in keys: for key in keys:
if row.get(key) is None: if row.get(key) is None:
slog(log_prio, f'{key} is missing in row {row}', caller=caller) slog(log_prio, f'{key} is missing in row {row}', caller = caller)
buf.append(row) buf.append(row)
if stat_key is not None: if stat_key is not None:
stat_val = row[stat_key] stat_val = row[stat_key]
if not stat_val in stats.keys(): if stat_val not in stats.keys():
stats[stat_val] = 0 stats[stat_val] = 0
stats[stat_val] += 1 stats[stat_val] += 1
count += 1 count += 1
@ -125,14 +145,27 @@ def rows_check_not_null(rows, keys, log_prio=WARNING, buf=None, stat_key=None, t
if count > 0: if count > 0:
if stat_key is not None: if stat_key is not None:
i = 0 i = 0
for k, v in reversed(sorted(stats.items(), key=lambda item: item[1])): for k, v in reversed(sorted(stats.items(), key = lambda item: item[1])):
i += 1 i += 1
slog(ERR, f'{i:>3}. {k:<23}: {v}', caller=caller) slog(ERR, f'{i:>3}. {k:<23}: {v}', caller = caller)
if throw: if throw:
raise Exception(f'Found {count} rows violating null-constraint for keys {keys}') raise Exception(
f'Found {count} rows violating null-constraint for keys {keys}'
)
return buf return buf
def rows_dumps(rows, log_prio=INFO, caller=None, use_cols=None, skip_cols=None, table_name=None, out_path='log', heading=None, lead=None, tablefmt=None): # export def rows_dumps( # export
rows,
log_prio = INFO,
caller = None,
use_cols = None,
skip_cols = None,
table_name = None,
out_path = 'log',
heading = None,
lead = None,
tablefmt = None
):
headers = 'keys' headers = 'keys'
dump_rows = rows dump_rows = rows
@ -152,20 +185,21 @@ def rows_dumps(rows, log_prio=INFO, caller=None, use_cols=None, skip_cols=None,
new_row[col] = val new_row[col] = val
new_dump_rows.append(new_row) new_dump_rows.append(new_row)
dump_rows = new_dump_rows dump_rows = new_dump_rows
out = header = footer = "" header = footer = ""
match tablefmt: match tablefmt:
case 'html': case 'html':
if heading is not None: if heading is not None:
heading = f'<h1>{heading}</h1>\n' heading = f'<h1>{heading}</h1>\n'
if type(lead) == str: if isinstance(lead, str):
lead = f'<div class="lead">\n {lead}\n</div>\n' lead = f'<div class="lead">\n {lead}\n</div>\n'
elif type(lead) == list: elif isinstance(lead, list):
l = '<ul>\n' lst = '<ul>\n'
for li in lead: for li in lead:
l += f'<li>{li}</li>\n' lst += f'<li>{li}</li>\n'
l += '</ul>\n' lst += '</ul>\n'
lead = l lead = lst
header=textwrap.dedent('''\ header = textwrap.dedent(
'''\
<html> <html>
<head> <head>
<script src="https://ajax.googleapis.com/ajax/libs/jquery/2.1.1/jquery.min.js"></script> <script src="https://ajax.googleapis.com/ajax/libs/jquery/2.1.1/jquery.min.js"></script>
@ -185,30 +219,47 @@ def rows_dumps(rows, log_prio=INFO, caller=None, use_cols=None, skip_cols=None,
</style> </style>
</head> </head>
<body> <body>
''') '''
footer = textwrap.dedent(''' )
footer = textwrap.dedent(
'''
</body> </body>
</html> </html>
''') '''
)
case _: case _:
if type(heading) == str: if isinstance(heading, str):
heading = '\n' + heading heading = '\n' + heading
if type(lead) == str: if isinstance(lead, str):
pass pass
elif type(lead) == list: elif isinstance(lead, list):
l ='' lst = ''
for li in lead: for li in lead:
l += f' - {li}\n' lst += f' - {li}\n'
lead = '\n\n' + l + '\n' lead = '\n\n' + lst + '\n'
if heading is None: if heading is None:
heading = '' heading = ''
if lead is None: if lead is None:
lead = '' lead = ''
return header + heading + lead + tabulate(dump_rows, headers=headers, tablefmt=tablefmt) + footer assert isinstance(tablefmt, str) or isinstance(tablefmt, TableFormat), 'tablefmt'
return header + heading + lead + tabulate(
dump_rows, headers = headers, tablefmt = tablefmt
) + footer
def rows_dump(rows, log_prio=INFO, caller=None, use_cols=None, skip_cols=None, table_name=None, out_path='log', heading=None, lead=None, tablefmt=None): # export def rows_dump( # export
rows,
log_prio = INFO,
caller = None,
use_cols = None,
skip_cols = None,
table_name = None,
out_path = 'log',
heading = None,
lead = None,
tablefmt = None
):
if not prio_gets_logged(log_prio): if not prio_gets_logged(log_prio):
return return
@ -218,18 +269,35 @@ def rows_dump(rows, log_prio=INFO, caller=None, use_cols=None, skip_cols=None, t
if tablefmt is None and out_path: if tablefmt is None and out_path:
tablefmt = os.path.splitext(out_path)[1][1:] tablefmt = os.path.splitext(out_path)[1][1:]
out = rows_dumps(rows, log_prio=log_prio, caller=caller, use_cols=use_cols, skip_cols=skip_cols, table_name=table_name, heading=heading, lead=lead, tablefmt=tablefmt) out = rows_dumps(
rows,
log_prio = log_prio,
caller = caller,
use_cols = use_cols,
skip_cols = skip_cols,
table_name = table_name,
heading = heading,
lead = lead,
tablefmt = tablefmt
)
match out_path: match out_path:
case 'log': case 'log':
slog_m(log_prio, out, caller=caller) slog_m(log_prio, out, caller = caller)
case _: case _:
with open(out_path, 'w') as fp: with open(out_path, 'w') as fp:
fp.write(out) fp.write(out)
def rows_to_csv(rows, use_tmpfile=False): # export def rows_to_csv(rows, use_tmpfile = False): # export
def __write(rows, out): def __write(rows, out):
writer = csv.DictWriter(out, fieldnames=field_names, delimiter=';', quotechar='"', quoting=csv.QUOTE_NONNUMERIC) writer = csv.DictWriter(
out,
fieldnames = field_names,
delimiter = ';',
quotechar = '"',
quoting = csv.QUOTE_NONNUMERIC
)
writer.writeheader() writer.writeheader()
for row in rows: for row in rows:
writer.writerow(row) writer.writerow(row)
@ -244,7 +312,7 @@ def rows_to_csv(rows, use_tmpfile=False): # export
__write(rows, out) __write(rows, out)
return out.getvalue() return out.getvalue()
import tempfile import tempfile
with tempfile.TemporaryFile(mode='w', newline='') as out: with tempfile.TemporaryFile(mode = 'w', newline = '') as out:
__write(rows, out) __write(rows, out)
out.seek(0) out.seek(0)
return out.read() return out.read()

View file

@ -1,17 +1,20 @@
# -*- coding: utf-8 -*- from __future__ import annotations
from typing import Optional, Any
import abc import abc
from .DataType import DataType from typing import TYPE_CHECKING, Any, Optional
from ...log import *
class Column(abc.ABC): # export from ...log import ERR, throw
def __init__(self, table, name, data_type: DataType): if TYPE_CHECKING:
from .DataType import DataType
from .Table import Table
class Column(abc.ABC): # export
def __init__(self, table: Table, name: str, data_type: DataType) -> None:
self.__name: str = name self.__name: str = name
self.__table: Any = table self.__table: Table = table
self.__is_nullable: Optional[bool] = None self.__is_nullable: Optional[bool] = None
self.__is_null_insertible: Optional[bool] = None self.__is_null_insertible: Optional[bool] = None
self.__is_primary_key: Optional[bool] = None self.__is_primary_key: Optional[bool] = None
@ -39,18 +42,18 @@ class Column(abc.ABC): # export
return False return False
return True return True
throw(ERR, f'Tried to compare column {self} to type {type(rhs)}: {rhs}') throw(ERR, f'Tried to compare column {self} to type {type(rhs)}: {rhs}')
return False # Unreachable but requested by mypy return False # Unreachable but requested by mypy
@property @property
def name(self) -> str: def name(self) -> str:
return self.__name return self.__name
@property @property
def data_type(self): def data_type(self) -> DataType:
return self.__data_type return self.__data_type
@property @property
def table(self) -> str: def table(self) -> Table:
return self.__table return self.__table
@property @property
@ -60,7 +63,7 @@ class Column(abc.ABC): # export
return self.__is_nullable return self.__is_nullable
@property @property
def is_null_insertible(self): def is_null_insertible(self) -> bool:
if self.__is_null_insertible is None: if self.__is_null_insertible is None:
ret = False ret = False
if self.is_nullable: if self.is_nullable:
@ -81,7 +84,9 @@ class Column(abc.ABC): # export
@property @property
def is_auto_increment(self) -> bool: def is_auto_increment(self) -> bool:
if self.__is_auto_increment is None: if self.__is_auto_increment is None:
self.__is_auto_increment = self.__name in self.__table.auto_increment_columns self.__is_auto_increment = (
self.__name in self.__table.auto_increment_columns
)
return self.__is_auto_increment return self.__is_auto_increment
@property @property
@ -113,8 +118,8 @@ class Column(abc.ABC): # export
def foreign_key(self, table) -> Optional[Any]: def foreign_key(self, table) -> Optional[Any]:
if self.__foreign_keys_by_table is None: if self.__foreign_keys_by_table is None:
self.__foreign_keys_by_table = dict() self.__foreign_keys_by_table = dict()
for col in self.foreign_keys: # type: ignore # Any not iterable for col in self.foreign_keys: # type: ignore # Any not iterable
assert(col.table.name not in self.__foreign_keys_by_table) assert (col.table.name not in self.__foreign_keys_by_table)
self.__foreign_keys_by_table[col.table.name] = col self.__foreign_keys_by_table[col.table.name] = col
table_name = table if isinstance(table, str) else table.name table_name = table if isinstance(table, str) else table.name
return self.__foreign_keys_by_table.get(table_name) return self.__foreign_keys_by_table.get(table_name)

View file

@ -1,20 +1,24 @@
# -*- coding: utf-8 -*- from typing import Optional, Any
from typing import Optional, Iterable, Any class ColumnSet: # export
class ColumnSet: # export def __init__(
self,
def __init__(self, *args: list[Any], columns: list[Any]=[], table: Optional[Any]=None, names: Optional[list[str]]=None): *args: list[Any],
columns: list[Any] = [],
table: Optional[Any] = None,
names: Optional[list[str]] = None
):
self.__columns: list[Any] = [*args] self.__columns: list[Any] = [*args]
self.__columns.extend(columns) self.__columns.extend(columns)
self.__table = table self.__table = table
if names is not None: if names is not None:
assert(table is not None) assert (table is not None)
for name in names: for name in names:
self.__columns.append(table.column(name)) self.__columns.append(table.column(name))
if self.__table is not None: if self.__table is not None:
for col in columns: for col in columns:
assert(col.table == self.__table) assert (col.table == self.__table)
def __len__(self): def __len__(self):
return len(self.__columns) return len(self.__columns)

View file

@ -1,15 +1,18 @@
# -*- coding: utf-8 -*- from __future__ import annotations
from typing import Optional, Any from typing import TYPE_CHECKING, Any, Optional
from ...log import * from ...log import WARNING, slog
from .ColumnSet import ColumnSet
from .SingleForeignKey import SingleForeignKey from .SingleForeignKey import SingleForeignKey
class CompositeForeignKey: # export if TYPE_CHECKING:
from .ColumnSet import ColumnSet
def __init__(self, child_col_set: ColumnSet, parent_col_set: ColumnSet): # TODO: Implement alternative ways to construct class CompositeForeignKey: # export
def __init__(
self, child_col_set: ColumnSet, parent_col_set: ColumnSet
): # TODO: Implement alternative ways to construct
def __table(s): def __table(s):
ret = None ret = None
@ -17,8 +20,8 @@ class CompositeForeignKey: # export
if ret is None: if ret is None:
ret = c.table ret = c.table
else: else:
assert(ret == c.table) assert (ret == c.table)
assert(ret is not None) assert (ret is not None)
return ret return ret
self.__child_col_set = child_col_set self.__child_col_set = child_col_set
@ -26,7 +29,7 @@ class CompositeForeignKey: # export
self.__child_table = __table(child_col_set) self.__child_table = __table(child_col_set)
self.__parent_table = __table(parent_col_set) self.__parent_table = __table(parent_col_set)
assert(len(self.__child_col_set) == len(self.__parent_col_set)) assert (len(self.__child_col_set) == len(self.__parent_col_set))
self.__len = len(self.__child_col_set) self.__len = len(self.__child_col_set)
self.__column_relations: Optional[list[SingleForeignKey]] = None self.__column_relations: Optional[list[SingleForeignKey]] = None
self.__parent_columns_by_child_column: Optional[dict[str, Any]] = None self.__parent_columns_by_child_column: Optional[dict[str, Any]] = None
@ -46,7 +49,12 @@ class CompositeForeignKey: # export
def __repr__(self): def __repr__(self):
ret = self.__table_rel_str() ret = self.__table_rel_str()
ret += ': ' + ', '.join([self.__cols_rel_str(rel.child_column, rel.parent_column) for rel in self.column_relations]) ret += ': ' + ', '.join(
[
self.__cols_rel_str(rel.child_column, rel.parent_column)
for rel in self.column_relations
]
)
return ret return ret
def __eq__(self, rhs): def __eq__(self, rhs):
@ -73,21 +81,25 @@ class CompositeForeignKey: # export
return self.__parent_col_set return self.__parent_col_set
def parent_column(self, child_column) -> Any: def parent_column(self, child_column) -> Any:
child_column_name = child_column if isinstance(child_column, str) else child_column.name child_column if isinstance(child_column, str) else child_column.name
if self.__parent_columns_by_child_column is None: if self.__parent_columns_by_child_column is None:
d: dict[str, Any] = {} d: dict[str, Any] = {}
assert(len(self.__child_col_set) == len(self.__parent_col_set)) assert (len(self.__child_col_set) == len(self.__parent_col_set))
for i in range(0, len(self.__child_col_set)): for i in range(0, len(self.__child_col_set)):
d[self.__child_col_set[i].name] = self.__parent_col_set[i] d[self.__child_col_set[i].name] = self.__parent_col_set[i]
self.__parent_columns_by_child_column = d self.__parent_columns_by_child_column = d
return self.__parent_columns_by_child_column[child_column] return self.__parent_columns_by_child_column[child_column]
def child_column(self, parent_column) -> Any: def child_column(self, parent_column) -> Any:
slog(WARNING, f'{self}: Looking for child column belonging to parent column "{parent_column}"') slog(
parent_column_name = parent_column if isinstance(parent_column, str) else parent_column.name WARNING,
f'{self}: Looking for child column belonging to parent column '
f'"{parent_column}"'
)
parent_column if isinstance(parent_column, str) else parent_column.name
if self.__child_columns_by_parent_column is None: if self.__child_columns_by_parent_column is None:
d: dict[str, Any] = {} d: dict[str, Any] = {}
assert(len(self.__parent_col_set) == len(self.__child_col_set)) assert (len(self.__parent_col_set) == len(self.__child_col_set))
for i in range(0, len(self.__parent_col_set)): for i in range(0, len(self.__parent_col_set)):
d[self.__parent_col_set[i].name] = self.__child_col_set[i] d[self.__parent_col_set[i].name] = self.__child_col_set[i]
self.__child_columns_by_parent_column = d self.__child_columns_by_parent_column = d
@ -98,6 +110,8 @@ class CompositeForeignKey: # export
ret = [] ret = []
if self.__column_relations is None: if self.__column_relations is None:
for i in range(0, self.__len): for i in range(0, self.__len):
ret.append(SingleForeignKey(self.__child_col_set[i], self.__parent_col_set[i])) ret.append(
SingleForeignKey(self.__child_col_set[i], self.__parent_col_set[i])
)
self.__column_relations = ret self.__column_relations = ret
return self.__column_relations return self.__column_relations

View file

@ -1,10 +1,8 @@
# -*- coding: utf-8 -*-
from typing import Optional
from enum import Enum, auto
from datetime import datetime from datetime import datetime
from enum import Enum, auto
from typing import Optional
from ...log import * from ...log import ERR, throw
class Id(Enum): class Id(Enum):
Integer = auto() Integer = auto()
@ -16,7 +14,7 @@ class Id(Enum):
Text = auto() Text = auto()
Invalid = auto() Invalid = auto()
def py_type(type_id: Id) -> type: # export def py_type(type_id: Id) -> type: # export
match type_id: match type_id:
case Id.Integer: case Id.Integer:
@ -38,14 +36,17 @@ def py_type(type_id: Id) -> type: # export
raise Exception(f'Unknown column type-id "{type_id}"') raise Exception(f'Unknown column type-id "{type_id}"')
class DataType: # export class DataType: # export
def __init__(self, type_id: Id, size: Optional[int]=None): def __init__(self, type_id: Id, size: Optional[int] = None):
if not isinstance(type_id, Id): if not isinstance(type_id, Id):
throw(ERR, f'Passed type id "{type_id}" with unsupported data type {type(type_id)}') throw(
ERR,
f'Passed type id "{type_id}" with unsupported data type {type(type_id)}'
)
if size is not None: if size is not None:
assert(isinstance(size, int)) assert (isinstance(size, int))
assert(size > 0) assert (size > 0)
self.__id = type_id self.__id = type_id
self.__size = size self.__size = size
@ -80,4 +81,4 @@ class DataType: # export
@property @property
def py_type_annotation(self) -> str: def py_type_annotation(self) -> str:
return self.py_type_str # FIXME: This is not always correct return self.py_type_str # FIXME: This is not always correct

View file

@ -1,20 +1,20 @@
# -*- coding: utf-8 -*- from __future__ import annotations
from typing import Optional, Iterable
import abc import abc
from ...log import * from typing import TYPE_CHECKING, Iterable, Optional
from .Table import Table from ...log import DEBUG, ERR, slog, throw
from .Column import Column
from .DataType import DataType
from .CompositeForeignKey import CompositeForeignKey
class Schema(abc.ABC): # export if TYPE_CHECKING:
from .Column import Column
from .CompositeForeignKey import CompositeForeignKey
from .Table import Table
class Schema(abc.ABC): # export
def __init__(self) -> None: def __init__(self) -> None:
self.___tables: Optional[list[Table]] = None self.___tables: Optional[dict[str, Table]] = None
self.__foreign_keys: Optional[list[CompositeForeignKey]] = None self.__foreign_keys: Optional[list[CompositeForeignKey]] = None
self.__access_defining_columns: Optional[list[str]] = None self.__access_defining_columns: Optional[list[str]] = None
@ -24,7 +24,7 @@ class Schema(abc.ABC): # export
ret = dict() ret = dict()
for name in self._table_names(): for name in self._table_names():
slog(DEBUG, f'Caching metadata for table "{name}"') slog(DEBUG, f'Caching metadata for table "{name}"')
assert(isinstance(name, str)) assert (isinstance(name, str))
ret[name] = self._table(name) ret[name] = self._table(name)
self.___tables = ret self.___tables = ret
return self.___tables return self.___tables
@ -39,7 +39,7 @@ class Schema(abc.ABC): # export
@abc.abstractmethod @abc.abstractmethod
def _table(self, name: str) -> Table: def _table(self, name: str) -> Table:
throw(ERR, "Called pure virtual base class method") throw(ERR, "Called pure virtual base class method")
return None # type: ignore return None # type: ignore
@abc.abstractmethod @abc.abstractmethod
def _foreign_keys(self) -> list[CompositeForeignKey]: def _foreign_keys(self) -> list[CompositeForeignKey]:
@ -50,7 +50,7 @@ class Schema(abc.ABC): # export
pass pass
@abc.abstractmethod @abc.abstractmethod
def _model_module_search_paths(self) -> list[tuple[str, type]]: def _model_module_search_paths(self) -> list[tuple[str, type]]:
pass pass
# ------ API to be called # ------ API to be called
@ -62,7 +62,7 @@ class Schema(abc.ABC): # export
yield from self.__tables.values() yield from self.__tables.values()
def __repr__(self): def __repr__(self):
return '|'.join([table.name for table in self.__tables]) return '|'.join([table.name for table in self.__tables.values()])
def __getitem__(self, index): def __getitem__(self, index):
return self.__tables[index] return self.__tables[index]
@ -90,13 +90,13 @@ class Schema(abc.ABC): # export
def table(self, name: str) -> Table: def table(self, name: str) -> Table:
return self.__tables[name] return self.__tables[name]
def table_by_model_name(self, name: str, throw=False) -> Table: def table_by_model_name(self, name: str, throw = False) -> Table:
for table in self.__tables.values(): for table in self.__tables.values():
if table.model_name == name: if table.model_name == name:
return table return table
if throw: if throw:
raise Exception(f'Table "{name}" not found in database metadata') raise Exception(f'Table "{name}" not found in database metadata')
return None # type: ignore return None # type: ignore
def primary_keys(self, table_name: str) -> Iterable[str]: def primary_keys(self, table_name: str) -> Iterable[str]:
return self.__tables[table_name].primary_keys return self.__tables[table_name].primary_keys
@ -105,5 +105,5 @@ class Schema(abc.ABC): # export
return self.__tables[table_name].columns return self.__tables[table_name].columns
@property @property
def model_module_search_paths(self) -> list[tuple[str, type]]: def model_module_search_paths(self) -> list[tuple[str, type]]:
return self._model_module_search_paths() return self._model_module_search_paths()

View file

@ -1,9 +1,9 @@
# -*- coding: utf-8 -*- from __future__ import annotations
from typing import Optional, Any from typing import TYPE_CHECKING
from .Column import Column if TYPE_CHECKING:
from .ColumnSet import ColumnSet from .Column import Column
class SingleForeignKey: class SingleForeignKey:
@ -19,7 +19,10 @@ class SingleForeignKey:
yield from self.__iterable yield from self.__iterable
def __repr__(self): def __repr__(self):
return f'{self.__child_col.table.name}.{self.__child_col.name} -> {self.__parent_col.table.name}.{self.__parent_col.name}' return (
f'{self.__child_col.table.name}.{self.__child_col.name} -> '
'{self.__parent_col.table.name}.{self.__parent_col.name}'
)
def __getitem__(self, index): def __getitem__(self, index):
return self.__iterable[index] return self.__iterable[index]

View file

@ -1,47 +1,51 @@
# -*- coding: utf-8 -*- from __future__ import annotations
from typing import Optional, Union, Iterable, Self, Any # TODO: Need any for many things, as I can't figure out how to avoid circular imports from here
import abc import abc
import re
from collections import OrderedDict from collections import OrderedDict
from typing import TYPE_CHECKING
from urllib.parse import quote_plus from urllib.parse import quote_plus
from ...log import * from ...log import ERR, WARNING, slog, throw
from ...misc import load_class from ...misc import load_class
from .ColumnSet import ColumnSet
from .DataType import DataType
from .CompositeForeignKey import CompositeForeignKey
from .Column import Column from .Column import Column
from .ColumnSet import ColumnSet
class Table(abc.ABC): # export if TYPE_CHECKING:
from typing import Any, Iterable, Optional, Self, Union
from .CompositeForeignKey import CompositeForeignKey
from .DataType import DataType
class Table(abc.ABC): # export
def __init__(self, schema, name: str): def __init__(self, schema, name: str):
assert(isinstance(name, str)) assert (isinstance(name, str))
self.__schema = schema self.__schema = schema
self.__name = name self.__name = name
self.___columns: Optional[OrderedDict[str, Any]] = None self.___columns: Optional[OrderedDict[str, Any]] = None
self.___foreign_key_parent_tables: Optional[OrderedDict[str, Any]] = None self.___foreign_key_parent_tables: Optional[OrderedDict[str, Any]] = None
self.__primary_keys: Optional[Iterable[str]] = None self.__primary_keys: Optional[Iterable[str]] = None
self.__unique_constraints: Optional[list[ColumnSet]] = None self.__unique_constraints: Optional[list[ColumnSet]] = None
self.__foreign_key_constraints: Optional[list[CompositeForeignKey]] = None self.__foreign_key_constraints: Optional[list[CompositeForeignKey]] = None
self.__nullable_columns: Optional[Iterable[str]] = None self.__nullable_columns: Optional[Iterable[str]] = None
self.__non_nullable_columns: Optional[Iterable[str]] = None self.__non_nullable_columns: Optional[Iterable[str]] = None
self.__null_insertible_columns: Optional[Iterable[str]] = None self.__null_insertible_columns: Optional[Iterable[str]] = None
self.__not_null_insertible_columns: Optional[Iterable[str]] = None self.__not_null_insertible_columns: Optional[Iterable[str]] = None
self.__log_columns: Optional[Iterable[str]] = None self.__log_columns: Optional[Iterable[str]] = None
self.__edit_columns: Optional[Iterable[str]] = None self.__edit_columns: Optional[Iterable[str]] = None
self.__translate_columns: Optional[Iterable[str]] = None self.__translate_columns: Optional[Iterable[str]] = None
self.__display_columns: Optional[Iterable[str]] = None self.__display_columns: Optional[Iterable[str]] = None
self.__default_sort_columns: Optional[Iterable[str]] = None self.__default_sort_columns: Optional[Iterable[str]] = None
self.__column_default: Optional[dict[str, Any]] = None self.__column_default: Optional[dict[str, Any]] = None
self.__base_location_rule: Optional[Iterable[str]] = None self.__base_location_rule: Optional[Iterable[str]] = None
self.__location_rule: Optional[Iterable[str]] = None self.__location_rule: Optional[Iterable[str]] = None
self.__row_location_rule: Optional[Iterable[str]] = None self.__row_location_rule: Optional[Iterable[str]] = None
self.__add_row_location_rule: Optional[Iterable[str]] = None self.__add_row_location_rule: Optional[Iterable[str]] = None
self.___add_child_row_location_rules: Optional[dict[str, str]] = None self.___add_child_row_location_rules: Optional[dict[str, str]] = None
self.__foreign_keys_to_parent_table: Optional[OrderedDict[str, Any]] = None self.__foreign_keys_to_parent_table: Optional[OrderedDict[str, Any]] = None
self.__relationships: Optional[list[tuple[str, Self]]] = None self.__relationships: Optional[list[tuple[str, Self]]] = None
self.__model_class: Optional[Any] = None self.__model_class: Optional[Any] = None
@ -61,7 +65,8 @@ class Table(abc.ABC): # export
if self.___foreign_key_parent_tables is None: if self.___foreign_key_parent_tables is None:
self.___foreign_key_parent_tables = OrderedDict() self.___foreign_key_parent_tables = OrderedDict()
for cfk in self.foreign_key_constraints: for cfk in self.foreign_key_constraints:
self.___foreign_key_parent_tables[cfk.parent_table.name] = cfk.parent_table self.___foreign_key_parent_tables[cfk.parent_table.name
] = cfk.parent_table
return self.___foreign_key_parent_tables return self.___foreign_key_parent_tables
@property @property
@ -77,12 +82,12 @@ class Table(abc.ABC): # export
def __add_child_row_location_rules(self) -> dict[str, str]: def __add_child_row_location_rules(self) -> dict[str, str]:
if self.___add_child_row_location_rules is None: if self.___add_child_row_location_rules is None:
ret: dict[str, str] = {} ret: dict[str, str] = {}
for foreign_table_name, foreign_table in self.__relationship_by_foreign_table.items(): for table_name, table in self.__relationship_by_foreign_table.items():
if len([self.foreign_keys_to_parent_table(foreign_table)]): if len([self.foreign_keys_to_parent_table(table)]):
rule = self._add_child_row_location_rule(foreign_table_name) rule = self._add_child_row_location_rule(table_name)
if rule is None: if rule is None:
continue continue
ret[foreign_table_name] = rule ret[table_name] = rule
self.___add_child_row_location_rules = ret self.___add_child_row_location_rules = ret
return self.___add_child_row_location_rules return self.___add_child_row_location_rules
@ -108,7 +113,7 @@ class Table(abc.ABC): # export
return False return False
return True return True
throw(ERR, f'Tried to compare table {self} to type {type(rhs)}: {rhs}') throw(ERR, f'Tried to compare table {self} to type {type(rhs)}: {rhs}')
return False # Unreachable but requested by mypy return False # Unreachable but requested by mypy
def __hash__(self) -> int: def __hash__(self) -> int:
return hash(self.name) return hash(self.name)
@ -167,8 +172,9 @@ class Table(abc.ABC): # export
slog(WARNING, f'Returning None model name for table {self.name}') slog(WARNING, f'Returning None model name for table {self.name}')
return None return None
def _model_module_search_paths(self) -> list[tuple[str, type]]: def _model_module_search_paths(self) -> list[tuple[str, type]]:
return self.schema.model_module_search_paths # Fall back to Schema-global default # Fall back to Schema-global default
return self.schema.model_module_search_paths
@abc.abstractmethod @abc.abstractmethod
def _query_name(self) -> str: def _query_name(self) -> str:
@ -190,7 +196,9 @@ class Table(abc.ABC): # export
for col in self.__schema.access_defining_columns: for col in self.__schema.access_defining_columns:
if col in self.primary_keys: if col in self.primary_keys:
ret += f'/<{col}>' ret += f'/<{col}>'
ret += self.base_location_rule base = self.base_location_rule
if base is not None:
ret += base if isinstance(base, str) else 'what-goes-here?'.join(base)
return ret return ret
def _row_location_rule(self) -> Optional[str]: def _row_location_rule(self) -> Optional[str]:
@ -261,7 +269,7 @@ class Table(abc.ABC): # export
return None return None
pattern = r'^' + model_name + '$' pattern = r'^' + model_name + '$'
for module_path, base_class in self._model_module_search_paths(): for module_path, base_class in self._model_module_search_paths():
ret = load_class(module_path, base_class, class_name_filter=pattern) ret = load_class(module_path, base_class, class_name_filter = pattern)
if ret is not None: if ret is not None:
self.__model_class = ret self.__model_class = ret
break break
@ -288,8 +296,8 @@ class Table(abc.ABC): # export
return self.__location_rule return self.__location_rule
def location(self, *args, **kwargs): def location(self, *args, **kwargs):
ret = self.location_rule ret = str(self.location_rule)
for token, val in kwargs.items(): # FIXME: Poor man's row location assembly for token, val in kwargs.items(): # FIXME: Poor man's row location assembly
ret = re.sub(f'<{token}>', quote_plus(quote_plus(str(val))), ret) ret = re.sub(f'<{token}>', quote_plus(quote_plus(str(val))), ret)
return ret return ret
@ -300,9 +308,9 @@ class Table(abc.ABC): # export
return self.__row_location_rule return self.__row_location_rule
def row_location(self, *args, **kwargs): def row_location(self, *args, **kwargs):
ret = self.row_location_rule ret = str(self.row_location_rule)
for col in self.primary_keys: for col in self.primary_keys:
if col in kwargs: # FIXME: Poor man's row location assembly if col in kwargs: # FIXME: Poor man's row location assembly
ret = re.sub(f'<{col}>', quote_plus(quote_plus(str(kwargs[col]))), ret) ret = re.sub(f'<{col}>', quote_plus(quote_plus(str(kwargs[col]))), ret)
return ret return ret
@ -313,9 +321,9 @@ class Table(abc.ABC): # export
return self.__add_row_location_rule return self.__add_row_location_rule
def add_row_location(self, *args, **kwargs) -> Optional[str]: def add_row_location(self, *args, **kwargs) -> Optional[str]:
ret = self.add_row_location_rule ret = str(self.add_row_location_rule)
for col in self.primary_keys: for col in self.primary_keys:
if col in kwargs: # FIXME: Poor man's row location assembly if col in kwargs: # FIXME: Poor man's row location assembly
ret = re.sub(f'<{col}>', quote_plus(quote_plus(str(kwargs[col]))), ret) ret = re.sub(f'<{col}>', quote_plus(quote_plus(str(kwargs[col]))), ret)
return ret return ret
@ -323,12 +331,14 @@ class Table(abc.ABC): # export
def add_child_row_location_rules(self) -> Iterable[str]: def add_child_row_location_rules(self) -> Iterable[str]:
return self.__add_child_row_location_rules.values() return self.__add_child_row_location_rules.values()
def add_child_row_location_rule(self, child_table: Union[Self, str]) -> Optional[str]: def add_child_row_location_rule(self, child_table: Union[Self,
str]) -> Optional[str]:
if isinstance(child_table, Table): if isinstance(child_table, Table):
child_table = child_table.name child_table = child_table.name
return self.__add_child_row_location_rules.get(child_table) return self.__add_child_row_location_rules.get(child_table)
def add_child_row_location(self, parent_table: Union[Self, str], **kwargs) -> Optional[str]: def add_child_row_location(self, parent_table: Union[Self, str],
**kwargs) -> Optional[str]:
ret = self.add_child_row_location_rule(parent_table) ret = self.add_child_row_location_rule(parent_table)
if isinstance(parent_table, str): if isinstance(parent_table, str):
parent_table = self.schema[parent_table] parent_table = self.schema[parent_table]
@ -337,7 +347,11 @@ class Table(abc.ABC): # export
for cfk in self.foreign_keys_to_parent_table(parent_table): for cfk in self.foreign_keys_to_parent_table(parent_table):
for fk in cfk: for fk in cfk:
if fk.parent_column.name in kwargs: if fk.parent_column.name in kwargs:
ret = re.sub(f'<{fk.child_column.name}>', quote_plus(quote_plus(str(kwargs[fk.parent_column.name]))), ret) ret = re.sub(
f'<{fk.child_column.name}>',
quote_plus(quote_plus(str(kwargs[fk.parent_column.name]))),
ret
)
return ret return ret
@property @property
@ -425,7 +439,7 @@ class Table(abc.ABC): # export
impl = self._unique_constraints() impl = self._unique_constraints()
if impl is not None: if impl is not None:
for columns in impl: for columns in impl:
ret.append(ColumnSet(columns=columns)) ret.append(ColumnSet(columns = columns))
self.__unique_constraints = ret self.__unique_constraints = ret
return self.__unique_constraints return self.__unique_constraints
@ -443,7 +457,8 @@ class Table(abc.ABC): # export
def foreign_key_parent_tables(self): def foreign_key_parent_tables(self):
return self.__foreign_key_parent_tables.values() return self.__foreign_key_parent_tables.values()
def foreign_keys_to_parent_table(self, parent_table) -> Iterable[CompositeForeignKey]: def foreign_keys_to_parent_table(self,
parent_table) -> Iterable[CompositeForeignKey]:
if self.__foreign_keys_to_parent_table is None: if self.__foreign_keys_to_parent_table is None:
self.__foreign_keys_to_parent_table = OrderedDict() self.__foreign_keys_to_parent_table = OrderedDict()
for cfk in self.foreign_key_constraints: for cfk in self.foreign_key_constraints:
@ -451,8 +466,12 @@ class Table(abc.ABC): # export
if pt not in self.__foreign_keys_to_parent_table: if pt not in self.__foreign_keys_to_parent_table:
self.__foreign_keys_to_parent_table[pt] = [] self.__foreign_keys_to_parent_table[pt] = []
self.__foreign_keys_to_parent_table[pt].append(cfk) self.__foreign_keys_to_parent_table[pt].append(cfk)
parent_table_name = parent_table if isinstance(parent_table, str) else parent_table.name parent_table_name = parent_table if isinstance(
return self.__foreign_keys_to_parent_table[parent_table_name] if parent_table_name in self.__foreign_keys_to_parent_table else [] parent_table, str
) else parent_table.name
return self.__foreign_keys_to_parent_table[
parent_table_name
] if parent_table_name in self.__foreign_keys_to_parent_table else []
@property @property
def relationships(self) -> list[tuple[str, Self]]: def relationships(self) -> list[tuple[str, Self]]:

View file

@ -1,12 +1,18 @@
# -*- coding: utf-8 -*- from __future__ import annotations
from .Schema import Schema from typing import TYPE_CHECKING
from ...log import * from ...log import NOTICE, slog
def check_schema(schema: Schema): # export if TYPE_CHECKING:
from .Schema import Schema
def check_schema(schema: Schema): # export
slog(NOTICE, f'There are {len(schema)} tables in the database') slog(NOTICE, f'There are {len(schema)} tables in the database')
for cfk in schema.foreign_key_constraints: for cfk in schema.foreign_key_constraints:
for fk in cfk: for fk in cfk:
if fk.child_column.data_type != fk.parent_column.data_type: if fk.child_column.data_type != fk.parent_column.data_type:
raise Exception(f'Type mismatch in foreign key {fk}: {fk.child_column.data_type} != {fk.parent_column.data_type}') raise Exception(
f'Type mismatch in foreign key {fk}: {fk.child_column.data_type} '
f'!= {fk.parent_column.data_type}'
)

View file

@ -1,29 +1,35 @@
# -*- coding: utf-8 -*- from __future__ import annotations
from collections.abc import Callable
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from ...log import * from typing import TYPE_CHECKING
class MapAttr2Shape: # export if TYPE_CHECKING:
from collections.abc import Callable
from typing import Any
def __init__(self, mappings: dict[str, str|Callable[[dict[str, str]], str]]|None=None) -> None: class MapAttr2Shape: # export
def __init__(
self,
mappings: dict[str, str | Callable[[dict[str, str]], str]] | None = None
) -> None:
self.__mappings = mappings if mappings is not None else {} self.__mappings = mappings if mappings is not None else {}
self.__shape_node_key = 'd25' self.__shape_node_key = 'd25'
self.__ns_gml = "http://graphml.graphdrawing.org/xmlns" self.__ns_gml = "http://graphml.graphdrawing.org/xmlns"
self.__ns = { self.__ns = {
# -- Standard GraphML # -- Standard GraphML
"": self.__ns_gml, "": self.__ns_gml,
"xsi": "http://www.w3.org/2001/XMLSchema-instance", "xsi": "http://www.w3.org/2001/XMLSchema-instance",
"xsi:schemaLocation": "http://graphml.graphdrawing.org/xmlns http://graphml.graphdrawing.org/xmlns/1.0/graphml.xsd", "xsi:schemaLocation":
"http://graphml.graphdrawing.org/xmlns http://graphml.graphdrawing.org/xmlns/1.0/graphml.xsd",
# -- YWorks GraphML # -- YWorks GraphML
"java": "http://www.yworks.com/xml/yfiles-common/1.0/java", "java": "http://www.yworks.com/xml/yfiles-common/1.0/java",
"sys": "http://www.yworks.com/xml/yfiles-common/markup/primitives/2.0", "sys": "http://www.yworks.com/xml/yfiles-common/markup/primitives/2.0",
"x": "http://www.yworks.com/xml/yfiles-common/markup/2.0", "x": "http://www.yworks.com/xml/yfiles-common/markup/2.0",
"y": "http://www.yworks.com/xml/graphml", "y": "http://www.yworks.com/xml/graphml",
"yed": "http://www.yworks.com/xml/yed/3", "yed": "http://www.yworks.com/xml/yed/3",
} }
# https://stackoverflow.com/questions/4997848/ # https://stackoverflow.com/questions/4997848/
for name, url in self.__ns.items(): for name, url in self.__ns.items():
@ -72,7 +78,7 @@ class MapAttr2Shape: # export
ns, tag = tag.split(':') ns, tag = tag.split(':')
tag = '{' + self.__ns[ns] + '}' + tag tag = '{' + self.__ns[ns] + '}' + tag
attrib = content.get('a') or {} attrib = content.get('a') or {}
el = ET.Element(tag, attrib=attrib) el = ET.Element(tag, attrib = attrib)
text = content.get('t') text = content.get('t')
if text is not None: if text is not None:
el.text = text el.text = text
@ -81,10 +87,7 @@ class MapAttr2Shape: # export
if children is not None: if children is not None:
__add(el, children) __add(el, children)
default_values = { default_values = {'color': '#FFCC00', 'text': ''}
'color': '#FFCC00',
'text': ''
}
values = {} values = {}
for key, default in default_values.items(): for key, default in default_values.items():
@ -98,11 +101,11 @@ class MapAttr2Shape: # export
continue continue
mapped = mapping(self.__attribs(node, keys)) mapped = mapping(self.__attribs(node, keys))
values[key] = mapped or default values[key] = mapped or default
except: except Exception:
pass pass
color = values['color'] color = values['color']
text = values['text'] text = values['text']
has_text = 'true' if text else 'false' has_text = 'true' if text else 'false'
width_text = round(len(text) * 5.75, 5) if text else 0 width_text = round(len(text) * 5.75, 5) if text else 0
@ -110,61 +113,89 @@ class MapAttr2Shape: # export
shape = { shape = {
'data': { 'data': {
'a': {'key': self.__shape_node_key}, 'a': {
'key': self.__shape_node_key
},
'c': { 'c': {
'y:ShapeNode': { 'y:ShapeNode': {
'a': {}, 'a': {},
'c': { 'c': {
'y:Geometry': {'a': {'height': '30.0', 'width': str(width_box), 'x': str(-(width_box / 2)), 'y':' -15.0'}}, 'y:Geometry': {
'y:Fill': {'a': {'color': color, 'transparent': 'false'}}, 'a': {
'y:BorderStyle': {'a': {'color': '#000000', 'raised': 'false', 'type': 'line', 'width': '1.0'}}, 'height': '30.0',
'width': str(width_box),
'x': str(-(width_box / 2)),
'y': ' -15.0'
}
},
'y:Fill': {
'a': {
'color': color, 'transparent': 'false'
}
},
'y:BorderStyle': {
'a': {
'color': '#000000',
'raised': 'false',
'type': 'line',
'width': '1.0'
}
},
'y:NodeLabel': { 'y:NodeLabel': {
'a': { 'a': {
'alignment': 'center', 'alignment': 'center',
'autoSizePolicy': 'content', 'autoSizePolicy': 'content',
'fontFamily': 'Dialog', 'fontFamily': 'Dialog',
'fontSize': '12', 'fontSize': '12',
'fontStyle': 'plain', 'fontStyle': 'plain',
'hasBackgroundColor': 'false', 'hasBackgroundColor': 'false',
'hasLineColor': 'false', 'hasLineColor': 'false',
'hasText': has_text, 'hasText': has_text,
'height': '18', 'height': '18',
'horizontalTextPosition': 'center', 'horizontalTextPosition': 'center',
'iconTextGap': '4', 'iconTextGap': '4',
'modelName': 'custom', 'modelName': 'custom',
'textColor': '#000000', 'textColor': '#000000',
'verticalTextPosition': 'bottom', 'verticalTextPosition': 'bottom',
'visible': 'true', 'visible': 'true',
'width': str(width_text), 'width': str(width_text),
'x': '13.0', 'x': '13.0',
'y': '13.0', 'y': '13.0',
}, },
'c': { 'c': {
'y:LabelModel': { 'y:LabelModel': {
'c': { 'c': {
'y:SmartNodeLabelModel': {'a': {'distance': '4.0'}} 'y:SmartNodeLabelModel': {
}, 'a': {
}, 'distance': '4.0'
'y:ModelParameter': { }
'c': { }
'y:SmartNodeLabelModelParameter': { },
'a': { },
'labelRatioX':'0.0', 'y:ModelParameter': {
'labelRatioY': '0.0', 'c': {
'nodeRatioX': '0.0', 'y:SmartNodeLabelModelParameter': {
'nodeRatioY': '0.0', 'a': {
'offsetX': '0.0', 'labelRatioX': '0.0',
'offsetY': '0.0', 'labelRatioY': '0.0',
'upX': '0.0', 'nodeRatioX': '0.0',
'upY': '-1.0', 'nodeRatioY': '0.0',
} 'offsetX': '0.0',
} 'offsetY': '0.0',
} 'upX': '0.0',
} 'upY': '-1.0',
}
}
}
}
}, },
't': text 't': text
}, },
'y:Shape': {'a': {'type': 'rectangle'}} 'y:Shape': {
'a': {
'type': 'rectangle'
}
}
} }
} }
} }
@ -175,17 +206,17 @@ class MapAttr2Shape: # export
def __massage_nodes(self, root) -> None: def __massage_nodes(self, root) -> None:
keys = self.__keys(root) keys = self.__keys(root)
graph = root.find(f'graph', self.__ns) graph = root.find('graph', self.__ns)
for node in graph: for node in graph:
self.__massage_node(node, keys) self.__massage_node(node, keys)
def run(self, path_in: str, path_out: str) -> None: def run(self, path_in: str, path_out: str) -> None:
parser = ET.XMLParser(encoding="utf-8") parser = ET.XMLParser(encoding = "utf-8")
tree = ET.parse(path_in, parser=parser) tree = ET.parse(path_in, parser = parser)
root = tree.getroot() root = tree.getroot()
self.__add_key_nodegraphics(root) self.__add_key_nodegraphics(root)
self.__massage_nodes(root) self.__massage_nodes(root)
ET.indent(tree, space=' ', level=0) ET.indent(tree, space = ' ', level = 0)
tree.write(path_out, xml_declaration=True, encoding='utf-8') tree.write(path_out, xml_declaration = True, encoding = 'utf-8')

View file

@ -1,36 +1,47 @@
# -*- coding: utf-8 -*- from __future__ import annotations
import copy
import getpass
import pathlib
import ldap, getpass, pathlib, copy
from ldap.schema.models import ObjectClass
from enum import Flag, auto from enum import Flag, auto
import networkx as nx from typing import TYPE_CHECKING, Any, Self
from typing import Any, Self
from collections.abc import Callable
from .Config import Config as BaseConfig import ldap # type: ignore[import-untyped]
from .log import * import networkx as nx # type: ignore[import-untyped]
from ldap.schema.models import ObjectClass # type: ignore[import-untyped]
from .log import ERR, INFO, WARNING, slog
if TYPE_CHECKING:
from collections.abc import Callable
from .Config import Config as BaseConfig
class Config: class Config:
def __init__(self, external: BaseConfig|None=None):
def __init__(self, external: BaseConfig | None = None):
self.__external = external self.__external = external
for attr in ['ldap_uri', 'bind_dn', 'bind_pw', 'base_dn']: for attr in ['ldap_uri', 'bind_dn', 'bind_pw', 'base_dn']:
setattr(self, '_Config__' + attr, None) setattr(self, '_Config__' + attr, None)
def __get(self, key: str, default: str): def __get(self, key: str, default: str | None):
if not self.__external: if not self.__external:
return default return default
return self.__external.value(key, default=default) return self.__external.value(key, default = default)
@property @property
def ldap_uri(self): def ldap_uri(self):
if self.__ldap_uri is None: if self.__ldap_uri is None:
for key in ['ldap_uri', 'uri']: for key in ['ldap_uri', 'uri']:
self.__ldap_uri = self.__get(key, default=None) self.__ldap_uri = self.__get(key, default = None)
if self.__ldap_uri is not None: if self.__ldap_uri is not None:
break break
else: else:
self.__ldap_uri = 'ldap://ldap.janware.com' self.__ldap_uri = 'ldap://ldap.janware.com'
return self.__ldap_uri return self.__ldap_uri
@ldap_uri.setter @ldap_uri.setter
def ldap_uri(self, rhs): def ldap_uri(self, rhs):
self.__ldap_uri = rhs self.__ldap_uri = rhs
@ -38,8 +49,12 @@ class Config:
@property @property
def bind_dn(self): def bind_dn(self):
if self.__bind_dn is None: if self.__bind_dn is None:
self.__bind_dn = self.__get('bind_dn', default=f'uid={getpass.getuser()},ou=users,dc=jannet,dc=de') self.__bind_dn = self.__get(
'bind_dn',
default = f'uid={getpass.getuser()},ou=users,dc=jannet,dc=de'
)
return self.__bind_dn return self.__bind_dn
@bind_dn.setter @bind_dn.setter
def bind_dn(self, rhs): def bind_dn(self, rhs):
self.__bind_dn = rhs self.__bind_dn = rhs
@ -48,17 +63,21 @@ class Config:
def bind_pw(self): def bind_pw(self):
if self.__bind_pw is None: if self.__bind_pw is None:
for key in ['bind_pw', 'password']: for key in ['bind_pw', 'password']:
ret = self.__get(key, default=None) ret = self.__get(key, default = None)
if ret is not None: if ret is not None:
break break
if ret is None: if ret is None:
ldap_secret_file = self.__get('secret_file', f'{pathlib.Path.home()}/.ldap.secret') ldap_secret_file = self.__get(
'secret_file', f'{pathlib.Path.home()}/.ldap.secret'
)
assert ldap_secret_file is not None, 'ldap_secret_file'
with open(ldap_secret_file, 'r') as file: with open(ldap_secret_file, 'r') as file:
ret = file.read() ret = file.read()
file.closed file.closed
ret = ret.strip() ret = ret.strip()
self.__bind_pw = ret self.__bind_pw = ret
return self.__bind_pw return self.__bind_pw
@bind_pw.setter @bind_pw.setter
def bind_pw(self, rhs): def bind_pw(self, rhs):
self.__bind_pw = rhs self.__bind_pw = rhs
@ -66,25 +85,28 @@ class Config:
@property @property
def base_dn(self): def base_dn(self):
if self.__base_dn is None: if self.__base_dn is None:
self.__base_dn = self.__get('base_dn', default=f'dc=jannet,dc=de') self.__base_dn = self.__get('base_dn', default = 'dc=jannet,dc=de')
return self.__base_dn return self.__base_dn
@base_dn.setter @base_dn.setter
def base_dn(self, rhs): def base_dn(self, rhs):
self.__base_dn = rhs self.__base_dn = rhs
class Connection: # export class Connection: # export
class AttrType(Flag): class AttrType(Flag):
Must = auto() Must = auto()
May = auto() May = auto()
def __init__(self, conf: Config|BaseConfig|None=None, backtrace=False): def __init__(self, conf: Config | BaseConfig | None = None, backtrace = False):
uri: str|None = None uri: str | None = None
c = conf if isinstance(conf, Config) else Config(conf) c = conf if isinstance(conf, Config) else Config(conf)
try: try:
uri = c.ldap_uri uri = c.ldap_uri
except: except Exception:
uri = c.uri # mypy says: E: "Config" has no attribute "uri" [attr-defined]
# FIXME: Who adds .uri?
uri = c.uri # type: ignore
try: try:
ret = ldap.initialize(uri) ret = ldap.initialize(uri)
ret.start_tls_s() ret.start_tls_s()
@ -92,46 +114,60 @@ class Connection: # export
slog(ERR, f'Failed to initialize LDAP connection to "{uri}" ({str(e)})') slog(ERR, f'Failed to initialize LDAP connection to "{uri}" ({str(e)})')
raise raise
try: try:
rr = ret.bind_s(c.bind_dn, c.bind_pw) # method) ret.bind_s(c.bind_dn, c.bind_pw) # method)
except Exception as e: except Exception as e:
slog(ERR, f'Failed to bind to "{uri}" with dn "{c.bind_dn}" ({str(e)})') slog(ERR, f'Failed to bind to "{uri}" with dn "{c.bind_dn}" ({str(e)})')
raise raise
self.__ldap = ret self.__ldap = ret
self.__backtrace = backtrace self.__backtrace = backtrace
self.__object_classes_by_oid: dict[str, ObjectClass]|None = None self.__object_classes_by_oid: dict[str, ObjectClass] | None = None
self.__object_class_tree: nx.Graph|None = None self.__object_class_tree: nx.Graph | None = None
self.__object_classes_by_name: dict[str, ObjectClass]|None = None self.__object_classes_by_name: dict[str, ObjectClass] | None = None
@property @property
def ldap(self): def ldap(self):
return self.__ldap return self.__ldap
def add(self, attrs: dict[str, bytes], dn: str|None=None): def add(self, attrs: dict[str, bytes], dn: str | None = None):
if dn is None: if dn is None:
if not 'dn' in attrs: if 'dn' not in attrs:
raise Exception('No DN to add an LDAP entry to') raise Exception('No DN to add an LDAP entry to')
attrs = copy.deepcopy(attrs) attrs = copy.deepcopy(attrs)
del attrs['dn'] del attrs['dn']
try: try:
slog(INFO, f'LDAP: Add [{dn}] -> {attrs}') slog(INFO, f'LDAP: Add [{dn}] -> {attrs}')
self.__ldap.add_s(dn, ldap.modlist.addModlist(attrs)) ml =ldap.modlist.addModlist( # pyright: ignore[reportAttributeAccessIssue]
attrs
)
self.__ldap.add_s(dn, ml)
except Exception as e: except Exception as e:
slog(ERR, f'{dn}: Failed to add entry {attrs} ({e})') slog(ERR, f'{dn}: Failed to add entry {attrs} ({e})')
raise raise
def delete(self, dn: str, recursive=False, force_existence: bool=False): def delete(self, dn: str, recursive = False, force_existence: bool = False):
def __walk_cb_delete(conn: Connection, entry, context): def __walk_cb_delete(conn: Connection, entry, context):
self.walk(__walk_cb_delete, base=entry[0], scope=ldap.SCOPE_ONELEVEL, context=context) self.walk(
__walk_cb_delete,
base = entry[0],
scope = ldap.
SCOPE_ONELEVEL, # pyright: ignore[reportAttributeAccessIssue]
context = context
)
self.__ldap.delete_s(entry[0]) self.__ldap.delete_s(entry[0])
try: try:
if recursive: if recursive:
self.walk(__walk_cb_delete, dn, scope=ldap.SCOPE_ONELEVEL) self.walk(
__walk_cb_delete,
dn,
scope = ldap.
SCOPE_ONELEVEL # pyright: ignore[reportAttributeAccessIssue]
)
self.__ldap.delete_s(dn) self.__ldap.delete_s(dn)
else: else:
self.__ldap.delete_s(dn) self.__ldap.delete_s(dn)
except ldap.NO_SUCH_OBJECT as e: except ldap.NO_SUCH_OBJECT: # pyright: ignore[reportAttributeAccessIssue]
if force_existence: if force_existence:
raise raise
except Exception as e: except Exception as e:
@ -139,37 +175,42 @@ class Connection: # export
raise raise
def walk( def walk(
self, self,
callback: Callable[[Self, Any, Any], None], callback: Callable[[Self, Any, Any], None],
base: str, base: str,
scope, scope,
context=None, context = None,
filterstr=None, filterstr = None,
attrlist=None, attrlist = None,
attrsonly=0, attrsonly = 0,
serverctrls=None, serverctrls = None,
clientctrls=None, clientctrls = None,
timeout=-1, timeout = -1,
sizelimit=0, sizelimit = 0,
decode: bool=False, decode: bool = False,
unroll: bool=False unroll: bool = False
): ):
# TODO: Support ignored arguments # TODO: Support ignored arguments
search_return = self.__ldap.search(base=base, search_return = self.__ldap.search(
scope=scope, base = base,
filterstr=filterstr, scope = scope,
attrlist=attrlist, filterstr = filterstr,
attrsonly=attrsonly) attrlist = attrlist,
attrsonly = attrsonly
)
while True: while True:
result_type, result_data = self.__ldap.result(search_return, 0) result_type, result_data = self.__ldap.result(search_return, 0)
if (result_data == []): if (not result_data):
break break
if result_type != ldap.RES_SEARCH_ENTRY: if result_type != ldap.RES_SEARCH_ENTRY: # pyright: ignore[reportAttributeAccessIssue]
continue continue
for entry in result_data: for entry in result_data:
if decode: if decode:
entry = entry[0], {key: [val.decode() for val in vals] for key, vals in entry[1].items()} entry = entry[0], {
key: [val.decode() for val in vals]
for key, vals in entry[1].items()
}
if unroll and False: if unroll and False:
entry = entry[0], {key: val[0] for key, val in entry[1].items()} entry = entry[0], {key: val[0] for key, val in entry[1].items()}
try: try:
@ -182,19 +223,20 @@ class Connection: # export
slog(WARNING, msg) slog(WARNING, msg)
continue continue
def find(self, def find(
base: str, self,
scope, base: str,
filterstr=None, scope,
attrlist=None, filterstr = None,
attrsonly=0, attrlist = None,
serverctrls=None, attrsonly = 0,
clientctrls=None, serverctrls = None,
timeout=-1, clientctrls = None,
sizelimit=0, timeout = -1,
assert_unique=False, sizelimit = 0,
assert_not_empty=False, assert_unique = False,
): assert_not_empty = False,
):
def __walk_cb_find(conn: Connection, entry: Any, context: Any): def __walk_cb_find(conn: Connection, entry: Any, context: Any):
result.append(entry) result.append(entry)
@ -204,7 +246,13 @@ class Connection: # export
try: try:
result: list[Any] = [] result: list[Any] = []
self.walk(__walk_cb_find, base, scope=scope, filterstr=filterstr, attrlist=attrlist) self.walk(
__walk_cb_find,
base,
scope = scope,
filterstr = filterstr,
attrlist = attrlist
)
except Exception as e: except Exception as e:
slog(ERR, f'Failed search {__search()} ({e})') slog(ERR, f'Failed search {__search()} ({e})')
raise raise
@ -216,17 +264,34 @@ class Connection: # export
@property @property
def object_classes(self) -> dict[str, ObjectClass]: def object_classes(self) -> dict[str, ObjectClass]:
#def object_classes(self): #def object_classes(self):
if self.__object_classes_by_oid is None: if self.__object_classes_by_oid is None:
res = self.find(base='', scope=ldap.SCOPE_BASE, filterstr='(objectClass=*)', attrlist=['subschemaSubentry']) res = self.find(
dn = res[0][1]['subschemaSubentry'][0].decode('utf-8') # Usually yields cn=Subschema base = '',
res = self.find(base=dn, scope=ldap.SCOPE_BASE, filterstr='(objectClass=*)', attrlist=['*', '+']) scope = ldap.SCOPE_BASE, # pyright: ignore[reportAttributeAccessIssue]
filterstr = '(objectClass=*)',
attrlist = ['subschemaSubentry']
)
dn = res[0][1]['subschemaSubentry'][0].decode(
'utf-8'
) # Usually yields cn=Subschema
res = self.find(
base = dn,
scope = ldap.SCOPE_BASE, # pyright: ignore[reportAttributeAccessIssue]
filterstr = '(objectClass=*)',
attrlist = ['*', '+']
)
subschema_entry = res[0] subschema_entry = res[0]
subschema_subentry = ldap.cidict.cidict(subschema_entry[1]) subschema_subentry = ldap.cidict.cidict( # pyright: ignore[reportAttributeAccessIssue]
subschema = ldap.schema.SubSchema(subschema_subentry) subschema_entry[1]
)
subschema = ldap.schema.SubSchema( # pyright: ignore[reportAttributeAccessIssue]
subschema_subentry
)
object_class_oids = subschema.listall(ObjectClass) object_class_oids = subschema.listall(ObjectClass)
self.__object_classes_by_oid = { self.__object_classes_by_oid = {
oid: subschema.get_obj(ObjectClass, oid) for oid in object_class_oids oid: subschema.get_obj(ObjectClass, oid)
for oid in object_class_oids
} }
return self.__object_classes_by_oid return self.__object_classes_by_oid
@ -242,15 +307,18 @@ class Connection: # export
ret[name.lower()] = oc ret[name.lower()] = oc
return self.__object_classes_by_name return self.__object_classes_by_name
def __oc_recurse_to_top(self, cur: str|ObjectClass, cb, context): def __oc_recurse_to_top(self, cur: str | ObjectClass, cb, context):
cur_oc = cur if isinstance(cur, ObjectClass) else self.object_class_by_name[cur.lower()] cur_oc = cur if isinstance(cur, ObjectClass) else self.object_class_by_name[
cur.lower()]
for s in cur_oc.sup: for s in cur_oc.sup:
self.__oc_recurse_to_top(s, cb, context) self.__oc_recurse_to_top(s, cb, context)
cb(cur_oc, context) cb(cur_oc, context)
def object_class_path(self, leaf: str|ObjectClass): def object_class_path(self, leaf: str | ObjectClass):
def cb(oc, context): def cb(oc, context):
ret.append(oc) ret.append(oc)
ret: list[str] = [] ret: list[str] = []
self.__oc_recurse_to_top(leaf, cb, None) self.__oc_recurse_to_top(leaf, cb, None)
return reversed(ret) return reversed(ret)
@ -262,47 +330,55 @@ class Connection: # export
def collect(root, attr): def collect(root, attr):
ret = set() ret = set()
def cb(oc, attr): def cb(oc, attr):
vals = getattr(oc, attr) vals = getattr(oc, attr)
if vals is None: if vals is None:
return return
for val in vals: for val in vals:
ret.add(val) ret.add(val)
self.__oc_recurse_to_top(root, cb, attr) self.__oc_recurse_to_top(root, cb, attr)
return ret return ret
kind = { kind = {0: 'STRUCTURAL', 1: 'ABSTRACT', 2: 'AUXILIARY'}
0: 'STRUCTURAL',
1: 'ABSTRACT',
2: 'AUXILIARY'
}
ret = nx.DiGraph() ret = nx.DiGraph()
for oid, oc in self.object_classes.items(): for oid, oc in self.object_classes.items():
ret.add_node( ret.add_node(
oid, oid,
oid=oid, oid = oid,
name=oc.names[0], name = oc.names[0],
kind=kind[oc.kind], kind = kind[oc.kind],
must=', '.join(collect(oc, 'must')), must = ', '.join(collect(oc, 'must')),
may=', '.join(collect(oc, 'may')) may = ', '.join(collect(oc, 'may'))
) )
for base_class in oc.sup: for base_class in oc.sup:
try: try:
ret.add_edge(oid, self.object_class_by_name[base_class.lower()].oid) ret.add_edge(
oid, self.object_class_by_name[base_class.lower()].oid
)
except Exception as e: except Exception as e:
slog(WARNING, f'Failed to add edge {oid}:{oc.names} -> {base_class} ({e})') slog(
WARNING,
f'Failed to add edge {oid}:{oc.names} -> {base_class} ({e})'
)
self.__object_class_tree = ret self.__object_class_tree = ret
return self.__object_class_tree return self.__object_class_tree
def object_class_attrs(self, oc: str|ObjectClass, required: AttrType = AttrType.Must, origins: bool=False) -> dict[str, set[str]]|set[str]: def object_class_attrs(
self,
oc: str | ObjectClass,
required: AttrType = AttrType.Must,
origins: bool = False
) -> dict[str, set[str]] | set[str]:
all_attrs: set[str] = set() all_attrs: set[str] = set()
attrs_by_origin: dict[str, set[str]] = {} attrs_by_origin: dict[str, set[str]] = {}
for oc in self.object_class_path(oc): for oc in self.object_class_path(oc):
cur = set() cur = set()
if required & self.AttrType.Must: if required & self.AttrType.Must:
cur |= set(oc.must) cur |= set(oc.must) # pyright: ignore[reportAttributeAccessIssue]
if required & self.AttrType.May: if required & self.AttrType.May:
cur |= set(oc.may) cur |= set(oc.may) # pyright: ignore[reportAttributeAccessIssue]
if cur: if cur:
all_attrs |= cur all_attrs |= cur
attrs_by_origin[oc] = cur attrs_by_origin[oc] = cur
@ -313,10 +389,14 @@ class Connection: # export
#base_oid = self.object_class_by_name[base_candidate].oid #base_oid = self.object_class_by_name[base_candidate].oid
#if base_oid in [oc.oid for oc in self.object_class_path(name)]: #if base_oid in [oc.oid for oc in self.object_class_path(name)]:
# return True # return True
return nx.has_path(self.object_class_tree, self.object_class_by_name[name.lower()].oid, self.object_class_by_name[base_candidate.lower()].oid) return nx.has_path(
self.object_class_tree,
self.object_class_by_name[name.lower()].oid,
self.object_class_by_name[base_candidate.lower()].oid
)
def default_config() -> Config: # export def default_config() -> Config: # export
return Config() return Config()
def bind(conf: Config|BaseConfig|None=None) -> Connection: def bind(conf: Config | BaseConfig | None = None) -> Connection:
return Connection(conf) return Connection(conf)

View file

@ -1,22 +1,28 @@
# -*- coding: utf-8 -*- from __future__ import annotations, print_function
from __future__ import print_function import inspect
import re
import sys
import syslog
import unicodedata
from typing import List, Tuple, Optional, Any
import sys, re, io, syslog, inspect, unicodedata
from os.path import basename
from datetime import datetime from datetime import datetime
from os.path import basename
from typing import TYPE_CHECKING, Any, List, Optional, Tuple
from . import misc from . import misc
if TYPE_CHECKING:
import io
# --- python 2 / 3 compatibility stuff # --- python 2 / 3 compatibility stuff
try: try:
basestring # type: ignore basestring # type: ignore
except NameError: except NameError:
basestring = str basestring = str
# fmt: disable # don't conflate
# yapf: disable # don't conflate
_special_chars = { _special_chars = {
'\a' : '\\a', '\a' : '\\a',
'\b' : '\\b', '\b' : '\\b',
@ -26,12 +32,20 @@ _special_chars = {
'\f' : '\\f', '\f' : '\\f',
'\r' : '\\r', '\r' : '\\r',
} }
# yapf: enable
# fmt: enable
_special_char_regex = re.compile("(%s)" % "|".join(map(re.escape, _special_chars.keys()))) _special_char_regex = re.compile(
"(%s)" % "|".join(map(re.escape, _special_chars.keys()))
)
_all_control_chars = ''.join(chr(c) for c in range(sys.maxunicode) if unicodedata.category(chr(c)) in {'Cc'}) _all_control_chars = ''.join(
_clean_str_regex = re.compile(r'(\033\[[0-9]*m|[%s])' % re.escape(_all_control_chars)) chr(c) for c in range(sys.maxunicode) if unicodedata.category(chr(c)) in {'Cc'}
)
_clean_str_regex = re.compile(r'(\033\[[0-9]*m|[%s])' % re.escape(_all_control_chars))
# fmt: disable # don't conflate
# yapf: disable # don't conflate
EMERG = int(syslog.LOG_EMERG) EMERG = int(syslog.LOG_EMERG)
ALERT = int(syslog.LOG_ALERT) ALERT = int(syslog.LOG_ALERT)
CRIT = int(syslog.LOG_CRIT) CRIT = int(syslog.LOG_CRIT)
@ -98,44 +112,49 @@ _prio_colors = {
EMERG : [ CONSOLE_FONT_BOLD + CONSOLE_FONT_MAGENTA, CONSOLE_FONT_OFF ], EMERG : [ CONSOLE_FONT_BOLD + CONSOLE_FONT_MAGENTA, CONSOLE_FONT_OFF ],
} }
# yapf: enable
# fmt: enable
class Stream: class Stream:
def __init__(self, stream, flags): def __init__(self, stream, flags):
self.stream = stream self.stream = stream
self.flags = flags self.flags = flags
_streams: dict[int, Stream] = dict() _streams: dict[int, Stream] = dict()
_stream_descriptors = [reversed(range(1, 16))] _stream_descriptors = list(reversed(range(1, 16)))
def add_capture_stream(stream, flags=0x0): def add_capture_stream(stream, flags = 0x0):
ret = _stream_descriptors.pop() ret = _stream_descriptors.pop()
_streams[ret] = Stream(stream=stream, flags=flags) _streams[ret] = Stream(stream = stream, flags = flags)
return ret return ret
def rm_capture_stream(sd): def rm_capture_stream(sd):
del _streams[sd] del _streams[sd]
_stream_descriptors.append(sd) _stream_descriptors.append(sd)
def prio_gets_logged(prio: int) -> bool: # export def prio_gets_logged(prio: int) -> bool: # export
if prio > _level: if prio > _level:
return False return False
return True return True
def log_level(s: Optional[str]=None) -> int: # export def log_level(s: Optional[str] = None) -> int: # export
if s is None: if s is None:
return _level return _level
return parse_log_prio_str(s) return parse_log_prio_str(s)
def get_caller_pos(up: int = 1, kwargs: Optional[dict[str, Any]] = None) -> Tuple[str, str, int]: def get_caller_pos(up: int = 1,
kwargs: Optional[dict[str, Any]] = None) -> Tuple[str, str, int]:
if kwargs and 'caller' in kwargs: if kwargs and 'caller' in kwargs:
r = kwargs['caller'] r = kwargs['caller']
del kwargs['caller'] del kwargs['caller']
return r return r
caller = inspect.stack()[up+1] caller = inspect.stack()[up + 1]
mod = inspect.getmodule(caller[0]) mod = inspect.getmodule(caller[0])
mod_name = '' if mod is None else mod.__name__ mod_name = '' if mod is None else mod.__name__
return (mod_name, basename(caller.filename), caller.lineno) return (mod_name, basename(caller.filename), caller.lineno)
def slog_m(prio: int, *args, **kwargs) -> None: # export def slog_m(prio: int, *args, **kwargs) -> None: # export
if prio > _level: if prio > _level:
return return
if len(args): if len(args):
@ -151,9 +170,9 @@ def slog_m(prio: int, *args, **kwargs) -> None: # export
caller = kwargs['caller'] caller = kwargs['caller']
del kwargs['caller'] del kwargs['caller']
for line in margs[1:].split('\n'): for line in margs[1:].split('\n'):
slog(prio, line, **kwargs, caller=caller) slog(prio, line, **kwargs, caller = caller)
def slog(prio: int, *args, only_printable: bool=False, **kwargs) -> None: # export def slog(prio: int, *args, only_printable: bool = False, **kwargs) -> None: # export
if prio > _level: if prio > _level:
return return
@ -188,11 +207,13 @@ def slog(prio: int, *args, only_printable: bool=False, **kwargs) -> None: # expo
for a in args: for a in args:
margs += ' ' + str(a) margs += ' ' + str(a)
if only_printable: if only_printable:
margs = _special_char_regex.sub(lambda mo: _special_chars[mo.string[mo.start():mo.end()]], margs) margs = _special_char_regex.sub(
lambda mo: _special_chars[mo.string[mo.start():mo.end()]], margs
)
margs = re.sub('[\x01-\x1f]', '.', margs) margs = re.sub('[\x01-\x1f]', '.', margs)
for file in _log_file_streams: for file in _log_file_streams:
print(msg + _clean_log_prefix + margs, file=file) print(msg + _clean_log_prefix + margs, file = file)
msg += _log_prefix msg += _log_prefix
@ -215,24 +236,26 @@ def slog(prio: int, *args, only_printable: bool=False, **kwargs) -> None: # expo
files.append(sys.stderr) files.append(sys.stderr)
if not len(files): if not len(files):
files = [ sys.stdout ] files = [sys.stdout]
for file in files: for file in files:
print(msg, file=file) print(msg, file = file)
def throw(*args, prio=ERR, caller=None, **kwargs) -> None: def throw(*args, prio = ERR, caller = None, **kwargs) -> None:
if caller is None: if caller is None:
caller = get_caller_pos(1) caller = get_caller_pos(1)
msg = ' '.join([str(arg) for arg in args]) msg = ' '.join([str(arg) for arg in args])
slog(prio, msg, caller=caller) slog(prio, msg, caller = caller)
raise Exception(msg) raise Exception(msg)
def parse_log_prio_str(prio: str) -> int: # export def parse_log_prio_str(prio: str) -> int: # export
try: try:
r = int(prio) r = int(prio)
if r < 0 or r > DEVEL: if r < 0 or r > DEVEL:
raise Exception("Invalid log priority ", prio) raise Exception("Invalid log priority ", prio)
except ValueError: except ValueError:
# fmt: disable # don't conflate
# yapf: disable # don't conflate
map_prio_str_to_val = { map_prio_str_to_val = {
"EMERG" : EMERG, "EMERG" : EMERG,
"emerg" : EMERG, "emerg" : EMERG,
@ -255,23 +278,25 @@ def parse_log_prio_str(prio: str) -> int: # export
"OFF" : OFF, "OFF" : OFF,
"off" : OFF, "off" : OFF,
} }
# yapf: enable
# fmt: enable
if prio in map_prio_str_to_val: if prio in map_prio_str_to_val:
return map_prio_str_to_val[prio] return map_prio_str_to_val[prio]
raise Exception("Unknown priority string \"", prio, "\"") raise Exception("Unknown priority string \"", prio, "\"")
def console_color_chars(prio: int) -> List[str]: # export def console_color_chars(prio: int) -> List[str]: # export
if not sys.stdout.isatty(): if not sys.stdout.isatty():
return [ '', '' ] return ['', '']
return _prio_colors[prio] return _prio_colors[prio]
def set_level(level_: str) -> None: # export def set_level(level_: str) -> None: # export
global _level global _level
if isinstance(level_, basestring): if isinstance(level_, basestring):
_level = parse_log_prio_str(level_) _level = parse_log_prio_str(level_)
return return
_level = level_ _level = level_
def set_flags(flags: str|None) -> str: # export def set_flags(flags: str | None) -> str: # export
global _flags global _flags
ret = ','.join(_flags) ret = ','.join(_flags)
if flags is not None: if flags is not None:
@ -293,7 +318,7 @@ def set_flags(flags: str|None) -> str: # export
#pid #pid
#highlight_first_error #highlight_first_error
def append_to_prefix(prefix: str) -> str: # export def append_to_prefix(prefix: str) -> str: # export
global _log_prefix global _log_prefix
global _clean_log_prefix global _clean_log_prefix
r = _log_prefix r = _log_prefix
@ -302,7 +327,7 @@ def append_to_prefix(prefix: str) -> str: # export
_clean_log_prefix = _clean_str_regex.sub('', _log_prefix) _clean_log_prefix = _clean_str_regex.sub('', _log_prefix)
return r return r
def remove_from_prefix(count) -> str: # export def remove_from_prefix(count) -> str: # export
if isinstance(count, str): if isinstance(count, str):
count = len(count) count = len(count)
global _log_prefix global _log_prefix
@ -312,21 +337,21 @@ def remove_from_prefix(count) -> str: # export
_clean_log_prefix = _clean_str_regex.sub('', _log_prefix) _clean_log_prefix = _clean_str_regex.sub('', _log_prefix)
return r return r
def set_filename_length(l: int) -> int: # export def set_filename_length(length: int) -> int: # export
global _file_name_len global _file_name_len
r = _file_name_len r = _file_name_len
if l: if length:
_file_name_len = l _file_name_len = length
return r return r
def set_module_name_length(l: int) -> int: # export def set_module_name_length(length: int) -> int: # export
global _module_name_len global _module_name_len
r = _module_name_len r = _module_name_len
if l: if length:
_module_name_len = l _module_name_len = length
return r return r
def add_log_file(path: str) -> None: # export def add_log_file(path: str) -> None: # export
global _log_file_streams global _log_file_streams
fd = open(path, 'w', buffering=1) fd = open(path, 'w', buffering = 1)
_log_file_streams.append(fd) _log_file_streams.append(fd)

View file

@ -1,6 +1,11 @@
# -*- coding: utf-8 -*- import atexit
import errno
import os, errno, atexit, tempfile, filecmp, inspect, importlib, re import filecmp
import importlib
import inspect
import os
import re
import tempfile
from typing import Iterable from typing import Iterable
@ -12,12 +17,12 @@ def _cleanup():
for f in _tmpfiles: for f in _tmpfiles:
silentremove(f) silentremove(f)
def silentremove(filename): #export def silentremove(filename): #export
try: try:
os.remove(filename) os.remove(filename)
except OSError as e: except OSError as e:
if e.errno != errno.ENOENT: if e.errno != errno.ENOENT:
raise # re-raise exception if a different error occurred raise # re-raise exception if a different error occurred
def update_symlink(target, link_name): def update_symlink(target, link_name):
try: try:
@ -38,12 +43,14 @@ def pad(token: str, total_size: int, right_align: bool = False) -> str:
return space + token return space + token
return token + space return token + space
def atomic_store(contents, path): # export def atomic_store(contents, path): # export
if path[0:3] == '/dev': if path[0:3] == '/dev':
with open(path, 'w') as outfile: with open(path, 'w') as outfile:
outfile.write(contents) outfile.write(contents)
return return
outfile = tempfile.NamedTemporaryFile(prefix=os.path.basename(path), delete=False, dir=os.path.dirname(path)) outfile = tempfile.NamedTemporaryFile(
prefix = os.path.basename(path), delete = False, dir = os.path.dirname(path)
)
name = outfile.name name = outfile.name
_tmpfiles.add(name) _tmpfiles.add(name)
outfile.write(contents) outfile.write(contents)
@ -52,7 +59,7 @@ def atomic_store(contents, path): # export
_tmpfiles.remove(name) _tmpfiles.remove(name)
# see https://stackoverflow.com/questions/2020014 # see https://stackoverflow.com/questions/2020014
def object_builtin_name(o, full=True): # export def object_builtin_name(o, full = True): # export
#if not full: #if not full:
# return o.__class__.__name__ # return o.__class__.__name__
module = o.__class__.__module__ module = o.__class__.__module__
@ -60,7 +67,7 @@ def object_builtin_name(o, full=True): # export
return o.__class__.__name__ # Avoid reporting __builtin__ return o.__class__.__name__ # Avoid reporting __builtin__
return module + '.' + o.__class__.__name__ return module + '.' + o.__class__.__name__
def get_derived_classes(mod, base, flt=None): # export def get_derived_classes(mod, base, flt = None): # export
members = inspect.getmembers(mod, inspect.isclass) members = inspect.getmembers(mod, inspect.isclass)
r = [] r = []
for name, c in members: for name, c in members:
@ -68,8 +75,10 @@ def get_derived_classes(mod, base, flt=None): # export
if inspect.isabstract(c): if inspect.isabstract(c):
log.slog(log.DEBUG, " is abstract") log.slog(log.DEBUG, " is abstract")
continue continue
if not base in inspect.getmro(c): if base not in inspect.getmro(c):
log.slog(log.DEBUG, " is not derived from", base, "only", inspect.getmro(c)) log.slog(
log.DEBUG, " is not derived from", base, "only", inspect.getmro(c)
)
continue continue
if flt and not re.match(flt, name): if flt and not re.match(flt, name):
log.slog(log.DEBUG, ' "{}.{}" has wrong name'.format(mod, name)) log.slog(log.DEBUG, ' "{}.{}" has wrong name'.format(mod, name))
@ -77,7 +86,7 @@ def get_derived_classes(mod, base, flt=None): # export
r.append(c) r.append(c)
return r return r
def load_classes(path, baseclass, flt=None): # export def load_classes(path, baseclass, flt = None): # export
r = [] r = []
for p in path.split(':'): for p in path.split(':'):
mod = importlib.import_module(path) mod = importlib.import_module(path)
@ -85,22 +94,28 @@ def load_classes(path, baseclass, flt=None): # export
r.extend(get_derived_classes(mod, baseclass, flt)) r.extend(get_derived_classes(mod, baseclass, flt))
return r return r
def load_class(module_path, baseclass, class_name_filter=None): # export def load_class(module_path, baseclass, class_name_filter = None): # export
mod = importlib.import_module(module_path) mod = importlib.import_module(module_path)
classes = get_derived_classes(mod, baseclass, flt=class_name_filter) classes = get_derived_classes(mod, baseclass, flt = class_name_filter)
if len(classes) == 0: if len(classes) == 0:
raise Exception(f'no class matching "{class_name_filter}" of type "{baseclass}" found in module "{module_path}"') raise Exception(
f'no class matching "{class_name_filter}" of type "{baseclass}" '
f'found in module "{module_path}"'
)
if len(classes) > 1: if len(classes) > 1:
raise Exception(f'{len(classes)} classes matching "{class_name_filter}" of type "{baseclass}" found in module "{module_path}"') raise Exception(
f'{len(classes)} classes matching "{class_name_filter}" of type '
f'"{baseclass}" found in module "{module_path}"'
)
return classes[0] return classes[0]
def load_class_names(path, baseclass, flt=None, remove_flt=False): # export def load_class_names(path, baseclass, flt = None, remove_flt = False): # export
classes = load_classes(path, baseclass, flt) classes = load_classes(path, baseclass, flt)
r = [] r = []
for c in classes: for c in classes:
name = c.__name__ name = c.__name__
if flt and remove_flt: if flt and remove_flt:
name = re.subst(flt, "", name) name = re.sub(flt, '', name)
if name not in r: if name not in r:
r.append(name) r.append(name)
else: else:
@ -108,64 +123,72 @@ def load_class_names(path, baseclass, flt=None, remove_flt=False): # export
#log.slog(log.WARNING, "{} is already in in {}".format(name, r)) #log.slog(log.WARNING, "{} is already in in {}".format(name, r))
return r return r
def load_object(module_path, baseclass, class_name_filter=None, *args, **kwargs): # export def load_object( # export
return load_class(module_path, baseclass, class_name_filter=class_name_filter)(*args, **kwargs) module_path, baseclass, class_name_filter = None, *args, **kwargs
):
return load_class(
module_path, baseclass, class_name_filter = class_name_filter
)(*args, **kwargs)
def load_function(module_path, name): # export def load_function(module_path, name): # export
mod = importlib.import_module(module_path) mod = importlib.import_module(module_path)
return getattr(mod, name) return getattr(mod, name)
def commit_tmpfile(tmp: str, path: str) -> None: # export def commit_tmpfile(tmp: str, path: str) -> None: # export
caller = log.get_caller_pos() caller = log.get_caller_pos()
if os.path.isfile(path) and filecmp.cmp(tmp, path): if os.path.isfile(path) and filecmp.cmp(tmp, path):
log.slog(log.INFO, "{} is up to date".format(path), caller=caller) log.slog(log.INFO, "{} is up to date".format(path), caller = caller)
os.unlink(tmp) os.unlink(tmp)
else: else:
log.slog(log.NOTICE, "saving {}".format(path), caller=caller) log.slog(log.NOTICE, "saving {}".format(path), caller = caller)
os.rename(path + '.tmp', path) os.rename(path + '.tmp', path)
def multi_regex_edit(spec, strings): # export def multi_regex_edit(spec, strings): # export
for cmd in spec: for cmd in spec:
if len(cmd) < 2: if len(cmd) < 2:
raise Exception('Invalid command in multi_regex_edit(): {}'.format(str(cmd))) raise Exception(
'Invalid command in multi_regex_edit(): {}'.format(str(cmd))
)
if cmd[0] == 'sub': if cmd[0] == 'sub':
rx = re.compile(cmd[1]) rx = re.compile(cmd[1])
replacement = cmd[2] replacement = cmd[2]
r = [] r = []
for l in strings: for string in strings:
r.append(re.sub(rx, replacement, l)) r.append(re.sub(rx, replacement, string))
strings = r strings = r
continue continue
if cmd[0] == 'del': if cmd[0] == 'del':
rx = re.compile(cmd[1]) rx = re.compile(cmd[1])
r = [] r = []
for l in strings: for string in strings:
if rx.search(l) is not None: if rx.search(string) is not None:
continue continue
r.append(l) r.append(string)
strings = r strings = r
continue continue
if cmd[0] == 'match': if cmd[0] == 'match':
rx = re.compile(cmd[1]) rx = re.compile(cmd[1])
r = [] r = []
for l in strings: for string in strings:
if rx.search(l) is not None: if rx.search(string) is not None:
r.append(l) r.append(string)
strings = r strings = r
continue continue
raise Exception('Invalid command in multi_regex_edit(): {}'.format(str(cmd))) raise Exception('Invalid command in multi_regex_edit(): {}'.format(str(cmd)))
return strings return strings
def dump(prio: int, objects: Iterable, *args, **kwargs) -> None: # export def dump(prio: int, objects: Iterable, *args, **kwargs) -> None: # export
caller = log.get_caller_pos(kwargs=kwargs) caller = log.get_caller_pos(kwargs = kwargs)
log.slog(prio, ",---------- {}".format(' '.join(args)), caller=caller) log.slog(prio, ",---------- {}".format(' '.join(args)), caller = caller)
prefix = " | " prefix = " | "
log.append_to_prefix(prefix) log.append_to_prefix(prefix)
i = 1 i = 1
for o in objects: for o in objects:
o.dump(prio, "{} ({})".format(i, o.__class__.__name__), caller=caller, **kwargs) o.dump(
prio, "{} ({})".format(i, o.__class__.__name__), caller = caller, **kwargs
)
i += 1 i += 1
log.remove_from_prefix(prefix) log.remove_from_prefix(prefix)
log.slog(prio, "`---------- {}".format(' '.join(args)), caller=caller) log.slog(prio, "`---------- {}".format(' '.join(args)), caller = caller)
atexit.register(_cleanup) atexit.register(_cleanup)

View file

@ -1,5 +1,4 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
''' '''
Created on 26 May 2013 Created on 26 May 2013
@ -14,65 +13,74 @@ ___________________________________
Permission is hereby granted, free of charge, to any person obtaining a copy of this Permission is hereby granted, free of charge, to any person obtaining a copy of this
software and associated documentation files (the "Software"), to deal in the Software software and associated documentation files (the "Software"), to deal in the Software
without restriction, including without limitation the rights to use, copy, modify, merge, without restriction, including without limitation the rights to use, copy, modify,
publish, distribute, sub-license, and/or sell copies of the Software, and to permit persons merge, publish, distribute, sub-license, and/or sell copies of the Software, and to
to whom the Software is furnished to do so, subject to the following conditions: permit persons to whom the Software is furnished to do so, subject to the following
conditions:
- The above copyright notice and this permission notice shall be included in all copies - The above copyright notice and this permission notice shall be included in all copies
or substantial portions of the Software. or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT
OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
DEALINGS IN THE SOFTWARE. OTHER DEALINGS IN THE SOFTWARE.
''' '''
import platform import platform
_python3 = int(platform.python_version_tuple()[0]) >= 3 _python3 = int(platform.python_version_tuple()[0]) >= 3
class multi_key_dict(object): class multi_key_dict(object):
""" The purpose of this type is to provide a multi-key dictionary. """
This kind of dictionary has a similar interface to the standard dictionary, and indeed if used The purpose of this type is to provide a multi-key dictionary. This kind of
dictionary has a similar interface to the standard dictionary, and indeed if used
with single key key elements - it's behaviour is the same as for a standard dict(). with single key key elements - it's behaviour is the same as for a standard dict().
However it also allows for creation of elements using multiple keys (using tuples/lists). However it also allows for creation of elements using multiple keys (using
Such elements can be accessed using either of those keys (e.g read/updated/deleted). tuples/lists). Such elements can be accessed using either of those keys (e.g
Dictionary provides also an extended interface for iterating over items and keys by the key type. read/updated/deleted). Dictionary provides also an extended interface for iterating
This can be useful e.g.: when creating dictionaries with (index,name) allowing one to iterate over over items and keys by the key type. This can be useful e.g.: when creating
items using either: names or indexes. It can be useful for many many other similar use-cases, dictionaries with (index,name) allowing one to iterate over items using either:
and there is no limit to the number of keys used to map to the value. names or indexes. It can be useful for many many other similar use-cases, and there
is no limit to the number of keys used to map to the value.
There are also methods to find other keys mapping to the same value as the specified keys etc. There are also methods to find other keys mapping to the same value as the specified
Refer to examples and test code to see it in action. keys etc. Refer to examples and test code to see it in action.
simple example: simple example:
k = multi_key_dict() k = multi_key_dict()
k[100] = 'hundred' # add item to the dictionary (as for normal dictionary) k[100] = 'hundred' # add item to the dictionary (as for normal dictionary)
# but also: # but also:
# below creates entry with two possible key types: int and str, # below creates entry with two possible key types: int and str,
# mapping all keys to the assigned value # mapping all keys to the assigned value
k[1000, 'kilo', 'k'] = 'kilo (x1000)' k[1000, 'kilo', 'k'] = 'kilo (x1000)'
print k[1000] # will print 'kilo (x1000)' print k[1000] # will print 'kilo (x1000)'
print k['k'] # will also print 'kilo (x1000)' print k['k'] # will also print 'kilo (x1000)'
# the same way objects can be updated, and if an object is updated using one key, the new value will # the same way objects can be updated, and if an object is updated using one
# be accessible using any other key, e.g. for example above: # key, the new value will be accessible using any other key, e.g. for example
# above:
k['kilo'] = 'kilo' k['kilo'] = 'kilo'
print k[1000] # will print 'kilo' as value was updated print k[1000] # will print 'kilo' as value was updated
""" """
def __init__(self, mapping_or_iterable=None, **kwargs): def __init__(self, mapping_or_iterable = None, **kwargs):
""" Initializes dictionary from an optional positional argument and a possibly empty set of keyword arguments.""" """ Initializes dictionary from an optional positional argument and a possibly
empty set of keyword arguments."""
self.items_dict = {} self.items_dict = {}
if mapping_or_iterable is not None: if mapping_or_iterable is not None:
if type(mapping_or_iterable) is dict: if type(mapping_or_iterable) is dict:
mapping_or_iterable = mapping_or_iterable.items() mapping_or_iterable = mapping_or_iterable.items()
for kv in mapping_or_iterable: for kv in mapping_or_iterable:
if len(kv) != 2: if len(kv) != 2:
raise Exception('Iterable should contain tuples with exactly two values but specified: {0}.'.format(kv)) raise Exception(
'Iterable should contain tuples with exactly two values '
'but specified: {0}.'.format(kv)
)
self[kv[0]] = kv[1] self[kv[0]] = kv[1]
for keys, value in kwargs.items(): for keys, value in kwargs.items():
self[keys] = value self[keys] = value
@ -83,20 +91,19 @@ class multi_key_dict(object):
def __setitem__(self, keys, value): def __setitem__(self, keys, value):
""" Set the value at index (or list of indexes) specified as keys. """ Set the value at index (or list of indexes) specified as keys.
Note, that if multiple key list is specified, either: Note, that if multiple key list is specified, either:
- none of keys should map to an existing item already (item creation), or - none of keys should map to an existing item already (item creation), or
- all of keys should map to exactly the same item (as previously created) - all of keys should map to exactly the same item (as previously created)
(item update) (item update)
If this is not the case - KeyError is raised. """ If this is not the case - KeyError is raised. """
if(type(keys) in [tuple, list]): if (type(keys) in [tuple, list]):
at_least_one_key_exists = False
num_of_keys_we_have = 0 num_of_keys_we_have = 0
for x in keys: for x in keys:
try: try:
self.__getitem__(x) self.__getitem__(x)
num_of_keys_we_have += 1 num_of_keys_we_have += 1
except Exception as err: except Exception:
continue continue
if num_of_keys_we_have: if num_of_keys_we_have:
@ -112,36 +119,37 @@ class multi_key_dict(object):
if new != direct_key: if new != direct_key:
all_select_same_item = False all_select_same_item = False
break break
except Exception as err: except Exception:
all_select_same_item = False all_select_same_item = False
break; break
if not all_select_same_item: if not all_select_same_item:
raise KeyError(', '.join(str(key) for key in keys)) raise KeyError(', '.join(str(key) for key in keys))
first_key = keys[0] # combination if keys is allowed, simply use the first one first_key = keys[
0] # combination if keys is allowed, simply use the first one
else: else:
first_key = keys first_key = keys
key_type = str(type(first_key)) # find the intermediate dictionary.. key_type = str(type(first_key)) # find the intermediate dictionary..
if first_key in self: if first_key in self:
self.items_dict[self.__dict__[key_type][first_key]] = value # .. and update the object if it exists.. self.items_dict[self.__dict__[key_type][first_key]
] = value # .. and update the object if it exists..
else: else:
if(type(keys) not in [tuple, list]): if (type(keys) not in [tuple, list]):
key = keys key = keys
keys = [keys] keys = [keys]
self.__add_item(value, keys) # .. or create it - if it doesn't self.__add_item(value, keys) # .. or create it - if it doesn't
def __delitem__(self, key): def __delitem__(self, key):
""" Called to implement deletion of self[key].""" """ Called to implement deletion of self[key]."""
key_type = str(type(key)) key_type = str(type(key))
if (key in self and if (key in self and self.items_dict
self.items_dict and and (self.__dict__[key_type][key] in self.items_dict)):
(self.__dict__[key_type][key] in self.items_dict) ):
intermediate_key = self.__dict__[key_type][key] intermediate_key = self.__dict__[key_type][key]
# remove the item in main dictionary # remove the item in main dictionary
del self.items_dict[intermediate_key] del self.items_dict[intermediate_key]
# and remove all references (if there were other keys) # and remove all references (if there were other keys)
@ -166,10 +174,10 @@ class multi_key_dict(object):
""" Returns True if this object contains an item referenced by the key.""" """ Returns True if this object contains an item referenced by the key."""
return key in self return key in self
def get_other_keys(self, key, including_current=False): def get_other_keys(self, key, including_current = False):
""" Returns list of other keys that are mapped to the same value as specified key. """ Returns list of other keys that are mapped to the same value as specified
@param key - key for which other keys should be returned. key. @param key - key for which other keys should be returned. @param
@param including_current if set to True - key will also appear on this list.""" including_current if set to True - key will also appear on this list."""
other_keys = [] other_keys = []
if key in self: if key in self:
other_keys.extend(self.__dict__[str(type(key))][key]) other_keys.extend(self.__dict__[str(type(key))][key])
@ -177,12 +185,17 @@ class multi_key_dict(object):
other_keys.remove(key) other_keys.remove(key)
return other_keys return other_keys
def iteritems(self, key_type=None, return_all_keys=False): def iteritems(self, key_type = None, return_all_keys = False):
""" Returns an iterator over the dictionary's (key, value) pairs. """ Returns an iterator over the dictionary's (key, value) pairs.
@param key_type if specified, iterator will be returning only (key,value) pairs for this type of key.
Otherwise (if not specified) ((keys,...), value) @param key_type if specified, iterator will be returning only (key,value)
i.e. (tuple of keys, values) pairs for all items in this dictionary will be generated. pairs for this type of key.
@param return_all_keys if set to True - tuple of keys is retuned instead of a key of this type."""
Otherwise (if not specified) ((keys,...), value) i.e. (tuple of keys,
values) pairs for all items in this dictionary will be generated.
@param return_all_keys if set to True - tuple of keys is retuned instead of
a key of this type."""
if key_type is None: if key_type is None:
for item in self.items_dict.items(): for item in self.items_dict.items():
@ -200,29 +213,34 @@ class multi_key_dict(object):
keys = tuple(k for k in keys if isinstance(k, key_type)) keys = tuple(k for k in keys if isinstance(k, key_type))
yield keys, value yield keys, value
def iterkeys(self, key_type=None, return_all_keys=False): def iterkeys(self, key_type = None, return_all_keys = False):
""" Returns an iterator over the dictionary's keys. """ Returns an iterator over the dictionary's keys.
@param key_type if specified, iterator for a dictionary of this type will be used. @param key_type if specified, iterator for a dictionary of this type will
be used.
Otherwise (if not specified) tuples containing all (multiple) keys Otherwise (if not specified) tuples containing all (multiple) keys
for this dictionary will be generated. for this dictionary will be generated.
@param return_all_keys if set to True - tuple of keys is retuned instead of a key of this type.""" @param return_all_keys if set to True - tuple of keys is retuned instead of
if(key_type is not None): a key of this type."""
if (key_type is not None):
the_key = str(key_type) the_key = str(key_type)
if the_key in self.__dict__: if the_key in self.__dict__:
for key in self.__dict__[the_key].keys(): for key in self.__dict__[the_key].keys():
if return_all_keys: if return_all_keys:
yield self.__dict__[the_key][key] yield self.__dict__[the_key][key]
else: else:
yield key yield key
else: else:
for keys in self.items_dict.keys(): for keys in self.items_dict.keys():
yield keys yield keys
def itervalues(self, key_type=None): def itervalues(self, key_type = None):
""" Returns an iterator over the dictionary's values. """ Returns an iterator over the dictionary's values.
@param key_type if specified, iterator will be returning only values pointed by keys of this type. @param key_type if specified, iterator will be returning only values pointed
Otherwise (if not specified) all values in this dictinary will be generated.""" by keys of this type.
if(key_type is not None): Otherwise (if not specified) all values in this dictinary will be
generated."""
if (key_type is not None):
intermediate_key = str(key_type) intermediate_key = str(key_type)
if intermediate_key in self.__dict__: if intermediate_key in self.__dict__:
for direct_key in self.__dict__[intermediate_key].values(): for direct_key in self.__dict__[intermediate_key].values():
@ -232,37 +250,42 @@ class multi_key_dict(object):
yield value yield value
if _python3: if _python3:
items = iteritems items = iteritems # type: ignore
else: else:
def items(self, key_type=None, return_all_keys=False):
def items(self, key_type = None, return_all_keys = False):
return list(self.iteritems(key_type, return_all_keys)) return list(self.iteritems(key_type, return_all_keys))
items.__doc__ = iteritems.__doc__ items.__doc__ = iteritems.__doc__
def keys(self, key_type=None): def keys(self, key_type = None):
""" Returns a copy of the dictionary's keys. """ Returns a copy of the dictionary's keys.
@param key_type if specified, only keys for this type will be returned. @param key_type if specified, only keys for this type will be returned.
Otherwise list of tuples containing all (multiple) keys will be returned.""" Otherwise list of tuples containing all (multiple) keys will be
returned."""
if key_type is not None: if key_type is not None:
intermediate_key = str(key_type) intermediate_key = str(key_type)
if intermediate_key in self.__dict__: if intermediate_key in self.__dict__:
return self.__dict__[intermediate_key].keys() return self.__dict__[intermediate_key].keys()
else: else:
all_keys = {} # in order to preserve keys() type (dict_keys for python3) all_keys = {} # in order to preserve keys() type (dict_keys for python3)
for keys in self.items_dict.keys(): for keys in self.items_dict.keys():
all_keys[keys] = None all_keys[keys] = None
return all_keys.keys() return all_keys.keys()
def values(self, key_type=None): def values(self, key_type = None):
""" Returns a copy of the dictionary's values. """ Returns a copy of the dictionary's values.
@param key_type if specified, only values pointed by keys of this type will be returned. @param key_type if specified, only values pointed by keys of this type
Otherwise list of all values contained in this dictionary will be returned.""" will be returned
if(key_type is not None): Otherwise list of all values contained in this dictionary will be
all_items = {} # in order to preserve keys() type (dict_values for python3) returned."""
if (key_type is not None):
all_items = {} # in order to preserve keys() type (dict_values for python3)
keys_used = set() keys_used = set()
direct_key = str(key_type) direct_key = str(key_type)
if direct_key in self.__dict__: if direct_key in self.__dict__:
for intermediate_key in self.__dict__[direct_key].values(): for intermediate_key in self.__dict__[direct_key].values():
if not intermediate_key in keys_used: if intermediate_key not in keys_used:
all_items[intermediate_key] = self.items_dict[intermediate_key] all_items[intermediate_key] = self.items_dict[intermediate_key]
keys_used.add(intermediate_key) keys_used.add(intermediate_key)
return all_items.values() return all_items.values()
@ -276,26 +299,29 @@ class multi_key_dict(object):
length = len(self.items_dict) length = len(self.items_dict)
return length return length
def __add_item(self, item, keys=None): def __add_item(self, item, keys = None):
""" Internal method to add an item to the multi-key dictionary""" """ Internal method to add an item to the multi-key dictionary"""
if(not keys or not len(keys)): if (not keys or not len(keys)):
raise Exception('Error in %s.__add_item(%s, keys=tuple/list of items): need to specify a tuple/list containing at least one key!' raise Exception(
% (self.__class__.__name__, str(item))) 'Error in %s.__add_item(%s, keys=tuple/list of items): need to specify'
direct_key = tuple(keys) # put all keys in a tuple, and use it as a key 'a tuple/list containing at least one key!' %
(self.__class__.__name__, str(item))
)
direct_key = tuple(keys) # put all keys in a tuple, and use it as a key
for key in keys: for key in keys:
key_type = str(type(key)) key_type = str(type(key))
# store direct key as a value in an intermediate dictionary # store direct key as a value in an intermediate dictionary
if(not key_type in self.__dict__): if (key_type not in self.__dict__):
self.__setattr__(key_type, dict()) self.__setattr__(key_type, dict())
self.__dict__[key_type][key] = direct_key self.__dict__[key_type][key] = direct_key
# store the value in the actual dictionary # store the value in the actual dictionary
if(not 'items_dict' in self.__dict__): if ('items_dict' not in self.__dict__):
self.items_dict = dict() self.items_dict = dict()
self.items_dict[direct_key] = item self.items_dict[direct_key] = item
def get(self, key, default=None): def get(self, key, default = None):
""" Return the value at index specified as key.""" """ Return the value at index specified as key."""
if key in self: if key in self:
return self.items_dict[self.__dict__[str(type(key))][key]] return self.items_dict[self.__dict__[str(type(key))][key]]
@ -304,74 +330,91 @@ class multi_key_dict(object):
def __str__(self): def __str__(self):
items = [] items = []
str_repr = lambda x: '\'%s\'' % x if type(x) == str else str(x)
def str_repr(x):
return '\'%s\'' % x if isinstance(x, str) else str(x)
if hasattr(self, 'items_dict'): if hasattr(self, 'items_dict'):
for (keys, value) in self.items(): for (keys, value) in self.items():
keys_str = [str_repr(k) for k in keys] keys_str = [str_repr(k) for k in keys]
items.append('(%s): %s' % (', '.join(keys_str), items.append('(%s): %s' % (', '.join(keys_str), str_repr(value)))
str_repr(value))) dict_str = '{%s}' % (', '.join(items))
dict_str = '{%s}' % ( ', '.join(items))
return dict_str return dict_str
def test_multi_key_dict(): def test_multi_key_dict():
contains_all = lambda cont, in_items: not (False in [c in cont for c in in_items])
def contains_all(cont, in_items):
return False not in [c in cont for c in in_items]
m = multi_key_dict() m = multi_key_dict()
assert( len(m) == 0 ), 'expected len(m) == 0' assert (len(m) == 0), 'expected len(m) == 0'
all_keys = list() all_keys = list()
m['aa', 12, 32, 'mmm'] = 123 # create a value with multiple keys.. m['aa', 12, 32, 'mmm'] = 123 # create a value with multiple keys..
assert( len(m) == 1 ), 'expected len(m) == 1' assert (len(m) == 1), 'expected len(m) == 1'
all_keys.append(('aa', 'mmm', 32, 12)) # store it for later all_keys.append(('aa', 'mmm', 32, 12)) # store it for later
# try retrieving other keys mapped to the same value using one of them # try retrieving other keys mapped to the same value using one of them
res = m.get_other_keys('aa') res = m.get_other_keys('aa')
expected = ['mmm', 32, 12] expected = ['mmm', 32, 12]
assert(set(res) == set(expected)), 'get_other_keys(\'aa\'): {0} other than expected: {1} '.format(res, expected)
# try retrieving other keys mapped to the same value using one of them: also include this key assert (set(res) == set(expected)), (
'get_other_keys(\'aa\'): {0} other '
'than expected: {1} '.format(res, expected)
)
# try retrieving other keys mapped to the same value using one of them: also include
# this key
res = m.get_other_keys(32, True) res = m.get_other_keys(32, True)
expected = ['aa', 'mmm', 32, 12] expected = ['aa', 'mmm', 32, 12]
assert(set(res) == set(expected)), 'get_other_keys(32): {0} other than expected: {1} '.format(res, expected) assert (set(res) == set(expected)), (
'get_other_keys(32): {0} other than expected: '
'{1} '.format(res, expected)
)
assert( m.has_key('aa') == True ), 'expected m.has_key(\'aa\') == True' assert (m.has_key('aa')), 'expected m.has_key(\'aa\') == True'
assert( m.has_key('aab') == False ), 'expected m.has_key(\'aab\') == False' assert (not m.has_key('aab')), 'expected m.has_key(\'aab\') == False'
assert( m.has_key(12) == True ), 'expected m.has_key(12) == True' assert (m.has_key(12)), 'expected m.has_key(12) == True'
assert( m.has_key(13) == False ), 'expected m.has_key(13) == False' assert (not m.has_key(13)), 'expected m.has_key(13) == False'
assert( m.has_key(32) == True ), 'expected m.has_key(32) == True' assert (m.has_key(32)), 'expected m.has_key(32) == True'
m['something else'] = 'abcd' m['something else'] = 'abcd'
assert( len(m) == 2 ), 'expected len(m) == 2' assert (len(m) == 2), 'expected len(m) == 2'
all_keys.append(('something else',)) # store for later all_keys.append(('something else', )) # store for later
m[23] = 0 m[23] = 0
assert( len(m) == 3 ), 'expected len(m) == 3' assert (len(m) == 3), 'expected len(m) == 3'
all_keys.append((23,)) # store for later all_keys.append((23, )) # store for later
# check if it's possible to read this value back using either of keys # check if it's possible to read this value back using either of keys
assert( m['aa'] == 123 ), 'expected m[\'aa\'] == 123' assert (m['aa'] == 123), 'expected m[\'aa\'] == 123'
assert( m[12] == 123 ), 'expected m[12] == 123' assert (m[12] == 123), 'expected m[12] == 123'
assert( m[32] == 123 ), 'expected m[32] == 123' assert (m[32] == 123), 'expected m[32] == 123'
assert( m['mmm'] == 123 ), 'expected m[\'mmm\'] == 123' assert (m['mmm'] == 123), 'expected m[\'mmm\'] == 123'
# now update value and again - confirm it back - using different keys.. # now update value and again - confirm it back - using different keys..
m['aa'] = 45 m['aa'] = 45
assert( m['aa'] == 45 ), 'expected m[\'aa\'] == 45' assert (m['aa'] == 45), 'expected m[\'aa\'] == 45'
assert( m[12] == 45 ), 'expected m[12] == 45' assert (m[12] == 45), 'expected m[12] == 45'
assert( m[32] == 45 ), 'expected m[32] == 45' assert (m[32] == 45), 'expected m[32] == 45'
assert( m['mmm'] == 45 ), 'expected m[\'mmm\'] == 45' assert (m['mmm'] == 45), 'expected m[\'mmm\'] == 45'
m[12] = '4' m[12] = '4'
assert( m['aa'] == '4' ), 'expected m[\'aa\'] == \'4\'' assert (m['aa'] == '4'), 'expected m[\'aa\'] == \'4\''
assert( m[12] == '4' ), 'expected m[12] == \'4\'' assert (m[12] == '4'), 'expected m[12] == \'4\''
# test __str__ # test __str__
m_str_exp = '{(23): 0, (\'aa\', \'mmm\', 32, 12): \'4\', (\'something else\'): \'abcd\'}' m_str_exp = (
'{(23): 0, (\'aa\', \'mmm\', 32, 12): \'4\', '
'(\'something else\'): \'abcd\'}'
)
m_str = str(m) m_str = str(m)
assert (len(m_str) > 0), 'str(m) should not be empty!' assert (len(m_str) > 0), 'str(m) should not be empty!'
assert (m_str[0] == '{'), 'str(m) should start with \'{\', but does with \'%c\'' % m_str[0] assert (m_str[0] == '{'
assert (m_str[-1] == '}'), 'str(m) should end with \'}\', but does with \'%c\'' % m_str[-1] ), ('str(m) should start with \'{\', but does with \'%c\'' % m_str[0])
assert (m_str[-1] == '}'
), ('str(m) should end with \'}\', but does with \'%c\'' % m_str[-1])
# check if all key-values are there as expected. They might be sorted differently # check if all key-values are there as expected. They might be sorted differently
def get_values_from_str(dict_str): def get_values_from_str(dict_str):
@ -381,41 +424,52 @@ def test_multi_key_dict():
keys = tuple(sorted([k.strip() for k in keys.split(',')])) keys = tuple(sorted([k.strip() for k in keys.split(',')]))
sorted_keys_and_values.append((keys, val)) sorted_keys_and_values.append((keys, val))
return sorted_keys_and_values return sorted_keys_and_values
exp = get_values_from_str(m_str_exp) exp = get_values_from_str(m_str_exp)
act = get_values_from_str(m_str) act = get_values_from_str(m_str)
assert (set(act) == set(exp)), 'str(m) values: \'{0}\' are not {1} '.format(act, exp) assert (set(act) == set(exp)
), ('str(m) values: \'{0}\' are not {1} '.format(act, exp))
# try accessing / creating new (keys)-> value mapping whilst one of these # try accessing / creating new (keys)-> value mapping whilst one of these
# keys already maps to a value in this dictionaries # keys already maps to a value in this dictionaries
try: try:
m['aa', 'bb'] = 'something new' m['aa', 'bb'] = 'something new'
assert(False), 'Should not allow adding multiple-keys when one of keys (\'aa\') already exists!' assert(False), (
except KeyError as err: 'Should not allow adding multiple-keys when one of keys '
'(\'aa\') already exists!'
)
except KeyError:
pass pass
# now check if we can get all possible keys (formed in a list of tuples) # now check if we can get all possible keys (formed in a list of tuples)
# each tuple containing all keys) # each tuple containing all keys)
res = sorted([sorted([str(x) for x in k]) for k in m.keys()]) res = sorted([sorted([str(x) for x in k]) for k in m.keys()]) # type: ignore
expected = sorted([sorted([str(x) for x in k]) for k in all_keys]) expected = sorted([sorted([str(x) for x in k]) for k in all_keys])
assert(res == expected), 'unexpected values from m.keys(), got:\n%s\n expected:\n%s' %(res, expected) assert (res == expected), (
'unexpected values from m.keys(), got:\n%s\n expected:\n%s' % (res, expected)
)
# check default items (which will unpack tupe with key(s) and value) # check default items (which will unpack tupe with key(s) and value)
num_of_elements = 0 num_of_elements = 0
for keys, value in m.items(): for keys, value in m.items():
sorted_keys = sorted([str(k) for k in keys]) sorted_keys = sorted([str(k) for k in keys])
num_of_elements += 1 num_of_elements += 1
assert(sorted_keys in expected), 'm.items(): unexpected keys: %s' % (sorted_keys) assert (sorted_keys
assert(m[keys[0]] == value), 'm.items(): unexpected value: %s (keys: %s)' % (value, keys) in expected), ('m.items(): unexpected keys: %s' % (sorted_keys))
assert(num_of_elements > 0), 'm.items() returned generator that did not produce anything' assert (m[keys[0]] == value
), ('m.items(): unexpected value: %s (keys: %s)' % (value, keys))
assert (num_of_elements
> 0), ('m.items() returned generator that did not produce anything')
# test default iterkeys() # test default iterkeys()
num_of_elements = 0 num_of_elements = 0
for keys in m.keys(): for keys in m.keys(): # type: ignore
num_of_elements += 1 num_of_elements += 1
keys_s = sorted([str(k) for k in keys]) keys_s = sorted([str(k) for k in keys])
assert(keys_s in expected), 'm.keys(): unexpected keys: {0}'.format(keys_s) assert (keys_s in expected), 'm.keys(): unexpected keys: {0}'.format(keys_s)
assert(num_of_elements > 0), 'm.iterkeys() returned generator that did not produce anything' assert (num_of_elements
> 0), ('m.iterkeys() returned generator that did not produce anything')
# test iterkeys(int, True): useful to get all info from the dictionary # test iterkeys(int, True): useful to get all info from the dictionary
# dictionary is iterated over the type specified, but all keys are returned. # dictionary is iterated over the type specified, but all keys are returned.
@ -423,75 +477,93 @@ def test_multi_key_dict():
for keys in m.iterkeys(int, True): for keys in m.iterkeys(int, True):
keys_s = sorted([str(k) for k in keys]) keys_s = sorted([str(k) for k in keys])
num_of_elements += 1 num_of_elements += 1
assert(keys_s in expected), 'm.iterkeys(int, True): unexpected keys: {0}'.format(keys_s) assert (keys_s in expected
assert(num_of_elements > 0), 'm.iterkeys(int, True) returned generator that did not produce anything' ), ('m.iterkeys(int, True): unexpected keys: {0}'.format(keys_s))
assert (num_of_elements > 0), (
'm.iterkeys(int, True) returned generator that did not produce anything'
)
# test values for different types of keys() # test values for different types of keys()
expected = set([0, '4']) expected = set([0, '4'])
res = set(m.values(int)) res = set(m.values(int))
assert (res == expected), 'm.values(int) are {0}, but expected: {1}.'.format(res, expected) assert (res == expected
), ('m.values(int) are {0}, but expected: {1}.'.format(res, expected))
expected = sorted(['4', 'abcd']) expected = sorted(['4', 'abcd'])
res = sorted(m.values(str)) res = sorted(m.values(str))
assert (res == expected), 'm.values(str) are {0}, but expected: {1}.'.format(res, expected) assert (res == expected
), ('m.values(str) are {0}, but expected: {1}.'.format(res, expected))
current_values = set([0, '4', 'abcd']) # default (should give all values) current_values = set([0, '4', 'abcd']) # default (should give all values)
res = set(m.values()) res = set(m.values())
assert (res == current_values), 'm.values() are {0}, but expected: {1}.'.format(res, current_values) assert (res == current_values
), ('m.values() are {0}, but expected: {1}.'.format(res, current_values))
#test itervalues() (default) - should return all values. (Itervalues for other types
# are tested below)
#test itervalues() (default) - should return all values. (Itervalues for other types are tested below)
vals = set() vals = set()
for value in m.itervalues(): for value in m.itervalues():
vals.add(value) vals.add(value)
assert (current_values == vals), 'itervalues(): expected {0}, but collected {1}'.format(current_values, vals) assert (current_values == vals), (
'itervalues(): expected {0}, but collected {1}'.format(current_values, vals)
)
#test items(int) #test items(int)
items_for_int = sorted([((12, 32), '4'), ((23,), 0)]) items_for_int = sorted([((12, 32), '4'), ((23, ), 0)])
assert (items_for_int == sorted(m.items(int))), 'items(int): expected {0}, but collected {1}'.format(items_for_int, assert (items_for_int == sorted(m.items(int))), (
sorted(m.items(int))) 'items(int): expected {0}, but collected {1}'.format(
items_for_int, sorted(m.items(int))
)
)
# test items(str) # test items(str)
items_for_str = set([(('aa','mmm'), '4'), (('something else',), 'abcd')]) items_for_str = set([(('aa', 'mmm'), '4'), (('something else', ), 'abcd')])
res = set(m.items(str)) res = set(m.items(str))
assert (set(res) == items_for_str), 'items(str): expected {0}, but collected {1}'.format(items_for_str, res) assert (set(res) == items_for_str), (
'items(str): expected {0}, but collected {1}'.format(items_for_str, res)
)
# test items() (default - all items) # test items() (default - all items)
# we tested keys(), values(), and __get_item__ above so here we'll re-create all_items using that # we tested keys(), values(), and __get_item__ above so here we'll re-create
# all_items using that
all_items = set() all_items = set()
keys = m.keys() keys = m.keys()
values = m.values() m.values()
for k in keys: for k in keys: # type: ignore
all_items.add( (tuple(k), m[k[0]]) ) all_items.add((tuple(k), m[k[0]]))
res = set(m.items()) res = set(m.items())
assert (all_items == res), 'items() (all items): expected {0},\n\t\t\t\tbut collected {1}'.format(all_items, res) assert (all_items == res), (
'items() (all items): expected {0},\n\t\t\t\tbut '
'collected {1}'.format(all_items, res)
)
# now test deletion.. # now test deletion..
curr_len = len(m) curr_len = len(m)
del m[12] del m[12]
assert( len(m) == curr_len - 1 ), 'expected len(m) == %d' % (curr_len - 1) assert (len(m) == curr_len - 1), 'expected len(m) == %d' % (curr_len - 1)
assert(not m.has_key(12)), 'expected deleted key to no longer be found!' assert (not m.has_key(12)), 'expected deleted key to no longer be found!'
# try again # try again
try: try:
del m['aa'] del m['aa']
assert(False), 'cant remove again: item m[\'aa\'] should not exist!' assert (False), 'cant remove again: item m[\'aa\'] should not exist!'
except KeyError as err: except KeyError:
pass pass
# try to access non-existing # try to access non-existing
try: try:
k = m['aa'] k = m['aa']
assert(False), 'removed item m[\'aa\'] should not exist!' assert (False), 'removed item m[\'aa\'] should not exist!'
except KeyError as err: except KeyError:
pass pass
# try to access non-existing with a different key # try to access non-existing with a different key
try: try:
k = m[12] k = m[12]
assert(False), 'removed item m[12] should not exist!' assert (False), 'removed item m[12] should not exist!'
except KeyError as err: except KeyError:
pass pass
# prepare for other tests (also testing creation of new items) # prepare for other tests (also testing creation of new items)
@ -499,11 +571,12 @@ def test_multi_key_dict():
m = multi_key_dict() m = multi_key_dict()
tst_range = list(range(10, 40)) + list(range(50, 70)) tst_range = list(range(10, 40)) + list(range(50, 70))
for i in tst_range: for i in tst_range:
m[i] = i # will create a dictionary, where keys are same as items m[i] = i # will create a dictionary, where keys are same as items
# test items() # test items()
for key, value in m.items(int): for key, value in m.items(int):
assert(key == (value,)), 'items(int): expected {0}, but received {1}'.format(key, value) assert (key == (value, )
), ('items(int): expected {0}, but received {1}'.format(key, value))
# test iterkeys() # test iterkeys()
num_of_elements = 0 num_of_elements = 0
@ -511,88 +584,103 @@ def test_multi_key_dict():
for key in m.iterkeys(int): for key in m.iterkeys(int):
returned_keys.add(key) returned_keys.add(key)
num_of_elements += 1 num_of_elements += 1
assert(num_of_elements > 0), 'm.iteritems(int) returned generator that did not produce anything' assert (num_of_elements
assert (returned_keys == set(tst_range)), 'iterkeys(int): expected {0}, but received {1}'.format(expected, key) > 0), ('m.iteritems(int) returned generator that did not produce anything')
assert (returned_keys == set(tst_range)
), ('iterkeys(int): expected {0}, but received {1}'.format(expected, key))
#test itervalues(int) #test itervalues(int)
num_of_elements = 0 num_of_elements = 0
returned_values = set() returned_values = set()
for value in m.itervalues(int): for value in m.itervalues(int):
returned_values.add(value) returned_values.add(value)
num_of_elements += 1 num_of_elements += 1
assert (num_of_elements > 0), 'm.itervalues(int) returned generator that did not produce anything' assert (num_of_elements > 0
assert (returned_values == set(tst_range)), 'itervalues(int): expected {0}, but received {1}'.format(expected, value) ), ('m.itervalues(int) returned generator that did not produce anything')
assert (returned_values == set(tst_range)), (
'itervalues(int): expected {0}, '
'but received {1}'.format(expected, value)
)
# test values(int) # test values(int)
res = sorted([x for x in m.values(int)]) res = sorted([x for x in m.values(int)])
assert (res == tst_range), 'm.values(int) is not as expected.' assert (res == tst_range), 'm.values(int) is not as expected.'
# test keys() # test keys()
assert (set(m.keys(int)) == set(tst_range)), 'm.keys(int) is not as expected.' assert (set(m.keys(int)) == set(tst_range)), 'm.keys(int) is not as expected.' # type: ignore
# test setitem with multiple keys # test setitem with multiple keys
m['xy', 999, 'abcd'] = 'teststr' m['xy', 999, 'abcd'] = 'teststr'
try: try:
m['xy', 998] = 'otherstr' m['xy', 998] = 'otherstr'
assert(False), 'creating / updating m[\'xy\', 998] should fail!' assert (False), 'creating / updating m[\'xy\', 998] should fail!'
except KeyError as err: except KeyError:
pass pass
# test setitem with multiple keys # test setitem with multiple keys
m['cd'] = 'somethingelse' m['cd'] = 'somethingelse'
try: try:
m['cd', 999] = 'otherstr' m['cd', 999] = 'otherstr'
assert(False), 'creating / updating m[\'cd\', 999] should fail!' assert (False), 'creating / updating m[\'cd\', 999] should fail!'
except KeyError as err: except KeyError:
pass pass
m['xy', 999] = 'otherstr' m['xy', 999] = 'otherstr'
assert (m['xy'] == 'otherstr'), 'm[\'xy\'] is not as expected.' assert (m['xy'] == 'otherstr'), 'm[\'xy\'] is not as expected.'
assert (m[999] == 'otherstr'), 'm[999] is not as expected.' assert (m[999] == 'otherstr'), 'm[999] is not as expected.'
assert (m['abcd'] == 'otherstr'), 'm[\'abcd\'] is not as expected.' assert (m['abcd'] == 'otherstr'), 'm[\'abcd\'] is not as expected.'
m['abcd', 'xy'] = 'another' m['abcd', 'xy'] = 'another'
assert (m['xy'] == 'another'), 'm[\'xy\'] is not == \'another\'.' assert (m['xy'] == 'another'), 'm[\'xy\'] is not == \'another\'.'
assert (m[999] == 'another'), 'm[999] is not == \'another\'' assert (m[999] == 'another'), 'm[999] is not == \'another\''
assert (m['abcd'] == 'another'), 'm[\'abcd\'] is not == \'another\'.' assert (m['abcd'] == 'another'), 'm[\'abcd\'] is not == \'another\'.'
# test get functionality of basic dictionaries # test get functionality of basic dictionaries
m['CanIGet'] = 'yes' m['CanIGet'] = 'yes'
assert (m.get('CanIGet') == 'yes') assert (m.get('CanIGet') == 'yes')
assert (m.get('ICantGet') == None) assert (m.get('ICantGet') is None)
assert (m.get('ICantGet', "Ok") == "Ok") assert (m.get('ICantGet', "Ok") == "Ok")
k = multi_key_dict() k = multi_key_dict()
k['1:12', 1] = 'key_has_:' k['1:12', 1] = 'key_has_:'
k.items() # should not cause any problems to have : in key k.items() # should not cause any problems to have : in key
assert (k[1] == 'key_has_:'), 'k[1] is not equal to \'abc:def:ghi\'' assert (k[1] == 'key_has_:'), 'k[1] is not equal to \'abc:def:ghi\''
import datetime import datetime
n = datetime.datetime.now() n = datetime.datetime.now()
l = multi_key_dict() d = multi_key_dict()
l[n] = 'now' # use datetime obj as a key d[n] = 'now' # use datetime obj as a key
#test keys.. #test keys..
res = [x for x in l.keys()][0] # for python3 keys() returns dict_keys dictionarly res = [
x for x in d.keys(
) # type: ignore # for python3 keys() returns dict_keys dictionarly
][0]
expected = n, expected = n,
assert(expected == res), 'Expected \"{0}\", but got: \"{1}\"'.format(expected, res) assert (expected == res), 'Expected \"{0}\", but got: \"{1}\"'.format(expected, res)
res = [x for x in l.keys(datetime.datetime)][0] res = [x for x in d.keys(datetime.datetime)][0] # type: ignore
assert(n == res), 'Expected {0} as a key, but got: {1}'.format(n, res) assert (n == res), 'Expected {0} as a key, but got: {1}'.format(n, res)
res = [x for x in l.values()] # for python3 keys() returns dict_values dictionarly res = [x for x in d.values()] # for python3 keys() returns dict_values dictionarly
expected = ['now'] expected = ['now']
assert(res == expected), 'Expected values: {0}, but got: {1}'.format(expected, res) assert (res == expected), 'Expected values: {0}, but got: {1}'.format(expected, res)
# test items.. # test items..
exp_items = [((n,), 'now')] exp_items = [((n, ), 'now')]
r = list(l.items()) r = list(d.items())
assert(r == exp_items), 'Expected for items(): tuple of keys: {0}, but got: {1}'.format(r, exp_items) assert (r == exp_items), (
assert(exp_items[0][1] == 'now'), 'Expected for items(): value: {0}, but got: {1}'.format('now', 'Expected for items(): tuple of keys: {0}, but got: {1}'.format(r, exp_items)
exp_items[0][1]) )
assert (exp_items[0][1] == 'now'), (
'Expected for items(): value: {0}, but got: {1}'.format('now', exp_items[0][1])
)
x = multi_key_dict({('k', 'kilo'):1000, ('M', 'MEGA', 1000000):1000000}, milli=0.01) x = multi_key_dict(
{
('k', 'kilo'): 1000, ('M', 'MEGA', 1000000): 1000000
}, milli = 0.01
)
assert (x['k'] == 1000), 'x[\'k\'] is not equal to 1000' assert (x['k'] == 1000), 'x[\'k\'] is not equal to 1000'
x['kilo'] = 'kilo' x['kilo'] = 'kilo'
assert (x['kilo'] == 'kilo'), 'x[\'kilo\'] is not equal to \'kilo\'' assert (x['kilo'] == 'kilo'), 'x[\'kilo\'] is not equal to \'kilo\''
@ -605,11 +693,13 @@ def test_multi_key_dict():
try: try:
y = multi_key_dict([(('two', 'duo'), 2), ('one', 'uno', 1), ('three', 3)]) y = multi_key_dict([(('two', 'duo'), 2), ('one', 'uno', 1), ('three', 3)])
assert(False), 'creating dictionary using iterable with tuples of size > 2 should fail!' assert (False), (
except: 'creating dictionary using iterable with tuples of size > 2 should fail!'
)
except Exception:
pass pass
print ('All test passed OK!') print('All test passed OK!')
__all__ = ["multi_key_dict"] __all__ = ["multi_key_dict"]
@ -617,5 +707,4 @@ if __name__ == '__main__':
try: try:
test_multi_key_dict() test_multi_key_dict()
except KeyboardInterrupt: except KeyboardInterrupt:
print ('\n(interrupted by user)') print('\n(interrupted by user)')

View file

@ -1,14 +1,16 @@
# -*- coding: utf-8 -*-
from __future__ import annotations from __future__ import annotations
from typing import Any, List, Optional, Union import fnmatch
import re
import re, fnmatch
from collections import OrderedDict from collections import OrderedDict
from enum import Enum, auto from enum import Enum, auto
from typing import TYPE_CHECKING
from ..log import * if TYPE_CHECKING:
from typing import List, Optional, Union, Any
from ..log import DEBUG, get_caller_pos, slog
def quote(s): def quote(s):
if is_quoted(s): if is_quoted(s):
@ -26,7 +28,7 @@ def is_quoted(s: str) -> bool:
if len(s) < 2: if len(s) < 2:
return False return False
d = s[0] d = s[0]
if d == s[-1] and d in [ '"', "'" ]: if d == s[-1] and d in ['"', "'"]:
return True return True
return False return False
@ -38,16 +40,18 @@ def cleanup_string(s: str) -> str:
return s[1:-1].replace('\\' + s[0], s[0]) return s[1:-1].replace('\\' + s[0], s[0])
return s return s
class StringTree: # export class StringTree: # export
def __init__(self, path: str, content: str, parent: StringTree|None=None) -> None: def __init__(
self, path: str, content: str, parent: StringTree | None = None
) -> None:
slog(DEBUG, f'Constructing StringTree(path="{path}", content="{content}")') slog(DEBUG, f'Constructing StringTree(path="{path}", content="{content}")')
self.__parent = parent self.__parent = parent
self.children: OrderedDict[str, StringTree] = OrderedDict() self.children: OrderedDict[str, StringTree] = OrderedDict()
self.content: Optional[str] = None self.content: Optional[str] = None
self.__set(path, content) self.__set(path, content)
assert(hasattr(self, "content")) assert (hasattr(self, "content"))
#assert self.content is not None #assert self.content is not None
# root (content = [ symbols ]) # root (content = [ symbols ])
@ -65,60 +69,89 @@ class StringTree: # export
#parent.dump(INFO, "These children are added") #parent.dump(INFO, "These children are added")
self.content = parent.content self.content = parent.content
for name, c in parent.children.items(): for name, c in parent.children.items():
if not name in self.children.keys(): if name not in self.children.keys():
slog(DEBUG, f'At {self.content}: Adding new child {c}') slog(DEBUG, f'At {self.content}: Adding new child {c}')
self.children[name] = c self.children[name] = c
else: else:
self.children[name].__adopt_children(c) self.children[name].__adopt_children(c)
def __set(self, path_, content, split=True): def __set(self, path_, content, split = True):
slog(DEBUG, ('At "{}": '.format(str(self.content)) if hasattr(self, "content") else "") + f'Setting "{path_}" -> "{content}"') slog(
#assert self.content != str(content) # Not sure what the idea behind this was. It often goes off, and all works fine without. DEBUG,
if content is not None and not type(content) in [str, StringTree]: ('At "{}": '.format(str(self.content)) if hasattr(self, "content") else "")
raise Exception("Tried to add content of unsupported type {}".format(type(content).__name__)) + f'Setting "{path_}" -> "{content}"'
)
# Not sure what the idea behind this was. It often goes off, and all
# works fine without.
#assert self.content != str(content)
if content is not None and type(content) not in [str, StringTree]:
raise Exception(
"Tried to add content of unsupported type {}".format(
type(content).__name__
)
)
if path_ is None: if path_ is None:
if isinstance(content, str): if isinstance(content, str):
self.content = cleanup_string(content) self.content = cleanup_string(content)
elif isinstance(content, StringTree): elif isinstance(content, StringTree):
self.__adopt_children(content) self.__adopt_children(content)
else: else:
raise Exception("Tried to add content of unsupported type {}".format(type(content).__name__)) raise Exception(
slog(DEBUG, " -- content = >" + str(content) + "<, self.content = >" + str(self.content) + "<") "Tried to add content of unsupported type {}".format(
type(content).__name__
)
)
slog(
DEBUG,
" -- content = >" + str(content) + "<, self.content = >" +
str(self.content) + "<"
)
return self return self
path = cleanup_string(path_) path = cleanup_string(path_)
components = path.split('.') if split else [ path ] components = path.split('.') if split else [path]
l = len(components) length = len(components)
if len(path) == 0 or l == 0: if len(path) == 0 or length == 0:
#assert self.content is None or (isinstance(content, StringTree) and content.content == self.content)
#assert self.content is None or (
# isinstance(content, StringTree) and content.content == self.content
#)
if isinstance(content, StringTree): if isinstance(content, StringTree):
#assert isinstance(content, StringTree), "Type: " + type(content).__name__ #assert isinstance(content, StringTree), (
# f'Type: {type(content).__name__ }'
#)
self.__adopt_children(content) self.__adopt_children(content)
else: else:
if self.content != content: if self.content != content:
#self.content = cleanup_string(content) #self.content = cleanup_string(content)
slog(DEBUG, f'Changing content: "{self.content}" ->"{content}"') slog(DEBUG, f'Changing content: "{self.content}" ->"{content}"')
assert(content != '"[a-zA-Z0-9+_*/-]"') assert (content != '"[a-zA-Z0-9+_*/-]"')
self.content = content self.content = content
#assert(content != "'antlr_doesnt_understand_vertical_tab'") #assert(content != "'antlr_doesnt_understand_vertical_tab'")
#self.children[content] = StringTree(None, content) #self.children[content] = StringTree(None, content)
return self return self
#assert self.content is not None, "tried to set empty content to {}".format(path_) #assert self.content is not None, f'Tried to set empty content to "{path_}"'
nibble = components[0] nibble = components[0]
rest = '.'.join(components[1:]) rest = '.'.join(components[1:])
if nibble not in self.children: if nibble not in self.children:
self.children[nibble] = StringTree('', content=nibble, parent=self) self.children[nibble] = StringTree('', content = nibble, parent = self)
if l > 1: if length > 1:
assert len(rest) > 0 assert len(rest) > 0
return self.children[nibble].__set(rest, content=content) return self.children[nibble].__set(rest, content = content)
# last component, a.k.a. leaf # last component, a.k.a. leaf
if content is not None: if content is not None:
gc = content if isinstance(content, StringTree) else StringTree('', content=content, parent=self.children[nibble]) gc = content if isinstance(content, StringTree) else StringTree(
'', content = content, parent = self.children[nibble]
)
# Make sure no existing grand child is updated. It would reside too # Make sure no existing grand child is updated. It would reside too
# far up in the grand child OrderedDict, we need it last # far up in the grand child OrderedDict, we need it last
if gc.content in self.children[nibble].children: if gc.content in self.children[nibble].children:
del self.children[nibble].children[gc.content] del self.children[nibble].children[gc.content]
assert gc.content is not None, 'Grand-child content is None'
self.children[nibble].children[gc.content] = gc self.children[nibble].children[gc.content] = gc
return self.children[nibble] return self.children[nibble]
@ -129,17 +162,17 @@ class StringTree: # export
r = self.get(path) r = self.get(path)
if r is None: if r is None:
raise KeyError(path) raise KeyError(path)
return r.value() # type: ignore return r.value() # type: ignore
def __setitem__(self, key, value): def __setitem__(self, key, value):
return self.__set(key, value) return self.__set(key, value)
def __dump(self, prio, indent=0, **kwargs): def __dump(self, prio, indent = 0, **kwargs):
caller = kwargs['caller'] if 'caller' in kwargs.keys() else get_caller_pos(1) caller = kwargs['caller'] if 'caller' in kwargs.keys() else get_caller_pos(1)
slog(prio, '|' + (' ' * indent) + str(self.content), caller=caller) slog(prio, '|' + (' ' * indent) + str(self.content), caller = caller)
indent += 2 indent += 2
for name, child in self.children.items(): for name, child in self.children.items():
child.__dump(prio, indent=indent, caller=caller) child.__dump(prio, indent = indent, caller = caller)
@property @property
def path(self): def path(self):
@ -164,7 +197,12 @@ class StringTree: # export
raise Exception("Tried to set empty content") raise Exception("Tried to set empty content")
self.content = content self.content = content
def add(self, path: str, content: Optional[Union[str, StringTree]] = None, split: bool = True) -> StringTree: def add(
self,
path: str,
content: Optional[Union[str, StringTree]] = None,
split: bool = True
) -> StringTree:
slog(DEBUG, f'-- At "{self.content}": Adding "{path}" -> "{content}"') slog(DEBUG, f'-- At "{self.content}": Adding "{path}" -> "{content}"')
return self.__set(path, content, split) return self.__set(path, content, split)
@ -176,7 +214,7 @@ class StringTree: # export
slog(DEBUG, "returning myself") slog(DEBUG, "returning myself")
return self return self
if is_quoted(path_): if is_quoted(path_):
if not path in self.children.keys(): if path not in self.children.keys():
return None return None
return self.children[path] return self.children[path]
components = path.split('.') components = path.split('.')
@ -185,7 +223,7 @@ class StringTree: # export
name = cleanup_string(components[0]) name = cleanup_string(components[0])
if not hasattr(self, "children"): if not hasattr(self, "children"):
return None return None
if not name in self.children.keys(): if name not in self.children.keys():
slog(DEBUG, "Name \"" + name + "\" is not in children of", self.content) slog(DEBUG, "Name \"" + name + "\" is not in children of", self.content)
for child in self.children: for child in self.children:
slog(DEBUG, "child = ", child) slog(DEBUG, "child = ", child)
@ -193,7 +231,7 @@ class StringTree: # export
relpath = '.'.join(components[1:]) relpath = '.'.join(components[1:])
return self.children[name].get(relpath) return self.children[name].get(relpath)
def value(self, path = None, default=None) -> Optional[str]: def value(self, path = None, default = None) -> Optional[str]:
if path: if path:
child = self.get(path) child = self.get(path)
if child is None: if child is None:
@ -204,7 +242,7 @@ class StringTree: # export
if len(self.children) == 0: if len(self.children) == 0:
raise Exception('tried to get value from leaf "{}"'.format(self.content)) raise Exception('tried to get value from leaf "{}"'.format(self.content))
slog(DEBUG, f'Returning value from children {self.children}') slog(DEBUG, f'Returning value from children {self.children}')
return self.children[next(reversed(self.children))].content # type: ignore return self.children[next(reversed(self.children))].content # type: ignore
@property @property
def parent(self): def parent(self):
@ -216,9 +254,12 @@ class StringTree: # export
return self return self
return self.__parent.root return self.__parent.root
def child_list(self, depth_first: bool=True) -> List[StringTree]: def child_list(self, depth_first: bool = True) -> List[StringTree]:
if depth_first == False: if not depth_first:
raise Exception("tried to retrieve child list with breadth-first search, not yet implemented") raise Exception(
'Tried to retrieve child list with breadth-first '
'search, not yet implemented'
)
r = [] r = []
for name, c in self.children.items(): for name, c in self.children.items():
r.append(c) r.append(c)
@ -230,32 +271,30 @@ class StringTree: # export
msg = '' msg = ''
if args is not None: if args is not None:
msg = ' ' + ' '.join(args) + ' ' msg = ' ' + ' '.join(args) + ' '
slog(prio, ",------------" + msg + "----------- >", caller=caller) slog(prio, ",------------" + msg + "----------- >", caller = caller)
self.__dump(prio, indent=0, caller=caller) self.__dump(prio, indent = 0, caller = caller)
slog(prio, "`------------" + msg + "----------- <", caller=caller) slog(prio, "`------------" + msg + "----------- <", caller = caller)
class Match(Enum): class Match(Enum):
Equal = auto() Equal = auto()
RegExArg = auto() RegExArg = auto()
RegExConf = auto() RegExConf = auto()
GlobArg = auto() GlobArg = auto()
GlobConf = auto() GlobConf = auto()
def __find(self, key: str|None, val: str|None, match: Match, depth_first: bool): def __find(self, key: str | None, val: str | None, m: Match, depth_first: bool):
def __children(): def __children():
for name, child in self.children.items(): for name, child in self.children.items():
ret.extend(child.__find(key, val, match, depth_first)) ret.extend(child.__find(key, val, m, depth_first))
def __self(): def __self():
_val = self.value() _val = self.value()
_content = self.content _content = self.content
try: try:
if ( if ((key == _content and matcher(val, _val))
(key == _content and matcher(val, _val)) or (key is None and matcher(val, _val))
or (key is None and matcher(val, _val)) or (key == _content and val is None)):
or (key == _content and val is None)
):
ret.append(self) ret.append(self)
except Exception as e: except Exception as e:
if isinstance(e, re.PatternError): if isinstance(e, re.PatternError):
@ -263,29 +302,33 @@ class StringTree: # export
else: else:
raise raise
def __debug_matcher(matcher, log_level=DEBUG): def __select_matcher(m: StringTree.Match) -> Any:
match m:
case self.Match.Equal:
return lambda x, y: x == y
case self.Match.RegExArg:
return lambda x, y: re.search(x, y) is not None
case self.Match.RegExConf:
return lambda x, y: re.search(y, x) is not None
case self.Match.GlobArg:
return lambda x, y: fnmatch.fnmatch(y, x)
case self.Match.GlobConf:
return lambda x, y: fnmatch.fnmatch(x, y)
case _:
raise NotImplementedError(f'Matcher {m} is not yet implemented')
def __debug_matcher(matcher, log_level = DEBUG):
def __matcher(x, y): def __matcher(x, y):
slog(log_level, f'Comparing "{x}" ~ "{y}"') slog(log_level, f'Comparing "{x}" ~ "{y}"')
return matcher(x, y) return matcher(x, y)
return __matcher return __matcher
if not self.children: if not self.children:
return [] return []
matcher = lambda x, y: x == y matcher = __select_matcher(m)
match match:
case self.Match.Equal:
pass
case self.Match.RegExArg:
matcher = lambda x, y: re.search(x, y) is not None
case self.Match.RegExConf:
matcher = lambda x, y: re.search(y, x) is not None
case self.Match.GlobArg:
matcher = lambda x, y: fnmatch.fnmatch(y, x)
case self.Match.GlobConf:
matcher = lambda x, y: fnmatch.fnmatch(x, y)
case _:
raise NotImplementedError(f'Matcher {match} is not yet implemented')
ret: list[StringTree] = [] ret: list[StringTree] = []
@ -298,5 +341,16 @@ class StringTree: # export
return ret return ret
def find(self, key: str|None=None, val: str|None=None, match: Match=Match.Equal, depth_first: bool=False): def find(
return [ node.parent.path for node in self.__find(key, val, match=match, depth_first=depth_first)] self,
key: str | None = None,
val: str | None = None,
match: Match = Match.Equal,
depth_first: bool = False
):
ret: list[str] = []
for node in self.__find(key, val, m = match, depth_first = depth_first):
if node.parent is None:
break
ret.append(node.parent.path)
return ret

View file

@ -1,9 +1,9 @@
# -*- coding: utf-8 -*- import glob
import os
import re
import os, glob from ..log import DEBUG, ERR, INFO, slog, slog_m
from .StringTree import StringTree, cleanup_string
from .StringTree import *
from ..log import *
def _cleanup_line(line: str) -> str: def _cleanup_line(line: str) -> str:
line = line.strip() line = line.strip()
@ -15,18 +15,22 @@ def _cleanup_line(line: str) -> str:
if c == in_quote: if c == in_quote:
in_quote = None in_quote = None
else: else:
if c in [ '"', "'" ]: if c in ['"', "'"]:
in_quote = c in_quote = c
elif in_quote is None and c == '#': elif in_quote is None and c == '#':
return r.strip() return r.strip()
r += c r += c
if len(r) >= 2 and r[0] in [ '"', "'" ] and r[-1] == r[0]: if len(r) >= 2 and r[0] in ['"', "'"] and r[-1] == r[0]:
return r[1:-1] return r[1:-1]
return r return r
def parse(s: str, allow_full_lines: bool=True, root_content: str='root') -> StringTree: # export def parse( # export
s: str,
allow_full_lines: bool = True,
root_content: str = 'root'
) -> StringTree:
slog_m(DEBUG, "--->--- parsing --->---\n" + s + "\n---<--- parsing ---<---\n") slog_m(DEBUG, "--->--- parsing --->---\n" + s + "\n---<--- parsing ---<---\n")
root = StringTree('', content=root_content) root = StringTree('', content = root_content)
sec = '' sec = ''
for line in s.splitlines(): for line in s.splitlines():
slog(DEBUG, f'Parsing: "{line}"') slog(DEBUG, f'Parsing: "{line}"')
@ -47,7 +51,7 @@ def parse(s: str, allow_full_lines: bool=True, root_content: str='root') -> Stri
root.add(sec) root.add(sec)
continue continue
elif line[0] == ']': elif line[0] == ']':
assert(len(sec) > 0) assert (len(sec) > 0)
sec = '.'.join(sec.split('.')[0:-1]) sec = '.'.join(sec.split('.')[0:-1])
continue continue
lhs = '' lhs = ''
@ -67,17 +71,19 @@ def parse(s: str, allow_full_lines: bool=True, root_content: str='root') -> Stri
raise Exception("failed to parse assignment", line) raise Exception("failed to parse assignment", line)
rhs = 'empty' rhs = 'empty'
split = False split = False
root.add(sec + '.' + cleanup_string(lhs), cleanup_string(rhs), split=split) root.add(sec + '.' + cleanup_string(lhs), cleanup_string(rhs), split = split)
return root return root
def _read_lines_from_one_path(path: str, throw=True, level=0, log_prio=INFO, paths_buf=None): def _read_lines_from_one_path(
path: str, throw = True, level = 0, log_prio = INFO, paths_buf = None
):
try: try:
with open(path, 'r') as infile: with open(path, 'r') as infile:
slog(log_prio, 'Reading {}"{}"'.format(' ' * level * 2, path)) slog(log_prio, 'Reading {}"{}"'.format(' ' * level * 2, path))
if paths_buf is not None: if paths_buf is not None:
paths_buf.append(path) paths_buf.append(path)
ret = [] ret = []
for line in infile: # lines are all trailed by \n for line in infile: # lines are all trailed by \n
m = re.search(r'^\s*(-)*include\s+(\S+)', line) m = re.search(r'^\s*(-)*include\s+(\S+)', line)
if m: if m:
optional = m.group(1) == '-' optional = m.group(1) == '-'
@ -86,7 +92,12 @@ def _read_lines_from_one_path(path: str, throw=True, level=0, log_prio=INFO, pat
dir_name = os.path.dirname(path) dir_name = os.path.dirname(path)
if len(dir_name): if len(dir_name):
include_path = dir_name + '/' + include_path include_path = dir_name + '/' + include_path
include_lines = _read_lines(include_path, throw=(not optional), level=level+1, paths_buf=paths_buf) include_lines = _read_lines(
include_path,
throw = (not optional),
level = level + 1,
paths_buf = paths_buf
)
if include_lines is None: if include_lines is None:
slog(DEBUG, f'{path}: Failed to process "{line}"') slog(DEBUG, f'{path}: Failed to process "{line}"')
continue continue
@ -100,17 +111,26 @@ def _read_lines_from_one_path(path: str, throw=True, level=0, log_prio=INFO, pat
raise raise
return None return None
def _read_lines(path: str, throw=True, level=0, log_prio=INFO, paths_buf=None): def _read_lines(path: str, throw = True, level = 0, log_prio = INFO, paths_buf = None):
paths = glob.glob(path) paths = glob.glob(path)
ret = [] ret = []
for p in paths: for p in paths:
rr = _read_lines_from_one_path(p, throw=throw, level=level, log_prio=log_prio, paths_buf=paths_buf) rr = _read_lines_from_one_path(
p, throw = throw, level = level, log_prio = log_prio, paths_buf = paths_buf
)
if rr is None: if rr is None:
return None return None
ret.extend(rr) ret.extend(rr)
return ret return ret
def read(path: str, root_content: str='root', log_prio=INFO, paths_buf=None) -> StringTree: # export def read( # export
lines = _read_lines_from_one_path(path, log_prio=log_prio, paths_buf=paths_buf) path: str,
root_content: str = 'root',
log_prio = INFO,
paths_buf = None
) -> StringTree:
lines = _read_lines_from_one_path(path, log_prio = log_prio, paths_buf = paths_buf)
if lines is None:
raise Exception(f'Could not read ini file from "{path}"')
s = ''.join(lines) s = ''.join(lines)
return parse(s, root_content=root_content) return parse(s, root_content = root_content)