From 381514ab2cd9b42fd836fc704e70cbd8898e3878 Mon Sep 17 00:00:00 2001 From: Jan Lindemann Date: Thu, 16 Jan 2025 10:51:28 +0100 Subject: [PATCH] misc: Add load_class() Add load_class() to complement load_object(), returning the class as opposed to instantiated object. Signed-off-by: Jan Lindemann --- tools/python/jwutils/misc.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/tools/python/jwutils/misc.py b/tools/python/jwutils/misc.py index 108f428..3b2ea05 100644 --- a/tools/python/jwutils/misc.py +++ b/tools/python/jwutils/misc.py @@ -4,7 +4,10 @@ import tempfile import filecmp import inspect import importlib +import re + from typing import Set, Iterable + from jwutils import log _tmpfiles: Set[str] = set() @@ -63,7 +66,7 @@ def get_derived_classes(mod, base, flt=None): # export log.slog(log.DEBUG, " is not derived from", base, "only", inspect.getmro(c)) continue if flt and not re.match(flt, name): - slog(DEBUG, ' "{}.{}" has wrong name'.format(mod, name)) + log.slog(log.DEBUG, ' "{}.{}" has wrong name'.format(mod, name)) continue r.append(c) return r @@ -75,6 +78,15 @@ def load_classes(path, baseclass, flt=None): # export r.extend(get_derived_classes(mod, baseclass, flt)) return r +def load_class(module_path, baseclass, class_name_filter=None): # export + mod = importlib.import_module(module_path) + classes = get_derived_classes(mod, baseclass, flt=class_name_filter) + if len(classes) == 0: + raise Exception(f'no class matching "{class_name_filter}" of type "{baseclass}" found in module "{module_path}"') + if len(classes) > 1: + raise Exception(f'{len(classes)} classes matching "{class_name_filter}" of type "{baseclass}" found in module "{module_path}"') + return classes[0] + def load_class_names(path, baseclass, flt=None, remove_flt=False): # export classes = load_classes(path, baseclass, flt) r = [] @@ -89,15 +101,8 @@ def load_class_names(path, baseclass, flt=None, remove_flt=False): # export #log.slog(log.WARNING, "{} is already in in {}".format(name, r)) return r -# TODO: handling of "name" / "flt" is awkward -def load_object(path, baseclass, name = None, *args, **kwargs): # export - mod = importlib.import_module(path) - classes = get_derived_classes(mod, baseclass, flt=None) - if len(classes) == 0: - raise Exception("no class {} found".format(name)) - if len(classes) > 1: - raise Exception("{} classes matching {} found".format(len(classes), name)) - return classes[0](*args, **kwargs) +def load_object(module_path, baseclass, class_name_filter=None, *args, **kwargs): # export + return load_class(module_path, baseclass, class_name_filter=class_name_filter)(*args, **kwargs) def commit_tmpfile(tmp: str, path: str) -> None: # export caller = log.get_caller_pos()