misc: Add load_class()

Add load_class() to complement load_object(), returning the class as
opposed to instantiated object.

Signed-off-by: Jan Lindemann <jan@janware.com>
This commit is contained in:
Jan Lindemann 2025-01-16 10:51:28 +01:00
commit 381514ab2c

View file

@ -4,7 +4,10 @@ import tempfile
import filecmp
import inspect
import importlib
import re
from typing import Set, Iterable
from jwutils import log
_tmpfiles: Set[str] = set()
@ -63,7 +66,7 @@ def get_derived_classes(mod, base, flt=None): # export
log.slog(log.DEBUG, " is not derived from", base, "only", inspect.getmro(c))
continue
if flt and not re.match(flt, name):
slog(DEBUG, ' "{}.{}" has wrong name'.format(mod, name))
log.slog(log.DEBUG, ' "{}.{}" has wrong name'.format(mod, name))
continue
r.append(c)
return r
@ -75,6 +78,15 @@ def load_classes(path, baseclass, flt=None): # export
r.extend(get_derived_classes(mod, baseclass, flt))
return r
def load_class(module_path, baseclass, class_name_filter=None): # export
mod = importlib.import_module(module_path)
classes = get_derived_classes(mod, baseclass, flt=class_name_filter)
if len(classes) == 0:
raise Exception(f'no class matching "{class_name_filter}" of type "{baseclass}" found in module "{module_path}"')
if len(classes) > 1:
raise Exception(f'{len(classes)} classes matching "{class_name_filter}" of type "{baseclass}" found in module "{module_path}"')
return classes[0]
def load_class_names(path, baseclass, flt=None, remove_flt=False): # export
classes = load_classes(path, baseclass, flt)
r = []
@ -89,15 +101,8 @@ def load_class_names(path, baseclass, flt=None, remove_flt=False): # export
#log.slog(log.WARNING, "{} is already in in {}".format(name, r))
return r
# TODO: handling of "name" / "flt" is awkward
def load_object(path, baseclass, name = None, *args, **kwargs): # export
mod = importlib.import_module(path)
classes = get_derived_classes(mod, baseclass, flt=None)
if len(classes) == 0:
raise Exception("no class {} found".format(name))
if len(classes) > 1:
raise Exception("{} classes matching {} found".format(len(classes), name))
return classes[0](*args, **kwargs)
def load_object(module_path, baseclass, class_name_filter=None, *args, **kwargs): # export
return load_class(module_path, baseclass, class_name_filter=class_name_filter)(*args, **kwargs)
def commit_tmpfile(tmp: str, path: str) -> None: # export
caller = log.get_caller_pos()