mirror of
ssh://git.janware.com/srv/git/janware/proj/jw-python
synced 2026-01-15 01:52:56 +01:00
misc: Add load_class()
Add load_class() to complement load_object(), returning the class as opposed to instantiated object. Signed-off-by: Jan Lindemann <jan@janware.com>
This commit is contained in:
parent
3a7fb50979
commit
381514ab2c
1 changed files with 15 additions and 10 deletions
|
|
@ -4,7 +4,10 @@ import tempfile
|
|||
import filecmp
|
||||
import inspect
|
||||
import importlib
|
||||
import re
|
||||
|
||||
from typing import Set, Iterable
|
||||
|
||||
from jwutils import log
|
||||
|
||||
_tmpfiles: Set[str] = set()
|
||||
|
|
@ -63,7 +66,7 @@ def get_derived_classes(mod, base, flt=None): # export
|
|||
log.slog(log.DEBUG, " is not derived from", base, "only", inspect.getmro(c))
|
||||
continue
|
||||
if flt and not re.match(flt, name):
|
||||
slog(DEBUG, ' "{}.{}" has wrong name'.format(mod, name))
|
||||
log.slog(log.DEBUG, ' "{}.{}" has wrong name'.format(mod, name))
|
||||
continue
|
||||
r.append(c)
|
||||
return r
|
||||
|
|
@ -75,6 +78,15 @@ def load_classes(path, baseclass, flt=None): # export
|
|||
r.extend(get_derived_classes(mod, baseclass, flt))
|
||||
return r
|
||||
|
||||
def load_class(module_path, baseclass, class_name_filter=None): # export
|
||||
mod = importlib.import_module(module_path)
|
||||
classes = get_derived_classes(mod, baseclass, flt=class_name_filter)
|
||||
if len(classes) == 0:
|
||||
raise Exception(f'no class matching "{class_name_filter}" of type "{baseclass}" found in module "{module_path}"')
|
||||
if len(classes) > 1:
|
||||
raise Exception(f'{len(classes)} classes matching "{class_name_filter}" of type "{baseclass}" found in module "{module_path}"')
|
||||
return classes[0]
|
||||
|
||||
def load_class_names(path, baseclass, flt=None, remove_flt=False): # export
|
||||
classes = load_classes(path, baseclass, flt)
|
||||
r = []
|
||||
|
|
@ -89,15 +101,8 @@ def load_class_names(path, baseclass, flt=None, remove_flt=False): # export
|
|||
#log.slog(log.WARNING, "{} is already in in {}".format(name, r))
|
||||
return r
|
||||
|
||||
# TODO: handling of "name" / "flt" is awkward
|
||||
def load_object(path, baseclass, name = None, *args, **kwargs): # export
|
||||
mod = importlib.import_module(path)
|
||||
classes = get_derived_classes(mod, baseclass, flt=None)
|
||||
if len(classes) == 0:
|
||||
raise Exception("no class {} found".format(name))
|
||||
if len(classes) > 1:
|
||||
raise Exception("{} classes matching {} found".format(len(classes), name))
|
||||
return classes[0](*args, **kwargs)
|
||||
def load_object(module_path, baseclass, class_name_filter=None, *args, **kwargs): # export
|
||||
return load_class(module_path, baseclass, class_name_filter=class_name_filter)(*args, **kwargs)
|
||||
|
||||
def commit_tmpfile(tmp: str, path: str) -> None: # export
|
||||
caller = log.get_caller_pos()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue