misc: Add load_(classes|class_names|object)()

Add some functions to aid in dynamically loading objects.

Signed-off-by: Jan Lindemann <jan@janware.com>
This commit is contained in:
Jan Lindemann 2020-04-22 14:02:35 +02:00
commit 86e25a2dfb

View file

@ -3,6 +3,7 @@ import atexit
import tempfile
import filecmp
import inspect
import importlib
from typing import Set
from jwutils import log
@ -50,7 +51,7 @@ def object_builtin_name(o, full=True): # export
return o.__class__.__name__ # Avoid reporting __builtin__
return module + '.' + o.__class__.__name__
def get_derived_classes(mod, base): # export
def get_derived_classes(mod, base, flt=None): # export
members = inspect.getmembers(mod, inspect.isclass)
r = []
for name, c in members:
@ -61,9 +62,43 @@ def get_derived_classes(mod, base): # export
if not base in inspect.getmro(c):
log.slog(log.DEBUG, " is not derived from", base, "only", inspect.getmro(c))
continue
if flt and not re.match(flt, name):
slog(DEBUG, ' "{}.{}" has wrong name'.format(mod, name))
continue
r.append(c)
return r
def load_classes(path, baseclass, flt=None): # export
r = []
for p in path.split(':'):
mod = importlib.import_module(path)
r.extend(get_derived_classes(mod, baseclass, flt))
return r
def load_class_names(path, baseclass, flt=None, remove_flt=False): # export
classes = load_classes(path, baseclass, flt)
r = []
for c in classes:
name = c.__name__
if flt and remove_flt:
name = re.subst(flt, "", name)
if name not in r:
r.append(name)
else:
pass
#log.slog(log.WARNING, "{} is already in in {}".format(name, r))
return r
# TODO: handling of "name" / "flt" is awkward
def load_object(path, baseclass, name = None, *args, **kwargs): # export
mod = importlib.import_module(path)
classes = get_derived_classes(mod, baseclass, flt=None)
if len(classes) == 0:
raise Exception("no class {} found".format(name))
if len(classes) > 1:
raise Exception("{} classes matching {} found".format(len(classes), name))
return classes[0](*args, **kwargs)
def commit_tmpfile(tmp: str, path: str) -> None: # export
caller = log.get_caller_pos()
if os.path.isfile(path) and filecmp.cmp(tmp, path):