jw-python/tools/python/jwutils/ldap.py

298 lines
10 KiB
Python
Raw Normal View History

# -*- coding: utf-8 -*-
import ldap, getpass, pathlib, copy
from ldap.schema.models import ObjectClass
from enum import Flag, auto
import networkx as nx
from typing import Any, Self
from collections.abc import Callable
from .Config import Config as BaseConfig
from .log import *
class Config:
def __init__(self, external: BaseConfig|None=None):
self.__external = external
for attr in ['ldap_uri', 'bind_dn', 'bind_pw', 'base_dn']:
setattr(self, '_Config__' + attr, None)
def __get(self, key: str, default: str):
if not self.__external:
return default
return self.__external.value(key, default=default)
@property
def ldap_uri(self):
if self.__ldap_uri is None:
self.__ldap_uri = self.__get('ldap_uri', default='ldap://ldap.janware.com')
return self.__ldap_uri
@ldap_uri.setter
def ldap_uri(self, rhs):
self.__ldap_uri = rhs
@property
def bind_dn(self):
if self.__bind_dn is None:
self.__bind_dn = self.__get('bind_dn', default=f'uid={getpass.getuser()},ou=users,dc=jannet,dc=de')
return self.__bind_dn
@bind_dn.setter
def bind_dn(self, rhs):
self.__bind_dn = rhs
@property
def bind_pw(self):
if self.__bind_pw is None:
ret = self.__get('bind_pw', default=None)
if ret is None:
ldap_secret_file = self.__get('secret_file', f'{pathlib.Path.home()}/.ldap.secret')
with open(ldap_secret_file, 'r') as file:
ret = file.read()
file.closed
ret = ret.strip()
self.__bind_pw = ret
return self.__bind_pw
@bind_pw.setter
def bind_pw(self, rhs):
self.__bind_pw = rhs
@property
def base_dn(self):
if self.__base_dn is None:
self.__base_dn = self.__get('base_dn', default=f'dc=jannet,dc=de')
return self.__base_dn
@base_dn.setter
def base_dn(self, rhs):
self.__base_dn = rhs
class Connection: # export
class AttrType(Flag):
Must = auto()
May = auto()
def __init__(self, conf: Config|BaseConfig|None=None, backtrace=False):
c = conf if isinstance(conf, Config) else Config(conf)
ret = ldap.initialize(c.ldap_uri)
ret.start_tls_s()
try:
rr = ret.bind_s(c.bind_dn, c.bind_pw) # method)
except Exception as e:
raise Exception(f'Failed to bind to {c.ldap_uri} with dn {c.bind_dn} ({e})')
self.__ldap = ret
self.__backtrace = backtrace
self.__object_classes: dict[str, ObjectClass]|None = None
self.__object_class_tree: nx.Graph|None = None
self.__object_classes_by_name: dict[str, ObjectClass]|None = None
@property
def ldap(self):
return self.__ldap
def add(self, attrs: dict[str, bytes], dn: str|None=None):
if dn is None:
dn = attrs.get('dn')
if dn is None:
raise Exception('No DN to add an LDAP entry to')
else:
attrs = copy.deepcopy(attrs)
del attrs['dn']
try:
slog(INFO, f'LDAP: Add [{dn}] -> {attrs}')
self.__ldap.add_s(dn, ldap.modlist.addModlist(attrs))
except Exception as e:
slog(ERR, f'{dn}: Failed to add entry {attrs} ({e})')
raise
def delete(self, dn: str, recursive=False, force_existence: bool=False):
def __walk_cb_delete(conn: Connection, entry, context):
self.walk(__walk_cb_delete, base=entry[0], scope=ldap.SCOPE_ONELEVEL, context=context)
self.__ldap.delete_s(entry[0])
try:
if recursive:
self.walk(__walk_cb_delete, dn, scope=ldap.SCOPE_ONELEVEL)
self.__ldap.delete_s(dn)
else:
self.__ldap.delete_s(dn)
except ldap.NO_SUCH_OBJECT as e:
if force_existence:
raise
except Exception as e:
slog(ERR, f'Failed to delete {dn} ({e})')
raise
def walk(self,
callback: Callable[[Self, Any, Any], None],
base: str,
scope,
context=None,
filterstr=None,
attrlist=None,
attrsonly=0,
serverctrls=None,
clientctrls=None,
timeout=-1,
sizelimit=0):
# TODO: Support ignored arguments
search_return = self.__ldap.search(base=base,
scope=scope,
filterstr=filterstr,
attrlist=attrlist,
attrsonly=attrsonly)
while True:
result_type, result_data = self.__ldap.result(search_return, 0)
if (result_data == []):
break
if result_type != ldap.RES_SEARCH_ENTRY:
continue
for entry in result_data:
try:
callback(self, entry, context)
except Exception as e:
msg = f'Exception {e}'
if self.__backtrace:
slog(ERR, msg)
raise
slog(WARNING, msg)
continue
def find(self,
base: str,
scope,
filterstr=None,
attrlist=None,
attrsonly=0,
serverctrls=None,
clientctrls=None,
timeout=-1,
sizelimit=0,
assert_unique=False,
assert_not_empty=False,
):
def __walk_cb_find(conn: Connection, entry, context):
result.append(entry)
def __search():
return f'{base} -> "{filterstr}"'
try:
result = []
self.walk(__walk_cb_find, base, scope=scope, filterstr=filterstr, attrlist=attrlist)
except Exception as e:
slog(ERR, f'Failed search {__search()} ({e})')
raise
if assert_not_empty and not result:
raise Exception(f'Empty result for search {__search()}')
if assert_unique and len(result) > 1:
raise Exception(f'Found {len(result)} results for search {__search()}')
return result
@property
def object_classes(self) -> list[ObjectClass]:
#def object_classes(self):
if self.__object_classes is None:
res = self.find(base='', scope=ldap.SCOPE_BASE, filterstr='(objectClass=*)', attrlist=['subschemaSubentry'])
dn = res[0][1]['subschemaSubentry'][0].decode('utf-8') # Usually yields cn=Subschema
res = self.find(base=dn, scope=ldap.SCOPE_BASE, filterstr='(objectClass=*)', attrlist=['*', '+'])
subschema_entry = res[0]
subschema_subentry = ldap.cidict.cidict(subschema_entry[1])
subschema = ldap.schema.SubSchema(subschema_subentry)
object_class_oids = subschema.listall(ObjectClass)
self.__object_classes = {
oid: subschema.get_obj(ObjectClass, oid) for oid in object_class_oids
}
return self.__object_classes
@property
def object_class_by_name(self):
if self.__object_classes_by_name is None:
ret: dict[str, ObjectClass] = {}
self.__object_classes_by_name = ret
for oid, oc in self.object_classes.items():
ret[oid] = oc
for name in oc.names:
ret[name] = oc
ret[name.lower()] = oc
return self.__object_classes_by_name
def __oc_recurse_to_top(self, cur: str|ObjectClass, cb, context):
cur_oc = cur if isinstance(cur, ObjectClass) else self.object_class_by_name[cur.lower()]
for s in cur_oc.sup:
self.__oc_recurse_to_top(s, cb, context)
cb(cur_oc, context)
def object_class_path(self, leaf: str|ObjectClass):
def cb(oc, context):
ret.append(oc)
ret = []
self.__oc_recurse_to_top(leaf, cb, None)
return reversed(ret)
@property
def object_class_tree(self) -> nx.Graph:
if self.__object_class_tree is None:
def collect(root, attr):
ret = set()
def cb(oc, attr):
vals = getattr(oc, attr)
if vals is None:
return
for val in vals:
ret.add(val)
self.__oc_recurse_to_top(root, cb, attr)
return ret
kind = {
0: 'STRUCTURAL',
1: 'ABSTRACT',
2: 'AUXILIARY'
}
ret = nx.DiGraph()
for oid, oc in self.object_classes.items():
ret.add_node(
oid,
oid=oid,
name=oc.names[0],
kind=kind[oc.kind],
must=', '.join(collect(oc, 'must')),
may=', '.join(collect(oc, 'may'))
)
for base_class in oc.sup:
try:
ret.add_edge(oid, self.object_class_by_name[base_class.lower()].oid)
except Exception as e:
slog(WARNING, f'Failed to add edge {oid}:{oc.names} -> {base_class} ({e})')
self.__object_class_tree = ret
return self.__object_class_tree
def object_class_attrs(self, oc: str|ObjectClass, required: AttrType = AttrType.Must, origins: bool=False) -> dict[str, set[str]]|set[str]:
all_attrs: set[str] = set()
attrs_by_origin: dict[str, set[str]] = {}
for oc in self.object_class_path(oc):
cur = set()
if required & self.AttrType.Must:
cur |= set(oc.must)
if required & self.AttrType.May:
cur |= set(oc.may)
if cur:
all_attrs |= cur
attrs_by_origin[oc] = cur
return attrs_by_origin if origins else all_attrs
def is_derived_class(self, name, base_candidate):
#oid = self.object_class_by_name[name].oid
#base_oid = self.object_class_by_name[base_candidate].oid
#if base_oid in [oc.oid for oc in self.object_class_path(name)]:
# return True
return nx.has_path(self.object_class_tree, self.object_class_by_name[name.lower()].oid, self.object_class_by_name[base_candidate.lower()].oid)
def default_config() -> Config: # export
return Config()
def bind(conf: Config|BaseConfig|None=None) -> Connection:
return Connection(conf)