# -*- coding: utf-8 -*- from typing import Optional, List, Iterable import abc from ...log import * from .Table import Table from .Column import Column from .DataType import DataType from .CompositeForeignKey import CompositeForeignKey class Schema(abc.ABC): # export def __init__(self) -> None: self.___tables: Optional[List[Table]] = None self.__foreign_keys: Optional[List[CompositeForeignKey]] = None self.__access_defining_columns: Optional[List[str]] = None @property def __tables(self): if self.___tables is None: ret = dict() for name in self._table_names(): slog(WARNING, f'Caching metadata for table "{name}"') assert(isinstance(name, str)) ret[name] = self._table(name) self.___tables = ret return self.___tables # ------ API to be implemented @abc.abstractmethod def _table_names(self) -> Iterable[str]: throw(ERR, "Called pure virtual base class method") return [] @abc.abstractmethod def _table(self, name: str) -> Table: throw(ERR, "Called pure virtual base class method") return None # type: ignore @abc.abstractmethod def _foreign_keys(self) -> List[CompositeForeignKey]: pass @abc.abstractmethod def _access_defining_columns(self): pass # ------ API to be called @property def table_names(self) -> Iterable[str]: return self.__tables.keys() @property def tables(self) -> Iterable[Table]: return self.__tables.values() @property def access_defining_columns(self): if self.__access_defining_columns is None: self.__access_defining_columns = self._access_defining_columns() return self.__access_defining_columns @property def foreign_key_constraints(self) -> List[CompositeForeignKey]: if self.__foreign_keys is None: self.__foreign_keys = self._foreign_keys() return self.__foreign_keys def table(self, name: str) -> Table: return self.__tables[name] def table_by_model_name(self, name: str, throw=False) -> Table: for table in self.__tables.values(): if table.model_name == name: return table if throw: raise Exception(f'Table "{name}" not found in database metadata') return None # type: ignore def primary_keys(self, table_name: str) -> Iterable[str]: return self.__tables[table_name].primary_keys def columns(self, table_name: str) -> Iterable[Column]: return self.__tables[table_name].columns