mirror of
ssh://git.janware.com/srv/git/janware/proj/jw-python
synced 2026-04-24 10:03:36 +02:00
The entry instance passed to the walk callback contains raw results of LDAP search operations, i.e. all attribute values are lists, and all attribute values are bytes. Add the boolean parameters decode and unroll to walk() as a convenience method to get decoded values. They default to False, representing current behaviour. Signed-off-by: Jan Lindemann <jan@janware.com>
322 lines
12 KiB
Python
322 lines
12 KiB
Python
# -*- coding: utf-8 -*-
|
|
|
|
import ldap, getpass, pathlib, copy
|
|
from ldap.schema.models import ObjectClass
|
|
from enum import Flag, auto
|
|
import networkx as nx
|
|
from typing import Any, Self
|
|
from collections.abc import Callable
|
|
|
|
from .Config import Config as BaseConfig
|
|
from .log import *
|
|
|
|
class Config:
|
|
def __init__(self, external: BaseConfig|None=None):
|
|
self.__external = external
|
|
for attr in ['ldap_uri', 'bind_dn', 'bind_pw', 'base_dn']:
|
|
setattr(self, '_Config__' + attr, None)
|
|
|
|
def __get(self, key: str, default: str):
|
|
if not self.__external:
|
|
return default
|
|
return self.__external.value(key, default=default)
|
|
|
|
@property
|
|
def ldap_uri(self):
|
|
if self.__ldap_uri is None:
|
|
for key in ['ldap_uri', 'uri']:
|
|
self.__ldap_uri = self.__get(key, default=None)
|
|
if self.__ldap_uri is not None:
|
|
break
|
|
else:
|
|
self.__ldap_uri = 'ldap://ldap.janware.com'
|
|
return self.__ldap_uri
|
|
@ldap_uri.setter
|
|
def ldap_uri(self, rhs):
|
|
self.__ldap_uri = rhs
|
|
|
|
@property
|
|
def bind_dn(self):
|
|
if self.__bind_dn is None:
|
|
self.__bind_dn = self.__get('bind_dn', default=f'uid={getpass.getuser()},ou=users,dc=jannet,dc=de')
|
|
return self.__bind_dn
|
|
@bind_dn.setter
|
|
def bind_dn(self, rhs):
|
|
self.__bind_dn = rhs
|
|
|
|
@property
|
|
def bind_pw(self):
|
|
if self.__bind_pw is None:
|
|
for key in ['bind_pw', 'password']:
|
|
ret = self.__get(key, default=None)
|
|
if ret is not None:
|
|
break
|
|
if ret is None:
|
|
ldap_secret_file = self.__get('secret_file', f'{pathlib.Path.home()}/.ldap.secret')
|
|
with open(ldap_secret_file, 'r') as file:
|
|
ret = file.read()
|
|
file.closed
|
|
ret = ret.strip()
|
|
self.__bind_pw = ret
|
|
return self.__bind_pw
|
|
@bind_pw.setter
|
|
def bind_pw(self, rhs):
|
|
self.__bind_pw = rhs
|
|
|
|
@property
|
|
def base_dn(self):
|
|
if self.__base_dn is None:
|
|
self.__base_dn = self.__get('base_dn', default=f'dc=jannet,dc=de')
|
|
return self.__base_dn
|
|
@base_dn.setter
|
|
def base_dn(self, rhs):
|
|
self.__base_dn = rhs
|
|
|
|
class Connection: # export
|
|
|
|
class AttrType(Flag):
|
|
Must = auto()
|
|
May = auto()
|
|
|
|
def __init__(self, conf: Config|BaseConfig|None=None, backtrace=False):
|
|
uri: str|None = None
|
|
c = conf if isinstance(conf, Config) else Config(conf)
|
|
try:
|
|
uri = c.ldap_uri
|
|
except:
|
|
uri = c.uri
|
|
try:
|
|
ret = ldap.initialize(uri)
|
|
ret.start_tls_s()
|
|
except Exception as e:
|
|
slog(ERR, f'Failed to initialize LDAP connection to "{uri}" ({str(e)})')
|
|
raise
|
|
try:
|
|
rr = ret.bind_s(c.bind_dn, c.bind_pw) # method)
|
|
except Exception as e:
|
|
slog(ERR, f'Failed to bind to "{uri}" with dn "{c.bind_dn}" ({str(e)})')
|
|
raise
|
|
self.__ldap = ret
|
|
self.__backtrace = backtrace
|
|
self.__object_classes_by_oid: dict[str, ObjectClass]|None = None
|
|
self.__object_class_tree: nx.Graph|None = None
|
|
self.__object_classes_by_name: dict[str, ObjectClass]|None = None
|
|
|
|
@property
|
|
def ldap(self):
|
|
return self.__ldap
|
|
|
|
def add(self, attrs: dict[str, bytes], dn: str|None=None):
|
|
if dn is None:
|
|
if not 'dn' in attrs:
|
|
raise Exception('No DN to add an LDAP entry to')
|
|
attrs = copy.deepcopy(attrs)
|
|
del attrs['dn']
|
|
try:
|
|
slog(INFO, f'LDAP: Add [{dn}] -> {attrs}')
|
|
self.__ldap.add_s(dn, ldap.modlist.addModlist(attrs))
|
|
except Exception as e:
|
|
slog(ERR, f'{dn}: Failed to add entry {attrs} ({e})')
|
|
raise
|
|
|
|
def delete(self, dn: str, recursive=False, force_existence: bool=False):
|
|
|
|
def __walk_cb_delete(conn: Connection, entry, context):
|
|
self.walk(__walk_cb_delete, base=entry[0], scope=ldap.SCOPE_ONELEVEL, context=context)
|
|
self.__ldap.delete_s(entry[0])
|
|
|
|
try:
|
|
if recursive:
|
|
self.walk(__walk_cb_delete, dn, scope=ldap.SCOPE_ONELEVEL)
|
|
self.__ldap.delete_s(dn)
|
|
else:
|
|
self.__ldap.delete_s(dn)
|
|
except ldap.NO_SUCH_OBJECT as e:
|
|
if force_existence:
|
|
raise
|
|
except Exception as e:
|
|
slog(ERR, f'Failed to delete {dn} ({e})')
|
|
raise
|
|
|
|
def walk(
|
|
self,
|
|
callback: Callable[[Self, Any, Any], None],
|
|
base: str,
|
|
scope,
|
|
context=None,
|
|
filterstr=None,
|
|
attrlist=None,
|
|
attrsonly=0,
|
|
serverctrls=None,
|
|
clientctrls=None,
|
|
timeout=-1,
|
|
sizelimit=0,
|
|
decode: bool=False,
|
|
unroll: bool=False
|
|
):
|
|
|
|
# TODO: Support ignored arguments
|
|
search_return = self.__ldap.search(base=base,
|
|
scope=scope,
|
|
filterstr=filterstr,
|
|
attrlist=attrlist,
|
|
attrsonly=attrsonly)
|
|
while True:
|
|
result_type, result_data = self.__ldap.result(search_return, 0)
|
|
if (result_data == []):
|
|
break
|
|
if result_type != ldap.RES_SEARCH_ENTRY:
|
|
continue
|
|
for entry in result_data:
|
|
if decode:
|
|
entry = entry[0], {key: [val.decode() for val in vals] for key, vals in entry[1].items()}
|
|
if unroll and False:
|
|
entry = entry[0], {key: val[0] for key, val in entry[1].items()}
|
|
try:
|
|
callback(self, entry, context)
|
|
except Exception as e:
|
|
msg = f'Exception {e}'
|
|
if self.__backtrace:
|
|
slog(ERR, msg)
|
|
raise
|
|
slog(WARNING, msg)
|
|
continue
|
|
|
|
def find(self,
|
|
base: str,
|
|
scope,
|
|
filterstr=None,
|
|
attrlist=None,
|
|
attrsonly=0,
|
|
serverctrls=None,
|
|
clientctrls=None,
|
|
timeout=-1,
|
|
sizelimit=0,
|
|
assert_unique=False,
|
|
assert_not_empty=False,
|
|
):
|
|
|
|
def __walk_cb_find(conn: Connection, entry: Any, context: Any):
|
|
result.append(entry)
|
|
|
|
def __search():
|
|
return f'{base} -> "{filterstr}"'
|
|
|
|
try:
|
|
result: list[Any] = []
|
|
self.walk(__walk_cb_find, base, scope=scope, filterstr=filterstr, attrlist=attrlist)
|
|
except Exception as e:
|
|
slog(ERR, f'Failed search {__search()} ({e})')
|
|
raise
|
|
if assert_not_empty and not result:
|
|
raise Exception(f'Empty result for search {__search()}')
|
|
if assert_unique and len(result) > 1:
|
|
raise Exception(f'Found {len(result)} results for search {__search()}')
|
|
return result
|
|
|
|
@property
|
|
def object_classes(self) -> dict[str, ObjectClass]:
|
|
#def object_classes(self):
|
|
if self.__object_classes_by_oid is None:
|
|
res = self.find(base='', scope=ldap.SCOPE_BASE, filterstr='(objectClass=*)', attrlist=['subschemaSubentry'])
|
|
dn = res[0][1]['subschemaSubentry'][0].decode('utf-8') # Usually yields cn=Subschema
|
|
res = self.find(base=dn, scope=ldap.SCOPE_BASE, filterstr='(objectClass=*)', attrlist=['*', '+'])
|
|
subschema_entry = res[0]
|
|
subschema_subentry = ldap.cidict.cidict(subschema_entry[1])
|
|
subschema = ldap.schema.SubSchema(subschema_subentry)
|
|
object_class_oids = subschema.listall(ObjectClass)
|
|
self.__object_classes_by_oid = {
|
|
oid: subschema.get_obj(ObjectClass, oid) for oid in object_class_oids
|
|
}
|
|
return self.__object_classes_by_oid
|
|
|
|
@property
|
|
def object_class_by_name(self) -> dict[str, ObjectClass]:
|
|
if self.__object_classes_by_name is None:
|
|
ret: dict[str, ObjectClass] = {}
|
|
self.__object_classes_by_name = ret
|
|
for oid, oc in self.object_classes.items():
|
|
ret[oid] = oc
|
|
for name in oc.names:
|
|
ret[name] = oc
|
|
ret[name.lower()] = oc
|
|
return self.__object_classes_by_name
|
|
|
|
def __oc_recurse_to_top(self, cur: str|ObjectClass, cb, context):
|
|
cur_oc = cur if isinstance(cur, ObjectClass) else self.object_class_by_name[cur.lower()]
|
|
for s in cur_oc.sup:
|
|
self.__oc_recurse_to_top(s, cb, context)
|
|
cb(cur_oc, context)
|
|
|
|
def object_class_path(self, leaf: str|ObjectClass):
|
|
def cb(oc, context):
|
|
ret.append(oc)
|
|
ret: list[str] = []
|
|
self.__oc_recurse_to_top(leaf, cb, None)
|
|
return reversed(ret)
|
|
|
|
@property
|
|
def object_class_tree(self) -> nx.Graph:
|
|
|
|
if self.__object_class_tree is None:
|
|
|
|
def collect(root, attr):
|
|
ret = set()
|
|
def cb(oc, attr):
|
|
vals = getattr(oc, attr)
|
|
if vals is None:
|
|
return
|
|
for val in vals:
|
|
ret.add(val)
|
|
self.__oc_recurse_to_top(root, cb, attr)
|
|
return ret
|
|
|
|
kind = {
|
|
0: 'STRUCTURAL',
|
|
1: 'ABSTRACT',
|
|
2: 'AUXILIARY'
|
|
}
|
|
ret = nx.DiGraph()
|
|
for oid, oc in self.object_classes.items():
|
|
ret.add_node(
|
|
oid,
|
|
oid=oid,
|
|
name=oc.names[0],
|
|
kind=kind[oc.kind],
|
|
must=', '.join(collect(oc, 'must')),
|
|
may=', '.join(collect(oc, 'may'))
|
|
)
|
|
for base_class in oc.sup:
|
|
try:
|
|
ret.add_edge(oid, self.object_class_by_name[base_class.lower()].oid)
|
|
except Exception as e:
|
|
slog(WARNING, f'Failed to add edge {oid}:{oc.names} -> {base_class} ({e})')
|
|
self.__object_class_tree = ret
|
|
return self.__object_class_tree
|
|
|
|
def object_class_attrs(self, oc: str|ObjectClass, required: AttrType = AttrType.Must, origins: bool=False) -> dict[str, set[str]]|set[str]:
|
|
all_attrs: set[str] = set()
|
|
attrs_by_origin: dict[str, set[str]] = {}
|
|
for oc in self.object_class_path(oc):
|
|
cur = set()
|
|
if required & self.AttrType.Must:
|
|
cur |= set(oc.must)
|
|
if required & self.AttrType.May:
|
|
cur |= set(oc.may)
|
|
if cur:
|
|
all_attrs |= cur
|
|
attrs_by_origin[oc] = cur
|
|
return attrs_by_origin if origins else all_attrs
|
|
|
|
def is_derived_class(self, name, base_candidate):
|
|
#oid = self.object_class_by_name[name].oid
|
|
#base_oid = self.object_class_by_name[base_candidate].oid
|
|
#if base_oid in [oc.oid for oc in self.object_class_path(name)]:
|
|
# return True
|
|
return nx.has_path(self.object_class_tree, self.object_class_by_name[name.lower()].oid, self.object_class_by_name[base_candidate.lower()].oid)
|
|
|
|
def default_config() -> Config: # export
|
|
return Config()
|
|
|
|
def bind(conf: Config|BaseConfig|None=None) -> Connection:
|
|
return Connection(conf)
|