mirror of
ssh://git.janware.com/srv/git/janware/proj/jw-python
synced 2026-01-15 01:52:56 +01:00
jwutils.ldap: Add module
Signed-off-by: Jan Lindemann <jan@janware.com>
This commit is contained in:
parent
e4cee86482
commit
1e43cdc715
2 changed files with 300 additions and 17 deletions
|
|
@ -5,7 +5,7 @@ from typing import Optional, Union
|
|||
import ldap
|
||||
|
||||
from ...log import *
|
||||
from ... import Config
|
||||
from ...ldap import bind
|
||||
from .. import Access
|
||||
from .. import Auth as AuthBase
|
||||
from .. import Group as GroupBase
|
||||
|
|
@ -62,22 +62,7 @@ class Auth(AuthBase): # export
|
|||
self.__dummy = self.load('dummy', conf)
|
||||
|
||||
def __bind(self):
|
||||
ldap_uri = self.conf['ldap_uri']
|
||||
bind_dn = self.conf['bind_dn']
|
||||
bind_pw = self.conf.get('password')
|
||||
if bind_pw is None:
|
||||
with open(ldap_secret_file, 'r') as file:
|
||||
bind_pw = file.read()
|
||||
file.closed
|
||||
bind_pw = bind_pw.strip()
|
||||
ret = ldap.initialize(ldap_uri)
|
||||
ret.start_tls_s()
|
||||
try:
|
||||
rr = ret.bind_s(bind_dn, bind_pw) # method)
|
||||
except Exception as e:
|
||||
#pw = f' (pw={bind_pw})'
|
||||
raise Exception(f'Failed to bind to {ldap_uri} with dn {bind_dn} ({e})')
|
||||
return ret
|
||||
return bind(self.conf)
|
||||
|
||||
@property
|
||||
def __users(self) -> User:
|
||||
|
|
|
|||
298
tools/python/jwutils/ldap.py
Normal file
298
tools/python/jwutils/ldap.py
Normal file
|
|
@ -0,0 +1,298 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import ldap, getpass, pathlib, copy
|
||||
from ldap.schema.models import ObjectClass
|
||||
from enum import Flag, auto
|
||||
import networkx as nx
|
||||
from typing import Any, Self
|
||||
from collections.abc import Callable
|
||||
|
||||
from .Config import Config as BaseConfig
|
||||
from .log import *
|
||||
|
||||
class Config:
|
||||
def __init__(self, external: BaseConfig|None=None):
|
||||
self.__external = external
|
||||
for attr in ['ldap_uri', 'bind_dn', 'bind_pw', 'base_dn']:
|
||||
setattr(self, '_Config__' + attr, None)
|
||||
|
||||
def __get(self, key: str, default: str):
|
||||
if not self.__external:
|
||||
return default
|
||||
return self.__external.value(key, default=default)
|
||||
|
||||
@property
|
||||
def ldap_uri(self):
|
||||
if self.__ldap_uri is None:
|
||||
self.__ldap_uri = self.__get('ldap_uri', default='ldap://ldap.janware.com')
|
||||
return self.__ldap_uri
|
||||
@ldap_uri.setter
|
||||
def ldap_uri(self, rhs):
|
||||
self.__ldap_uri = rhs
|
||||
|
||||
@property
|
||||
def bind_dn(self):
|
||||
if self.__bind_dn is None:
|
||||
self.__bind_dn = self.__get('bind_dn', default=f'uid={getpass.getuser()},ou=users,dc=jannet,dc=de')
|
||||
return self.__bind_dn
|
||||
@bind_dn.setter
|
||||
def bind_dn(self, rhs):
|
||||
self.__bind_dn = rhs
|
||||
|
||||
@property
|
||||
def bind_pw(self):
|
||||
if self.__bind_pw is None:
|
||||
ret = self.__get('bind_pw', default=None)
|
||||
if ret is None:
|
||||
ldap_secret_file = self.__get('secret_file', f'{pathlib.Path.home()}/.ldap.secret')
|
||||
with open(ldap_secret_file, 'r') as file:
|
||||
ret = file.read()
|
||||
file.closed
|
||||
ret = ret.strip()
|
||||
self.__bind_pw = ret
|
||||
return self.__bind_pw
|
||||
@bind_pw.setter
|
||||
def bind_pw(self, rhs):
|
||||
self.__bind_pw = rhs
|
||||
|
||||
@property
|
||||
def base_dn(self):
|
||||
if self.__base_dn is None:
|
||||
self.__base_dn = self.__get('base_dn', default=f'dc=jannet,dc=de')
|
||||
return self.__base_dn
|
||||
@base_dn.setter
|
||||
def base_dn(self, rhs):
|
||||
self.__base_dn = rhs
|
||||
|
||||
class Connection: # export
|
||||
|
||||
class AttrType(Flag):
|
||||
Must = auto()
|
||||
May = auto()
|
||||
|
||||
def __init__(self, conf: Config|BaseConfig|None=None, backtrace=False):
|
||||
c = conf if isinstance(conf, Config) else Config(conf)
|
||||
ret = ldap.initialize(c.ldap_uri)
|
||||
ret.start_tls_s()
|
||||
try:
|
||||
rr = ret.bind_s(c.bind_dn, c.bind_pw) # method)
|
||||
except Exception as e:
|
||||
raise Exception(f'Failed to bind to {c.ldap_uri} with dn {c.bind_dn} ({e})')
|
||||
self.__ldap = ret
|
||||
self.__backtrace = backtrace
|
||||
self.__object_classes: dict[str, ObjectClass]|None = None
|
||||
self.__object_class_tree: nx.Graph|None = None
|
||||
self.__object_classes_by_name: dict[str, ObjectClass]|None = None
|
||||
|
||||
@property
|
||||
def ldap(self):
|
||||
return self.__ldap
|
||||
|
||||
def add(self, attrs: dict[str, bytes], dn: str|None=None):
|
||||
if dn is None:
|
||||
dn = attrs.get('dn')
|
||||
if dn is None:
|
||||
raise Exception('No DN to add an LDAP entry to')
|
||||
else:
|
||||
attrs = copy.deepcopy(attrs)
|
||||
del attrs['dn']
|
||||
try:
|
||||
slog(INFO, f'LDAP: Add [{dn}] -> {attrs}')
|
||||
self.__ldap.add_s(dn, ldap.modlist.addModlist(attrs))
|
||||
except Exception as e:
|
||||
slog(ERR, f'{dn}: Failed to add entry {attrs} ({e})')
|
||||
raise
|
||||
|
||||
def delete(self, dn: str, recursive=False, force_existence: bool=False):
|
||||
|
||||
def __walk_cb_delete(conn: Connection, entry, context):
|
||||
self.walk(__walk_cb_delete, base=entry[0], scope=ldap.SCOPE_ONELEVEL, context=context)
|
||||
self.__ldap.delete_s(entry[0])
|
||||
|
||||
try:
|
||||
if recursive:
|
||||
self.walk(__walk_cb_delete, dn, scope=ldap.SCOPE_ONELEVEL)
|
||||
self.__ldap.delete_s(dn)
|
||||
else:
|
||||
self.__ldap.delete_s(dn)
|
||||
except ldap.NO_SUCH_OBJECT as e:
|
||||
if force_existence:
|
||||
raise
|
||||
except Exception as e:
|
||||
slog(ERR, f'Failed to delete {dn} ({e})')
|
||||
raise
|
||||
|
||||
def walk(self,
|
||||
callback: Callable[[Self, Any, Any], None],
|
||||
base: str,
|
||||
scope,
|
||||
context=None,
|
||||
filterstr=None,
|
||||
attrlist=None,
|
||||
attrsonly=0,
|
||||
serverctrls=None,
|
||||
clientctrls=None,
|
||||
timeout=-1,
|
||||
sizelimit=0):
|
||||
|
||||
# TODO: Support ignored arguments
|
||||
search_return = self.__ldap.search(base=base,
|
||||
scope=scope,
|
||||
filterstr=filterstr,
|
||||
attrlist=attrlist,
|
||||
attrsonly=attrsonly)
|
||||
while True:
|
||||
result_type, result_data = self.__ldap.result(search_return, 0)
|
||||
if (result_data == []):
|
||||
break
|
||||
if result_type != ldap.RES_SEARCH_ENTRY:
|
||||
continue
|
||||
for entry in result_data:
|
||||
try:
|
||||
callback(self, entry, context)
|
||||
except Exception as e:
|
||||
msg = f'Exception {e}'
|
||||
if self.__backtrace:
|
||||
slog(ERR, msg)
|
||||
raise
|
||||
slog(WARNING, msg)
|
||||
continue
|
||||
|
||||
def find(self,
|
||||
base: str,
|
||||
scope,
|
||||
filterstr=None,
|
||||
attrlist=None,
|
||||
attrsonly=0,
|
||||
serverctrls=None,
|
||||
clientctrls=None,
|
||||
timeout=-1,
|
||||
sizelimit=0,
|
||||
assert_unique=False,
|
||||
assert_not_empty=False,
|
||||
):
|
||||
|
||||
def __walk_cb_find(conn: Connection, entry, context):
|
||||
result.append(entry)
|
||||
|
||||
def __search():
|
||||
return f'{base} -> "{filterstr}"'
|
||||
|
||||
try:
|
||||
result = []
|
||||
self.walk(__walk_cb_find, base, scope=scope, filterstr=filterstr, attrlist=attrlist)
|
||||
except Exception as e:
|
||||
slog(ERR, f'Failed search {__search()} ({e})')
|
||||
raise
|
||||
if assert_not_empty and not result:
|
||||
raise Exception(f'Empty result for search {__search()}')
|
||||
if assert_unique and len(result) > 1:
|
||||
raise Exception(f'Found {len(result)} results for search {__search()}')
|
||||
return result
|
||||
|
||||
@property
|
||||
def object_classes(self) -> list[ObjectClass]:
|
||||
#def object_classes(self):
|
||||
if self.__object_classes is None:
|
||||
res = self.find(base='', scope=ldap.SCOPE_BASE, filterstr='(objectClass=*)', attrlist=['subschemaSubentry'])
|
||||
dn = res[0][1]['subschemaSubentry'][0].decode('utf-8') # Usually yields cn=Subschema
|
||||
res = self.find(base=dn, scope=ldap.SCOPE_BASE, filterstr='(objectClass=*)', attrlist=['*', '+'])
|
||||
subschema_entry = res[0]
|
||||
subschema_subentry = ldap.cidict.cidict(subschema_entry[1])
|
||||
subschema = ldap.schema.SubSchema(subschema_subentry)
|
||||
object_class_oids = subschema.listall(ObjectClass)
|
||||
self.__object_classes = {
|
||||
oid: subschema.get_obj(ObjectClass, oid) for oid in object_class_oids
|
||||
}
|
||||
return self.__object_classes
|
||||
|
||||
@property
|
||||
def object_class_by_name(self):
|
||||
if self.__object_classes_by_name is None:
|
||||
ret: dict[str, ObjectClass] = {}
|
||||
self.__object_classes_by_name = ret
|
||||
for oid, oc in self.object_classes.items():
|
||||
ret[oid] = oc
|
||||
for name in oc.names:
|
||||
ret[name] = oc
|
||||
ret[name.lower()] = oc
|
||||
return self.__object_classes_by_name
|
||||
|
||||
def __oc_recurse_to_top(self, cur: str|ObjectClass, cb, context):
|
||||
cur_oc = cur if isinstance(cur, ObjectClass) else self.object_class_by_name[cur.lower()]
|
||||
for s in cur_oc.sup:
|
||||
self.__oc_recurse_to_top(s, cb, context)
|
||||
cb(cur_oc, context)
|
||||
|
||||
def object_class_path(self, leaf: str|ObjectClass):
|
||||
def cb(oc, context):
|
||||
ret.append(oc)
|
||||
ret = []
|
||||
self.__oc_recurse_to_top(leaf, cb, None)
|
||||
return reversed(ret)
|
||||
|
||||
@property
|
||||
def object_class_tree(self) -> nx.Graph:
|
||||
|
||||
if self.__object_class_tree is None:
|
||||
|
||||
def collect(root, attr):
|
||||
ret = set()
|
||||
def cb(oc, attr):
|
||||
vals = getattr(oc, attr)
|
||||
if vals is None:
|
||||
return
|
||||
for val in vals:
|
||||
ret.add(val)
|
||||
self.__oc_recurse_to_top(root, cb, attr)
|
||||
return ret
|
||||
|
||||
kind = {
|
||||
0: 'STRUCTURAL',
|
||||
1: 'ABSTRACT',
|
||||
2: 'AUXILIARY'
|
||||
}
|
||||
ret = nx.DiGraph()
|
||||
for oid, oc in self.object_classes.items():
|
||||
ret.add_node(
|
||||
oid,
|
||||
oid=oid,
|
||||
name=oc.names[0],
|
||||
kind=kind[oc.kind],
|
||||
must=', '.join(collect(oc, 'must')),
|
||||
may=', '.join(collect(oc, 'may'))
|
||||
)
|
||||
for base_class in oc.sup:
|
||||
try:
|
||||
ret.add_edge(oid, self.object_class_by_name[base_class.lower()].oid)
|
||||
except Exception as e:
|
||||
slog(WARNING, f'Failed to add edge {oid}:{oc.names} -> {base_class} ({e})')
|
||||
self.__object_class_tree = ret
|
||||
return self.__object_class_tree
|
||||
|
||||
def object_class_attrs(self, oc: str|ObjectClass, required: AttrType = AttrType.Must, origins: bool=False) -> dict[str, set[str]]|set[str]:
|
||||
all_attrs: set[str] = set()
|
||||
attrs_by_origin: dict[str, set[str]] = {}
|
||||
for oc in self.object_class_path(oc):
|
||||
cur = set()
|
||||
if required & self.AttrType.Must:
|
||||
cur |= set(oc.must)
|
||||
if required & self.AttrType.May:
|
||||
cur |= set(oc.may)
|
||||
if cur:
|
||||
all_attrs |= cur
|
||||
attrs_by_origin[oc] = cur
|
||||
return attrs_by_origin if origins else all_attrs
|
||||
|
||||
def is_derived_class(self, name, base_candidate):
|
||||
#oid = self.object_class_by_name[name].oid
|
||||
#base_oid = self.object_class_by_name[base_candidate].oid
|
||||
#if base_oid in [oc.oid for oc in self.object_class_path(name)]:
|
||||
# return True
|
||||
return nx.has_path(self.object_class_tree, self.object_class_by_name[name.lower()].oid, self.object_class_by_name[base_candidate.lower()].oid)
|
||||
|
||||
def default_config() -> Config: # export
|
||||
return Config()
|
||||
|
||||
def bind(conf: Config|BaseConfig|None=None) -> Connection:
|
||||
return Connection(conf)
|
||||
Loading…
Add table
Add a link
Reference in a new issue