From 86e25a2dfbc5ad83289bf7f0a1c372f3116f54cd Mon Sep 17 00:00:00 2001 From: Jan Lindemann Date: Wed, 22 Apr 2020 14:02:35 +0200 Subject: [PATCH] misc: Add load_(classes|class_names|object)() Add some functions to aid in dynamically loading objects. Signed-off-by: Jan Lindemann --- tools/python/jwutils/misc.py | 37 +++++++++++++++++++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/tools/python/jwutils/misc.py b/tools/python/jwutils/misc.py index 7bd2882..94997c1 100644 --- a/tools/python/jwutils/misc.py +++ b/tools/python/jwutils/misc.py @@ -3,6 +3,7 @@ import atexit import tempfile import filecmp import inspect +import importlib from typing import Set from jwutils import log @@ -50,7 +51,7 @@ def object_builtin_name(o, full=True): # export return o.__class__.__name__ # Avoid reporting __builtin__ return module + '.' + o.__class__.__name__ -def get_derived_classes(mod, base): # export +def get_derived_classes(mod, base, flt=None): # export members = inspect.getmembers(mod, inspect.isclass) r = [] for name, c in members: @@ -61,9 +62,43 @@ def get_derived_classes(mod, base): # export if not base in inspect.getmro(c): log.slog(log.DEBUG, " is not derived from", base, "only", inspect.getmro(c)) continue + if flt and not re.match(flt, name): + slog(DEBUG, ' "{}.{}" has wrong name'.format(mod, name)) + continue r.append(c) return r +def load_classes(path, baseclass, flt=None): # export + r = [] + for p in path.split(':'): + mod = importlib.import_module(path) + r.extend(get_derived_classes(mod, baseclass, flt)) + return r + +def load_class_names(path, baseclass, flt=None, remove_flt=False): # export + classes = load_classes(path, baseclass, flt) + r = [] + for c in classes: + name = c.__name__ + if flt and remove_flt: + name = re.subst(flt, "", name) + if name not in r: + r.append(name) + else: + pass + #log.slog(log.WARNING, "{} is already in in {}".format(name, r)) + return r + +# TODO: handling of "name" / "flt" is awkward +def load_object(path, baseclass, name = None, *args, **kwargs): # export + mod = importlib.import_module(path) + classes = get_derived_classes(mod, baseclass, flt=None) + if len(classes) == 0: + raise Exception("no class {} found".format(name)) + if len(classes) > 1: + raise Exception("{} classes matching {} found".format(len(classes), name)) + return classes[0](*args, **kwargs) + def commit_tmpfile(tmp: str, path: str) -> None: # export caller = log.get_caller_pos() if os.path.isfile(path) and filecmp.cmp(tmp, path):