jw-python/tools/python/jwutils/ldap.py
Jan Lindemann 988d420e44 ldap.Connection: Support uri config option
Currently the configuration passed to the Connection constructor
needs to contain an ldap_uri entry. Add "uri" as alias, because
ldap_uri for the LDAP config in many contexts represents a tautology
and is left out.

Signed-off-by: Jan Lindemann <jan@janware.com>
2026-04-03 18:00:24 +02:00

314 lines
11 KiB
Python

# -*- coding: utf-8 -*-
import ldap, getpass, pathlib, copy
from ldap.schema.models import ObjectClass
from enum import Flag, auto
import networkx as nx
from typing import Any, Self
from collections.abc import Callable
from .Config import Config as BaseConfig
from .log import *
class Config:
def __init__(self, external: BaseConfig|None=None):
self.__external = external
for attr in ['ldap_uri', 'bind_dn', 'bind_pw', 'base_dn']:
setattr(self, '_Config__' + attr, None)
def __get(self, key: str, default: str):
if not self.__external:
return default
return self.__external.value(key, default=default)
@property
def ldap_uri(self):
if self.__ldap_uri is None:
for key in ['ldap_uri', 'uri']:
self.__ldap_uri = self.__get(key, default=None)
if self.__ldap_uri is not None:
break
else:
self.__ldap_uri = 'ldap://ldap.janware.com'
return self.__ldap_uri
@ldap_uri.setter
def ldap_uri(self, rhs):
self.__ldap_uri = rhs
@property
def bind_dn(self):
if self.__bind_dn is None:
self.__bind_dn = self.__get('bind_dn', default=f'uid={getpass.getuser()},ou=users,dc=jannet,dc=de')
return self.__bind_dn
@bind_dn.setter
def bind_dn(self, rhs):
self.__bind_dn = rhs
@property
def bind_pw(self):
if self.__bind_pw is None:
for key in ['bind_pw', 'password']:
ret = self.__get(key, default=None)
if ret is not None:
break
if ret is None:
ldap_secret_file = self.__get('secret_file', f'{pathlib.Path.home()}/.ldap.secret')
with open(ldap_secret_file, 'r') as file:
ret = file.read()
file.closed
ret = ret.strip()
self.__bind_pw = ret
return self.__bind_pw
@bind_pw.setter
def bind_pw(self, rhs):
self.__bind_pw = rhs
@property
def base_dn(self):
if self.__base_dn is None:
self.__base_dn = self.__get('base_dn', default=f'dc=jannet,dc=de')
return self.__base_dn
@base_dn.setter
def base_dn(self, rhs):
self.__base_dn = rhs
class Connection: # export
class AttrType(Flag):
Must = auto()
May = auto()
def __init__(self, conf: Config|BaseConfig|None=None, backtrace=False):
uri: str|None = None
c = conf if isinstance(conf, Config) else Config(conf)
try:
uri = c.ldap_uri
except:
uri = c.uri
try:
ret = ldap.initialize(uri)
ret.start_tls_s()
except Exception as e:
slog(ERR, f'Failed to initialize LDAP connection to "{uri}" ({str(e)})')
raise
try:
rr = ret.bind_s(c.bind_dn, c.bind_pw) # method)
except Exception as e:
slog(ERR, f'Failed to bind to "{uri}" with dn "{c.bind_dn}" ({str(e)})')
raise
self.__ldap = ret
self.__backtrace = backtrace
self.__object_classes_by_oid: dict[str, ObjectClass]|None = None
self.__object_class_tree: nx.Graph|None = None
self.__object_classes_by_name: dict[str, ObjectClass]|None = None
@property
def ldap(self):
return self.__ldap
def add(self, attrs: dict[str, bytes], dn: str|None=None):
if dn is None:
if not 'dn' in attrs:
raise Exception('No DN to add an LDAP entry to')
attrs = copy.deepcopy(attrs)
del attrs['dn']
try:
slog(INFO, f'LDAP: Add [{dn}] -> {attrs}')
self.__ldap.add_s(dn, ldap.modlist.addModlist(attrs))
except Exception as e:
slog(ERR, f'{dn}: Failed to add entry {attrs} ({e})')
raise
def delete(self, dn: str, recursive=False, force_existence: bool=False):
def __walk_cb_delete(conn: Connection, entry, context):
self.walk(__walk_cb_delete, base=entry[0], scope=ldap.SCOPE_ONELEVEL, context=context)
self.__ldap.delete_s(entry[0])
try:
if recursive:
self.walk(__walk_cb_delete, dn, scope=ldap.SCOPE_ONELEVEL)
self.__ldap.delete_s(dn)
else:
self.__ldap.delete_s(dn)
except ldap.NO_SUCH_OBJECT as e:
if force_existence:
raise
except Exception as e:
slog(ERR, f'Failed to delete {dn} ({e})')
raise
def walk(self,
callback: Callable[[Self, Any, Any], None],
base: str,
scope,
context=None,
filterstr=None,
attrlist=None,
attrsonly=0,
serverctrls=None,
clientctrls=None,
timeout=-1,
sizelimit=0):
# TODO: Support ignored arguments
search_return = self.__ldap.search(base=base,
scope=scope,
filterstr=filterstr,
attrlist=attrlist,
attrsonly=attrsonly)
while True:
result_type, result_data = self.__ldap.result(search_return, 0)
if (result_data == []):
break
if result_type != ldap.RES_SEARCH_ENTRY:
continue
for entry in result_data:
try:
callback(self, entry, context)
except Exception as e:
msg = f'Exception {e}'
if self.__backtrace:
slog(ERR, msg)
raise
slog(WARNING, msg)
continue
def find(self,
base: str,
scope,
filterstr=None,
attrlist=None,
attrsonly=0,
serverctrls=None,
clientctrls=None,
timeout=-1,
sizelimit=0,
assert_unique=False,
assert_not_empty=False,
):
def __walk_cb_find(conn: Connection, entry: Any, context: Any):
result.append(entry)
def __search():
return f'{base} -> "{filterstr}"'
try:
result: list[Any] = []
self.walk(__walk_cb_find, base, scope=scope, filterstr=filterstr, attrlist=attrlist)
except Exception as e:
slog(ERR, f'Failed search {__search()} ({e})')
raise
if assert_not_empty and not result:
raise Exception(f'Empty result for search {__search()}')
if assert_unique and len(result) > 1:
raise Exception(f'Found {len(result)} results for search {__search()}')
return result
@property
def object_classes(self) -> dict[str, ObjectClass]:
#def object_classes(self):
if self.__object_classes_by_oid is None:
res = self.find(base='', scope=ldap.SCOPE_BASE, filterstr='(objectClass=*)', attrlist=['subschemaSubentry'])
dn = res[0][1]['subschemaSubentry'][0].decode('utf-8') # Usually yields cn=Subschema
res = self.find(base=dn, scope=ldap.SCOPE_BASE, filterstr='(objectClass=*)', attrlist=['*', '+'])
subschema_entry = res[0]
subschema_subentry = ldap.cidict.cidict(subschema_entry[1])
subschema = ldap.schema.SubSchema(subschema_subentry)
object_class_oids = subschema.listall(ObjectClass)
self.__object_classes_by_oid = {
oid: subschema.get_obj(ObjectClass, oid) for oid in object_class_oids
}
return self.__object_classes_by_oid
@property
def object_class_by_name(self) -> dict[str, ObjectClass]:
if self.__object_classes_by_name is None:
ret: dict[str, ObjectClass] = {}
self.__object_classes_by_name = ret
for oid, oc in self.object_classes.items():
ret[oid] = oc
for name in oc.names:
ret[name] = oc
ret[name.lower()] = oc
return self.__object_classes_by_name
def __oc_recurse_to_top(self, cur: str|ObjectClass, cb, context):
cur_oc = cur if isinstance(cur, ObjectClass) else self.object_class_by_name[cur.lower()]
for s in cur_oc.sup:
self.__oc_recurse_to_top(s, cb, context)
cb(cur_oc, context)
def object_class_path(self, leaf: str|ObjectClass):
def cb(oc, context):
ret.append(oc)
ret: list[str] = []
self.__oc_recurse_to_top(leaf, cb, None)
return reversed(ret)
@property
def object_class_tree(self) -> nx.Graph:
if self.__object_class_tree is None:
def collect(root, attr):
ret = set()
def cb(oc, attr):
vals = getattr(oc, attr)
if vals is None:
return
for val in vals:
ret.add(val)
self.__oc_recurse_to_top(root, cb, attr)
return ret
kind = {
0: 'STRUCTURAL',
1: 'ABSTRACT',
2: 'AUXILIARY'
}
ret = nx.DiGraph()
for oid, oc in self.object_classes.items():
ret.add_node(
oid,
oid=oid,
name=oc.names[0],
kind=kind[oc.kind],
must=', '.join(collect(oc, 'must')),
may=', '.join(collect(oc, 'may'))
)
for base_class in oc.sup:
try:
ret.add_edge(oid, self.object_class_by_name[base_class.lower()].oid)
except Exception as e:
slog(WARNING, f'Failed to add edge {oid}:{oc.names} -> {base_class} ({e})')
self.__object_class_tree = ret
return self.__object_class_tree
def object_class_attrs(self, oc: str|ObjectClass, required: AttrType = AttrType.Must, origins: bool=False) -> dict[str, set[str]]|set[str]:
all_attrs: set[str] = set()
attrs_by_origin: dict[str, set[str]] = {}
for oc in self.object_class_path(oc):
cur = set()
if required & self.AttrType.Must:
cur |= set(oc.must)
if required & self.AttrType.May:
cur |= set(oc.may)
if cur:
all_attrs |= cur
attrs_by_origin[oc] = cur
return attrs_by_origin if origins else all_attrs
def is_derived_class(self, name, base_candidate):
#oid = self.object_class_by_name[name].oid
#base_oid = self.object_class_by_name[base_candidate].oid
#if base_oid in [oc.oid for oc in self.object_class_path(name)]:
# return True
return nx.has_path(self.object_class_tree, self.object_class_by_name[name.lower()].oid, self.object_class_by_name[base_candidate.lower()].oid)
def default_config() -> Config: # export
return Config()
def bind(conf: Config|BaseConfig|None=None) -> Connection:
return Connection(conf)