# -*- coding: utf-8 -*- import ldap, getpass, pathlib, copy from ldap.schema.models import ObjectClass from enum import Flag, auto import networkx as nx from typing import Any, Self from collections.abc import Callable from .Config import Config as BaseConfig from .log import * class Config: def __init__(self, external: BaseConfig|None=None): self.__external = external for attr in ['ldap_uri', 'bind_dn', 'bind_pw', 'base_dn']: setattr(self, '_Config__' + attr, None) def __get(self, key: str, default: str): if not self.__external: return default return self.__external.value(key, default=default) @property def ldap_uri(self): if self.__ldap_uri is None: self.__ldap_uri = self.__get('ldap_uri', default='ldap://ldap.janware.com') return self.__ldap_uri @ldap_uri.setter def ldap_uri(self, rhs): self.__ldap_uri = rhs @property def bind_dn(self): if self.__bind_dn is None: self.__bind_dn = self.__get('bind_dn', default=f'uid={getpass.getuser()},ou=users,dc=jannet,dc=de') return self.__bind_dn @bind_dn.setter def bind_dn(self, rhs): self.__bind_dn = rhs @property def bind_pw(self): if self.__bind_pw is None: for key in ['bind_pw', 'password']: ret = self.__get(key, default=None) if ret is not None: break if ret is None: ldap_secret_file = self.__get('secret_file', f'{pathlib.Path.home()}/.ldap.secret') with open(ldap_secret_file, 'r') as file: ret = file.read() file.closed ret = ret.strip() self.__bind_pw = ret return self.__bind_pw @bind_pw.setter def bind_pw(self, rhs): self.__bind_pw = rhs @property def base_dn(self): if self.__base_dn is None: self.__base_dn = self.__get('base_dn', default=f'dc=jannet,dc=de') return self.__base_dn @base_dn.setter def base_dn(self, rhs): self.__base_dn = rhs class Connection: # export class AttrType(Flag): Must = auto() May = auto() def __init__(self, conf: Config|BaseConfig|None=None, backtrace=False): c = conf if isinstance(conf, Config) else Config(conf) ret = ldap.initialize(c.ldap_uri) ret.start_tls_s() try: rr = ret.bind_s(c.bind_dn, c.bind_pw) # method) except Exception as e: raise Exception(f'Failed to bind to {c.ldap_uri} with dn {c.bind_dn} ({e})') self.__ldap = ret self.__backtrace = backtrace self.__object_classes_by_oid: dict[str, ObjectClass]|None = None self.__object_class_tree: nx.Graph|None = None self.__object_classes_by_name: dict[str, ObjectClass]|None = None @property def ldap(self): return self.__ldap def add(self, attrs: dict[str, bytes], dn: str|None=None): if dn is None: if not 'dn' in attrs: raise Exception('No DN to add an LDAP entry to') attrs = copy.deepcopy(attrs) del attrs['dn'] try: slog(INFO, f'LDAP: Add [{dn}] -> {attrs}') self.__ldap.add_s(dn, ldap.modlist.addModlist(attrs)) except Exception as e: slog(ERR, f'{dn}: Failed to add entry {attrs} ({e})') raise def delete(self, dn: str, recursive=False, force_existence: bool=False): def __walk_cb_delete(conn: Connection, entry, context): self.walk(__walk_cb_delete, base=entry[0], scope=ldap.SCOPE_ONELEVEL, context=context) self.__ldap.delete_s(entry[0]) try: if recursive: self.walk(__walk_cb_delete, dn, scope=ldap.SCOPE_ONELEVEL) self.__ldap.delete_s(dn) else: self.__ldap.delete_s(dn) except ldap.NO_SUCH_OBJECT as e: if force_existence: raise except Exception as e: slog(ERR, f'Failed to delete {dn} ({e})') raise def walk(self, callback: Callable[[Self, Any, Any], None], base: str, scope, context=None, filterstr=None, attrlist=None, attrsonly=0, serverctrls=None, clientctrls=None, timeout=-1, sizelimit=0): # TODO: Support ignored arguments search_return = self.__ldap.search(base=base, scope=scope, filterstr=filterstr, attrlist=attrlist, attrsonly=attrsonly) while True: result_type, result_data = self.__ldap.result(search_return, 0) if (result_data == []): break if result_type != ldap.RES_SEARCH_ENTRY: continue for entry in result_data: try: callback(self, entry, context) except Exception as e: msg = f'Exception {e}' if self.__backtrace: slog(ERR, msg) raise slog(WARNING, msg) continue def find(self, base: str, scope, filterstr=None, attrlist=None, attrsonly=0, serverctrls=None, clientctrls=None, timeout=-1, sizelimit=0, assert_unique=False, assert_not_empty=False, ): def __walk_cb_find(conn: Connection, entry: Any, context: Any): result.append(entry) def __search(): return f'{base} -> "{filterstr}"' try: result: list[Any] = [] self.walk(__walk_cb_find, base, scope=scope, filterstr=filterstr, attrlist=attrlist) except Exception as e: slog(ERR, f'Failed search {__search()} ({e})') raise if assert_not_empty and not result: raise Exception(f'Empty result for search {__search()}') if assert_unique and len(result) > 1: raise Exception(f'Found {len(result)} results for search {__search()}') return result @property def object_classes(self) -> dict[str, ObjectClass]: #def object_classes(self): if self.__object_classes_by_oid is None: res = self.find(base='', scope=ldap.SCOPE_BASE, filterstr='(objectClass=*)', attrlist=['subschemaSubentry']) dn = res[0][1]['subschemaSubentry'][0].decode('utf-8') # Usually yields cn=Subschema res = self.find(base=dn, scope=ldap.SCOPE_BASE, filterstr='(objectClass=*)', attrlist=['*', '+']) subschema_entry = res[0] subschema_subentry = ldap.cidict.cidict(subschema_entry[1]) subschema = ldap.schema.SubSchema(subschema_subentry) object_class_oids = subschema.listall(ObjectClass) self.__object_classes_by_oid = { oid: subschema.get_obj(ObjectClass, oid) for oid in object_class_oids } return self.__object_classes_by_oid @property def object_class_by_name(self) -> dict[str, ObjectClass]: if self.__object_classes_by_name is None: ret: dict[str, ObjectClass] = {} self.__object_classes_by_name = ret for oid, oc in self.object_classes.items(): ret[oid] = oc for name in oc.names: ret[name] = oc ret[name.lower()] = oc return self.__object_classes_by_name def __oc_recurse_to_top(self, cur: str|ObjectClass, cb, context): cur_oc = cur if isinstance(cur, ObjectClass) else self.object_class_by_name[cur.lower()] for s in cur_oc.sup: self.__oc_recurse_to_top(s, cb, context) cb(cur_oc, context) def object_class_path(self, leaf: str|ObjectClass): def cb(oc, context): ret.append(oc) ret: list[str] = [] self.__oc_recurse_to_top(leaf, cb, None) return reversed(ret) @property def object_class_tree(self) -> nx.Graph: if self.__object_class_tree is None: def collect(root, attr): ret = set() def cb(oc, attr): vals = getattr(oc, attr) if vals is None: return for val in vals: ret.add(val) self.__oc_recurse_to_top(root, cb, attr) return ret kind = { 0: 'STRUCTURAL', 1: 'ABSTRACT', 2: 'AUXILIARY' } ret = nx.DiGraph() for oid, oc in self.object_classes.items(): ret.add_node( oid, oid=oid, name=oc.names[0], kind=kind[oc.kind], must=', '.join(collect(oc, 'must')), may=', '.join(collect(oc, 'may')) ) for base_class in oc.sup: try: ret.add_edge(oid, self.object_class_by_name[base_class.lower()].oid) except Exception as e: slog(WARNING, f'Failed to add edge {oid}:{oc.names} -> {base_class} ({e})') self.__object_class_tree = ret return self.__object_class_tree def object_class_attrs(self, oc: str|ObjectClass, required: AttrType = AttrType.Must, origins: bool=False) -> dict[str, set[str]]|set[str]: all_attrs: set[str] = set() attrs_by_origin: dict[str, set[str]] = {} for oc in self.object_class_path(oc): cur = set() if required & self.AttrType.Must: cur |= set(oc.must) if required & self.AttrType.May: cur |= set(oc.may) if cur: all_attrs |= cur attrs_by_origin[oc] = cur return attrs_by_origin if origins else all_attrs def is_derived_class(self, name, base_candidate): #oid = self.object_class_by_name[name].oid #base_oid = self.object_class_by_name[base_candidate].oid #if base_oid in [oc.oid for oc in self.object_class_path(name)]: # return True return nx.has_path(self.object_class_tree, self.object_class_by_name[name.lower()].oid, self.object_class_by_name[base_candidate.lower()].oid) def default_config() -> Config: # export return Config() def bind(conf: Config|BaseConfig|None=None) -> Connection: return Connection(conf)