From 1e43cdc7151890376a303fd5cb4ab0286931b943 Mon Sep 17 00:00:00 2001 From: Jan Lindemann Date: Sat, 26 Jul 2025 13:51:42 +0200 Subject: [PATCH] jwutils.ldap: Add module Signed-off-by: Jan Lindemann --- tools/python/jwutils/auth/ldap/Auth.py | 19 +- tools/python/jwutils/ldap.py | 298 +++++++++++++++++++++++++ 2 files changed, 300 insertions(+), 17 deletions(-) create mode 100644 tools/python/jwutils/ldap.py diff --git a/tools/python/jwutils/auth/ldap/Auth.py b/tools/python/jwutils/auth/ldap/Auth.py index e4c398a..27674f3 100644 --- a/tools/python/jwutils/auth/ldap/Auth.py +++ b/tools/python/jwutils/auth/ldap/Auth.py @@ -5,7 +5,7 @@ from typing import Optional, Union import ldap from ...log import * -from ... import Config +from ...ldap import bind from .. import Access from .. import Auth as AuthBase from .. import Group as GroupBase @@ -62,22 +62,7 @@ class Auth(AuthBase): # export self.__dummy = self.load('dummy', conf) def __bind(self): - ldap_uri = self.conf['ldap_uri'] - bind_dn = self.conf['bind_dn'] - bind_pw = self.conf.get('password') - if bind_pw is None: - with open(ldap_secret_file, 'r') as file: - bind_pw = file.read() - file.closed - bind_pw = bind_pw.strip() - ret = ldap.initialize(ldap_uri) - ret.start_tls_s() - try: - rr = ret.bind_s(bind_dn, bind_pw) # method) - except Exception as e: - #pw = f' (pw={bind_pw})' - raise Exception(f'Failed to bind to {ldap_uri} with dn {bind_dn} ({e})') - return ret + return bind(self.conf) @property def __users(self) -> User: diff --git a/tools/python/jwutils/ldap.py b/tools/python/jwutils/ldap.py new file mode 100644 index 0000000..d8d9946 --- /dev/null +++ b/tools/python/jwutils/ldap.py @@ -0,0 +1,298 @@ +# -*- coding: utf-8 -*- + +import ldap, getpass, pathlib, copy +from ldap.schema.models import ObjectClass +from enum import Flag, auto +import networkx as nx +from typing import Any, Self +from collections.abc import Callable + +from .Config import Config as BaseConfig +from .log import * + +class Config: + def __init__(self, external: BaseConfig|None=None): + self.__external = external + for attr in ['ldap_uri', 'bind_dn', 'bind_pw', 'base_dn']: + setattr(self, '_Config__' + attr, None) + + def __get(self, key: str, default: str): + if not self.__external: + return default + return self.__external.value(key, default=default) + + @property + def ldap_uri(self): + if self.__ldap_uri is None: + self.__ldap_uri = self.__get('ldap_uri', default='ldap://ldap.janware.com') + return self.__ldap_uri + @ldap_uri.setter + def ldap_uri(self, rhs): + self.__ldap_uri = rhs + + @property + def bind_dn(self): + if self.__bind_dn is None: + self.__bind_dn = self.__get('bind_dn', default=f'uid={getpass.getuser()},ou=users,dc=jannet,dc=de') + return self.__bind_dn + @bind_dn.setter + def bind_dn(self, rhs): + self.__bind_dn = rhs + + @property + def bind_pw(self): + if self.__bind_pw is None: + ret = self.__get('bind_pw', default=None) + if ret is None: + ldap_secret_file = self.__get('secret_file', f'{pathlib.Path.home()}/.ldap.secret') + with open(ldap_secret_file, 'r') as file: + ret = file.read() + file.closed + ret = ret.strip() + self.__bind_pw = ret + return self.__bind_pw + @bind_pw.setter + def bind_pw(self, rhs): + self.__bind_pw = rhs + + @property + def base_dn(self): + if self.__base_dn is None: + self.__base_dn = self.__get('base_dn', default=f'dc=jannet,dc=de') + return self.__base_dn + @base_dn.setter + def base_dn(self, rhs): + self.__base_dn = rhs + +class Connection: # export + + class AttrType(Flag): + Must = auto() + May = auto() + + def __init__(self, conf: Config|BaseConfig|None=None, backtrace=False): + c = conf if isinstance(conf, Config) else Config(conf) + ret = ldap.initialize(c.ldap_uri) + ret.start_tls_s() + try: + rr = ret.bind_s(c.bind_dn, c.bind_pw) # method) + except Exception as e: + raise Exception(f'Failed to bind to {c.ldap_uri} with dn {c.bind_dn} ({e})') + self.__ldap = ret + self.__backtrace = backtrace + self.__object_classes: dict[str, ObjectClass]|None = None + self.__object_class_tree: nx.Graph|None = None + self.__object_classes_by_name: dict[str, ObjectClass]|None = None + + @property + def ldap(self): + return self.__ldap + + def add(self, attrs: dict[str, bytes], dn: str|None=None): + if dn is None: + dn = attrs.get('dn') + if dn is None: + raise Exception('No DN to add an LDAP entry to') + else: + attrs = copy.deepcopy(attrs) + del attrs['dn'] + try: + slog(INFO, f'LDAP: Add [{dn}] -> {attrs}') + self.__ldap.add_s(dn, ldap.modlist.addModlist(attrs)) + except Exception as e: + slog(ERR, f'{dn}: Failed to add entry {attrs} ({e})') + raise + + def delete(self, dn: str, recursive=False, force_existence: bool=False): + + def __walk_cb_delete(conn: Connection, entry, context): + self.walk(__walk_cb_delete, base=entry[0], scope=ldap.SCOPE_ONELEVEL, context=context) + self.__ldap.delete_s(entry[0]) + + try: + if recursive: + self.walk(__walk_cb_delete, dn, scope=ldap.SCOPE_ONELEVEL) + self.__ldap.delete_s(dn) + else: + self.__ldap.delete_s(dn) + except ldap.NO_SUCH_OBJECT as e: + if force_existence: + raise + except Exception as e: + slog(ERR, f'Failed to delete {dn} ({e})') + raise + + def walk(self, + callback: Callable[[Self, Any, Any], None], + base: str, + scope, + context=None, + filterstr=None, + attrlist=None, + attrsonly=0, + serverctrls=None, + clientctrls=None, + timeout=-1, + sizelimit=0): + + # TODO: Support ignored arguments + search_return = self.__ldap.search(base=base, + scope=scope, + filterstr=filterstr, + attrlist=attrlist, + attrsonly=attrsonly) + while True: + result_type, result_data = self.__ldap.result(search_return, 0) + if (result_data == []): + break + if result_type != ldap.RES_SEARCH_ENTRY: + continue + for entry in result_data: + try: + callback(self, entry, context) + except Exception as e: + msg = f'Exception {e}' + if self.__backtrace: + slog(ERR, msg) + raise + slog(WARNING, msg) + continue + + def find(self, + base: str, + scope, + filterstr=None, + attrlist=None, + attrsonly=0, + serverctrls=None, + clientctrls=None, + timeout=-1, + sizelimit=0, + assert_unique=False, + assert_not_empty=False, + ): + + def __walk_cb_find(conn: Connection, entry, context): + result.append(entry) + + def __search(): + return f'{base} -> "{filterstr}"' + + try: + result = [] + self.walk(__walk_cb_find, base, scope=scope, filterstr=filterstr, attrlist=attrlist) + except Exception as e: + slog(ERR, f'Failed search {__search()} ({e})') + raise + if assert_not_empty and not result: + raise Exception(f'Empty result for search {__search()}') + if assert_unique and len(result) > 1: + raise Exception(f'Found {len(result)} results for search {__search()}') + return result + + @property + def object_classes(self) -> list[ObjectClass]: + #def object_classes(self): + if self.__object_classes is None: + res = self.find(base='', scope=ldap.SCOPE_BASE, filterstr='(objectClass=*)', attrlist=['subschemaSubentry']) + dn = res[0][1]['subschemaSubentry'][0].decode('utf-8') # Usually yields cn=Subschema + res = self.find(base=dn, scope=ldap.SCOPE_BASE, filterstr='(objectClass=*)', attrlist=['*', '+']) + subschema_entry = res[0] + subschema_subentry = ldap.cidict.cidict(subschema_entry[1]) + subschema = ldap.schema.SubSchema(subschema_subentry) + object_class_oids = subschema.listall(ObjectClass) + self.__object_classes = { + oid: subschema.get_obj(ObjectClass, oid) for oid in object_class_oids + } + return self.__object_classes + + @property + def object_class_by_name(self): + if self.__object_classes_by_name is None: + ret: dict[str, ObjectClass] = {} + self.__object_classes_by_name = ret + for oid, oc in self.object_classes.items(): + ret[oid] = oc + for name in oc.names: + ret[name] = oc + ret[name.lower()] = oc + return self.__object_classes_by_name + + def __oc_recurse_to_top(self, cur: str|ObjectClass, cb, context): + cur_oc = cur if isinstance(cur, ObjectClass) else self.object_class_by_name[cur.lower()] + for s in cur_oc.sup: + self.__oc_recurse_to_top(s, cb, context) + cb(cur_oc, context) + + def object_class_path(self, leaf: str|ObjectClass): + def cb(oc, context): + ret.append(oc) + ret = [] + self.__oc_recurse_to_top(leaf, cb, None) + return reversed(ret) + + @property + def object_class_tree(self) -> nx.Graph: + + if self.__object_class_tree is None: + + def collect(root, attr): + ret = set() + def cb(oc, attr): + vals = getattr(oc, attr) + if vals is None: + return + for val in vals: + ret.add(val) + self.__oc_recurse_to_top(root, cb, attr) + return ret + + kind = { + 0: 'STRUCTURAL', + 1: 'ABSTRACT', + 2: 'AUXILIARY' + } + ret = nx.DiGraph() + for oid, oc in self.object_classes.items(): + ret.add_node( + oid, + oid=oid, + name=oc.names[0], + kind=kind[oc.kind], + must=', '.join(collect(oc, 'must')), + may=', '.join(collect(oc, 'may')) + ) + for base_class in oc.sup: + try: + ret.add_edge(oid, self.object_class_by_name[base_class.lower()].oid) + except Exception as e: + slog(WARNING, f'Failed to add edge {oid}:{oc.names} -> {base_class} ({e})') + self.__object_class_tree = ret + return self.__object_class_tree + + def object_class_attrs(self, oc: str|ObjectClass, required: AttrType = AttrType.Must, origins: bool=False) -> dict[str, set[str]]|set[str]: + all_attrs: set[str] = set() + attrs_by_origin: dict[str, set[str]] = {} + for oc in self.object_class_path(oc): + cur = set() + if required & self.AttrType.Must: + cur |= set(oc.must) + if required & self.AttrType.May: + cur |= set(oc.may) + if cur: + all_attrs |= cur + attrs_by_origin[oc] = cur + return attrs_by_origin if origins else all_attrs + + def is_derived_class(self, name, base_candidate): + #oid = self.object_class_by_name[name].oid + #base_oid = self.object_class_by_name[base_candidate].oid + #if base_oid in [oc.oid for oc in self.object_class_path(name)]: + # return True + return nx.has_path(self.object_class_tree, self.object_class_by_name[name.lower()].oid, self.object_class_by_name[base_candidate.lower()].oid) + +def default_config() -> Config: # export + return Config() + +def bind(conf: Config|BaseConfig|None=None) -> Connection: + return Connection(conf)