misc: Add load_(classes|class_names|object)()

Add some functions to aid in dynamically loading objects.

Signed-off-by: Jan Lindemann <jan@janware.com>
This commit is contained in:
Jan Lindemann 2020-04-22 14:02:35 +02:00
commit 86e25a2dfb

View file

@ -3,6 +3,7 @@ import atexit
import tempfile import tempfile
import filecmp import filecmp
import inspect import inspect
import importlib
from typing import Set from typing import Set
from jwutils import log from jwutils import log
@ -50,7 +51,7 @@ def object_builtin_name(o, full=True): # export
return o.__class__.__name__ # Avoid reporting __builtin__ return o.__class__.__name__ # Avoid reporting __builtin__
return module + '.' + o.__class__.__name__ return module + '.' + o.__class__.__name__
def get_derived_classes(mod, base): # export def get_derived_classes(mod, base, flt=None): # export
members = inspect.getmembers(mod, inspect.isclass) members = inspect.getmembers(mod, inspect.isclass)
r = [] r = []
for name, c in members: for name, c in members:
@ -61,9 +62,43 @@ def get_derived_classes(mod, base): # export
if not base in inspect.getmro(c): if not base in inspect.getmro(c):
log.slog(log.DEBUG, " is not derived from", base, "only", inspect.getmro(c)) log.slog(log.DEBUG, " is not derived from", base, "only", inspect.getmro(c))
continue continue
if flt and not re.match(flt, name):
slog(DEBUG, ' "{}.{}" has wrong name'.format(mod, name))
continue
r.append(c) r.append(c)
return r return r
def load_classes(path, baseclass, flt=None): # export
r = []
for p in path.split(':'):
mod = importlib.import_module(path)
r.extend(get_derived_classes(mod, baseclass, flt))
return r
def load_class_names(path, baseclass, flt=None, remove_flt=False): # export
classes = load_classes(path, baseclass, flt)
r = []
for c in classes:
name = c.__name__
if flt and remove_flt:
name = re.subst(flt, "", name)
if name not in r:
r.append(name)
else:
pass
#log.slog(log.WARNING, "{} is already in in {}".format(name, r))
return r
# TODO: handling of "name" / "flt" is awkward
def load_object(path, baseclass, name = None, *args, **kwargs): # export
mod = importlib.import_module(path)
classes = get_derived_classes(mod, baseclass, flt=None)
if len(classes) == 0:
raise Exception("no class {} found".format(name))
if len(classes) > 1:
raise Exception("{} classes matching {} found".format(len(classes), name))
return classes[0](*args, **kwargs)
def commit_tmpfile(tmp: str, path: str) -> None: # export def commit_tmpfile(tmp: str, path: str) -> None: # export
caller = log.get_caller_pos() caller = log.get_caller_pos()
if os.path.isfile(path) and filecmp.cmp(tmp, path): if os.path.isfile(path) and filecmp.cmp(tmp, path):