misc: Add load_class()

Add load_class() to complement load_object(), returning the class as
opposed to instantiated object.

Signed-off-by: Jan Lindemann <jan@janware.com>
This commit is contained in:
Jan Lindemann 2025-01-16 10:51:28 +01:00
commit 381514ab2c

View file

@ -4,7 +4,10 @@ import tempfile
import filecmp import filecmp
import inspect import inspect
import importlib import importlib
import re
from typing import Set, Iterable from typing import Set, Iterable
from jwutils import log from jwutils import log
_tmpfiles: Set[str] = set() _tmpfiles: Set[str] = set()
@ -63,7 +66,7 @@ def get_derived_classes(mod, base, flt=None): # export
log.slog(log.DEBUG, " is not derived from", base, "only", inspect.getmro(c)) log.slog(log.DEBUG, " is not derived from", base, "only", inspect.getmro(c))
continue continue
if flt and not re.match(flt, name): if flt and not re.match(flt, name):
slog(DEBUG, ' "{}.{}" has wrong name'.format(mod, name)) log.slog(log.DEBUG, ' "{}.{}" has wrong name'.format(mod, name))
continue continue
r.append(c) r.append(c)
return r return r
@ -75,6 +78,15 @@ def load_classes(path, baseclass, flt=None): # export
r.extend(get_derived_classes(mod, baseclass, flt)) r.extend(get_derived_classes(mod, baseclass, flt))
return r return r
def load_class(module_path, baseclass, class_name_filter=None): # export
mod = importlib.import_module(module_path)
classes = get_derived_classes(mod, baseclass, flt=class_name_filter)
if len(classes) == 0:
raise Exception(f'no class matching "{class_name_filter}" of type "{baseclass}" found in module "{module_path}"')
if len(classes) > 1:
raise Exception(f'{len(classes)} classes matching "{class_name_filter}" of type "{baseclass}" found in module "{module_path}"')
return classes[0]
def load_class_names(path, baseclass, flt=None, remove_flt=False): # export def load_class_names(path, baseclass, flt=None, remove_flt=False): # export
classes = load_classes(path, baseclass, flt) classes = load_classes(path, baseclass, flt)
r = [] r = []
@ -89,15 +101,8 @@ def load_class_names(path, baseclass, flt=None, remove_flt=False): # export
#log.slog(log.WARNING, "{} is already in in {}".format(name, r)) #log.slog(log.WARNING, "{} is already in in {}".format(name, r))
return r return r
# TODO: handling of "name" / "flt" is awkward def load_object(module_path, baseclass, class_name_filter=None, *args, **kwargs): # export
def load_object(path, baseclass, name = None, *args, **kwargs): # export return load_class(module_path, baseclass, class_name_filter=class_name_filter)(*args, **kwargs)
mod = importlib.import_module(path)
classes = get_derived_classes(mod, baseclass, flt=None)
if len(classes) == 0:
raise Exception("no class {} found".format(name))
if len(classes) > 1:
raise Exception("{} classes matching {} found".format(len(classes), name))
return classes[0](*args, **kwargs)
def commit_tmpfile(tmp: str, path: str) -> None: # export def commit_tmpfile(tmp: str, path: str) -> None: # export
caller = log.get_caller_pos() caller = log.get_caller_pos()