Source code for seli.core._module

"""
Modules are the core data structure for organising state. This module provides
the core functionality for creating, traversing, and modifying modules.
"""

import re
from collections.abc import Callable, Hashable, Sequence
from typing import Any, Self, TypeAlias

import jax
from jaxtyping import PRNGKeyArray

from seli.core._registry import REGISTRY_INVERSE, ModuleBase
from seli.core._typecheck import typecheck
from seli.core._utils import array_summary

__all__ = [
    "Module",
    "ItemKey",
    "AttrKey",
    "PathKey",
    "NodeType",
    "LeafType",
    "to_tree",
    "to_tree_inverse",
    "flat_path_dict",
    "dfs_map",
]


[docs] @typecheck class Module(ModuleBase, name="builtin.Module"): """ Base class for all modules. Modules can be used to implement parameterized functions like neural networks. Modules are PyTrees, which means they can be flattened and unflattened using JAXs tree_util functions. The flattening will automatically go through all attributes including slots. Submodules as well as dictionaries, lists, and arrays will also be recursively flattened. If the module and its children do not contain arrays, the module supports hashing and equality checking. These checks even respect the structure of shared references. """ def __hash__(self): flat = flat_path_dict(self) return hash(tuple(flat.items())) def __eq__(self, other): return flat_path_dict(self) == flat_path_dict(other)
[docs] def tree_flatten( self, ) -> tuple[list[jax.Array], tuple[list["PathKey"], "NodeType"]]: """ Flatten the module into a list of arrays and a tuple for reconstructing the orignal module. The tuple contains the path keys to the arrays and a copy of the original module wihout the arrays. This function is needed to be compatible with the Jax PyTree API. """ tree = to_tree(self) arrs: dict[PathKey, jax.Array] = {} def get_arrs(path: PathKey, obj: NodeType): if isinstance(obj, jax.Array): arrs[path] = obj return None return obj tree = dfs_map(tree, get_arrs) arrs_keys = list(arrs.keys()) arrs_vals = [arrs[key] for key in arrs_keys] return arrs_vals, (arrs_keys, tree)
[docs] @classmethod def tree_unflatten( cls: type[Self], aux_data: tuple[list["PathKey"], "NodeType"], arrs_vals: Sequence[jax.Array | jax.ShapeDtypeStruct], ) -> Self: """ Reconstruct the module from the outputs produced by `Module.tree_flatten`. """ arrs_keys, tree = aux_data obj = to_tree_inverse(tree) for path, child in zip(arrs_keys, arrs_vals): path.set(obj, child) return obj
def __repr__(self) -> str: return node_repr(self)
[docs] def set_rngs( self, key_or_seed: PRNGKeyArray | int, collection: list[str] | None = None, ) -> Self: """ Set the state of the random number generator(s) for the module. Parameters ---------- key_or_seed : PRNGKeyArray | int The random number generator key or seed. collection : list[str] | None, optional The collection of random number generators to set. If `None`, all random number generators will be set. Returns ------- Self The module with the updated random number generator state. """ from seli.net._key import set_rngs return set_rngs(self, key_or_seed, collection)
LeafType: TypeAlias = ( None | bool | int | float | str | type | jax.Array | jax.ShapeDtypeStruct ) DeepType: TypeAlias = list | dict | Module NodeType: TypeAlias = LeafType | DeepType | Any
[docs] @typecheck class ItemKey(Module, name="builtin.ItemKey"): """ Key for accessing items using the [] operator. Used to access dictionary items by key or sequence items by index. Attributes ---------- key : str | int The key which describes the position of the item in the dictionary or list. """ key: str | int def __init__(self, key: str | int) -> None: self.key = key
[docs] def get(self, obj: dict | list) -> Any: return obj[self.key]
[docs] def set(self, obj: dict | list, value: Any) -> None: obj[self.key] = value
def __repr__(self): return f"[{self.key!r}]" # add sorting to allow deterministic traversal def __lt__(self, other: "ItemKey | AttrKey") -> bool: return _keys_lt(self, other) def __hash__(self) -> int: return hash((type(self), self.key)) def __eq__(self, other: "ItemKey | AttrKey") -> bool: return isinstance(other, ItemKey) and self.key == other.key
[docs] @classmethod def from_str(cls, s: str) -> "ItemKey": if not s.startswith("[") and not s.endswith("]"): raise ValueError(f"Invalid item key string: {s}") key = s[1:-1] if key.startswith("'") and key.endswith("'"): key = key[1:-1] return cls(key) if not key.isdigit(): raise ValueError(f"Invalid item key string: `{key}`") return cls(int(key))
[docs] @typecheck class AttrKey(ItemKey, name="builtin.AttrKey"): """ Key for accessing object attributes using the dot operator. Used to access attributes of an object using the dot notation (obj.attr). Attributes ---------- key : str The name of the attribute to access. """ key: str def __init__(self, key: str) -> None: self.key = key
[docs] def get(self, obj: Any) -> Any: return getattr(obj, self.key)
[docs] def set(self, obj: Any, value: Any) -> None: object.__setattr__(obj, self.key, value)
def __repr__(self): return f".{self.key}"
[docs] @classmethod def from_str(cls, s: str) -> "AttrKey": if not s.startswith(".") or len(s) < 2: raise ValueError(f"Invalid attribute key string: {s}") if not s[1:].isidentifier(): raise ValueError(f"Invalid attribute key string: {s}") return cls(s[1:])
@typecheck def _keys_lt(a: ItemKey | AttrKey, b: ItemKey | AttrKey) -> bool: if type(a) is not type(b): return type(a) is ItemKey if type(a.key) is not type(b.key): return isinstance(a.key, int) return a.key < b.key
[docs] @typecheck class PathKey(Module, name="builtin.PathKey"): """ Sequence of keys that enables access to nested data structures. Combines multiple ItemKey and AttrKey objects to navigate through nested objects, dictionaries, and sequences. Attributes ---------- path : list[ItemKey | AttrKey] The sequence of keys that describe the path to the nested data structure. """ path: list[ItemKey | AttrKey] def __init__(self, path: list[ItemKey | AttrKey]) -> None: self.path = path def __add__(self, item: ItemKey | AttrKey) -> "PathKey": return PathKey(self.path + [item])
[docs] def get(self, obj): for item in self.path: obj = item.get(obj) return obj
[docs] def set(self, obj: DeepType, value: NodeType): # Handle empty path if not self.path: return # Navigate to the parent object, stopping before the last item parent = obj for item in self.path[:-1]: parent = item.get(parent) # Set the value using the last item on the parent object last_item = self.path[-1] last_item.set(parent, value)
def __repr__(self): return "".join(repr(item) for item in self.path) # add sorting to allow deterministic traversal def __lt__(self, other): return tuple(self.path) < tuple(other.path) def __hash__(self): return hash((type(self), tuple(self.path))) def __eq__(self, other): return isinstance(other, PathKey) and self.path == other.path
[docs] @classmethod def from_str(cls, s: str) -> "PathKey": key_parts = re.split(r"(?=[.\[])", s) keys = [] for part in key_parts: # regular expression might produce emtpy string at the start or end if not part: continue if part.startswith("."): keys.append(AttrKey.from_str(part)) continue if part.startswith("["): keys.append(ItemKey.from_str(part)) continue raise ValueError(f"Invalid path key string: {s}") return cls(keys)
def __getitem__(self, item: int | slice) -> "ItemKey | AttrKey | PathKey": if isinstance(item, slice): return type(self)(self.path[item]) return self.path[item]
[docs] def dfs_map( obj: NodeType, fun: Callable[[PathKey, NodeType], NodeType] = lambda _, x: x, *, refs: dict[int, NodeType] | None = None, path: PathKey | None = None, refs_fun: Callable[[PathKey, NodeType], NodeType] | None = None, ) -> DeepType | LeafType: """ Performs a depth-first traversal of a nested data structure, applying a transformation function to each element. This function traverses dictionaries, lists, and Module objects recursively in a depth-first manner, applying the provided transformation function to each element. It builds a new structure with the same shape as the original, but with transformed values. During traversal, it tracks the path to each element and handles circular references to prevent infinite recursion. Parameters ---------- obj : NodeType The object to traverse, which can be a dictionary, list, Module, or a leaf value. fun : Callable[[PathKey, NodeType], NodeType] A transformation function to apply to each element in the structure. The function should return a transformed version of the element. The function should accept two arguments: - path: A PathKey object representing the current path - x: The current element being processed refs : dict[int, NodeType] | None, optional A dictionary mapping object IDs to their transformed versions. Used internally to track already-processed objects and handle circular references. Default is `None` (an empty dict will be created). path : PathKey | None, optional A PathKey object representing the current path in the structure. Used for tracking position during recursive calls. Default is `None` (an empty PathKey will be created). refs_fun : Callable[[PathKey, NodeType], NodeType] | None, optional A function to handle repeated references. Default is `None`. When an object is encountered multiple times during traversal: If `refs_fun` is `None`, the already-processed version is returned directly, if `refs_fun` is provided, it is called with `(path, processed_obj)` to determine what to return for the repeated reference. Returns ------- A new structure with the same shape as the input, but with all elements transformed according to the provided function. Raises ------ `ValueError` If an object of an unsupported type is encountered. Supported types are: dictionaries, lists, Module objects, and leaf values. `TypeError` If a dictionary with non-string keys is encountered. Notes ----- - The function preserves the structure of the original object while creating a new transformed copy. - Dictionary keys and Module attributes are processed in sorted order for deterministic traversal. - For circular references, the function uses the refs_fun parameter to determine how to handle them. - Module objects are created using object.__new__ without calling __init__, which may bypass important initialization logic. - The path parameter tracks the exact location of each element in the nested structure using: - ItemKey for dictionary keys and list indices - AttrKey for Module attributes """ path = path or PathKey([]) refs = refs or {} if id(obj) in refs and not isinstance(obj, LeafType): if refs_fun is None: return refs[id(obj)] return refs_fun(path, refs[id(obj)]) obj_fun = fun(path, obj) refs[id(obj)] = obj_fun if isinstance(obj_fun, LeafType): return obj_fun # if object is registered it is also a valid type, since we can covert it # to a string and back, we need to test for hashability and non-module # otherwise we cannot perform the isin check, the obj_fun may not be a # Module, since Module.__hash__ would be a RecursionError. if isinstance(obj_fun, Hashable) and not isinstance(obj_fun, Module): if obj_fun in REGISTRY_INVERSE: return obj_fun if isinstance(obj_fun, dict): if not all(isinstance(key, str) for key in obj_fun.keys()): error = f"Dictionary keys must be strings got {obj_fun.keys()}" raise TypeError(error) obj_new = {} for key, value in sorted(obj_fun.items(), key=lambda x: x[0]): obj_new[key] = dfs_map( value, fun, path=path + ItemKey(key), refs=refs, refs_fun=refs_fun, ) return obj_new if isinstance(obj_fun, list): obj_new = [] for i, value in enumerate(obj_fun): obj_new.append( dfs_map( value, fun, path=path + ItemKey(i), refs=refs, refs_fun=refs_fun, ), ) return obj_new if isinstance(obj_fun, Module): keys = [] if hasattr(obj_fun, "__dict__"): keys.extend(obj_fun.__dict__.keys()) if hasattr(obj_fun, "__slots__"): error = f"Module {obj_fun} has __slots__, which is not supported" raise TypeError(error) obj_new = object.__new__(type(obj_fun)) for key in sorted(keys): value = getattr(obj_fun, key) setattr( obj_new, key, dfs_map( value, fun, path=path + AttrKey(key), refs=refs, refs_fun=refs_fun, ), ) return obj_new raise ValueError(f"Unknown object type: {type(obj_fun)}, {obj_fun}")
[docs] def to_tree(obj: NodeType): """ Convert shared/cyclic references into a PathKeys, the result is a tree. This function transforms complex nested data structures that may contain shared references (the same object referenced multiple times) or cyclic references (loops in the reference graph) into a tree structure. Instead of maintaining the actual shared or cyclic references, it replaces them with path references. Parameters --- obj : NodeType The input object to convert to a tree. Can be any supported type: dictionaries, lists, Module objects, or leaf values (None, bool, int, float, str, or jax.Array). Returns --- NodeType A tree-structured version of the input, with all shared and cyclic references replaced by path references. Notes --- - This function is useful for serializing complex object graphs or visualizing structures with cycles. - Path references can be used to reconstruct the original structure if needed. - The function uses dfs_map internally to traverse the structure. """ id_to_path: dict[int, PathKey] = {} def fun(path: PathKey, obj: NodeType): id_to_path[id(obj)] = path return obj def refs_fun(_: PathKey, obj: NodeType): return id_to_path[id(obj)] return dfs_map(obj, fun, refs_fun=refs_fun)
[docs] def to_tree_inverse(obj: NodeType): """ Reconstructs the original object structure from a tree produced by to_tree. This function is the inverse operation of to_tree. It takes a tree structure where shared or cyclic references have been replaced with PathKey objects, and reconstructs the original structure by resolving those path references back into actual object references. Parameters --- obj : NodeType A tree structure, typically produced by to_tree, where shared or cyclic references have been replaced with PathKey objects pointing to their location in the tree. Returns --- NodeType The reconstructed object structure with all path references resolved back into actual object references, restoring the original shared references and cycles. Notes --- - This function reverses the transformation performed by to_tree - When a PathKey is encountered during traversal, it gets resolved by accessing the object at that path in the tree - The function uses dfs_map internally for traversal, similar to to_tree - While to_tree eliminates cycles by replacing them with path references, this function reintroduces those cycles """ refs: dict[PathKey, PathKey] = {} def fun(path: PathKey, obj: NodeType): if isinstance(obj, PathKey): refs[path] = obj return obj obj = dfs_map(obj, fun, refs_fun=fun) for path, ref in refs.items(): path.set(obj, ref.get(obj)) return obj
[docs] def flat_path_dict(obj: NodeType): """ Convert a nested object structure into a flat dictionary representation. This function transforms a potentially nested object into a flat dictionary where: - Each entry is keyed by a PathKey representing its location in the original structure - Leaf values and PathKey references are preserved directly - For non-leaf nodes, their class name is stored under a __class__ attribute key The resulting dictionary provides a serializable, deterministic representation of the objects structure that preserves paths and type information. Parameters ---------- obj : NodeType The object to convert to a flat path dictionary Returns ------- dict[PathKey, NodeType] A dictionary mapping PathKey objects to values, sorted by path for deterministic output """ tree = to_tree(obj) nodes: dict[PathKey, NodeType] = {} def add_node(path, node: NodeType): if isinstance(node, (LeafType, PathKey)): nodes[path] = node return node nodes[path + AttrKey("__class__")] = type(node) return node dfs_map(tree, add_node) # sort dict by keys for deterministic output return dict(sorted(nodes.items(), key=lambda x: x[0]))
def node_repr(obj: NodeType, /, indent: str = " " * 4) -> str: """ Generate a structured, readable string representation of nested objects. Creates a hierarchical string representation of Module objects and other complex nested structures with appropriate indentation. The function handles various types differently: - JAX arrays: Summarized using `array_summary` - PathKey objects: Displayed with their string representation - Lists: Formatted with each item on a new indented line - Dictionaries: Formatted with key-value pairs on indented lines - Module objects: Displayed with class name and attribute values - Other types: Using their native repr() representation Parameters --- obj : NodeType The object to represent as a string indent : str, default=" " The indentation string used for nested levels Returns --- str A formatted string representation of the object """ obj = to_tree(obj) if isinstance(obj, jax.Array): return array_summary(obj) if isinstance(obj, PathKey): return f"<obj{obj!r}>" if isinstance(obj, list): if not obj: return "[]" head = "[\n" body = "" for item in obj: item_repr = node_repr(item, indent=indent) item_repr = item_repr.replace("\n", "\n" + indent) body += f"{indent}{item_repr},\n" tail = "]" return head + body + tail if isinstance(obj, dict): if not obj: return "{}" head = "{\n" body = "" for key, value in sorted(obj.items(), key=lambda x: x[0]): value_repr = node_repr(value, indent=indent) value_repr = value_repr.replace("\n", "\n" + indent) body += f"{indent}{key!r}: {value_repr},\n" tail = "}" return head + body + tail if isinstance(obj, Module): keys = [] if hasattr(obj, "__dict__"): keys.extend(obj.__dict__.keys()) if not keys: return f"{obj.__class__.__name__}()" head = f"{obj.__class__.__name__}(\n" body = "" for key in sorted(keys): value = getattr(obj, key) value_repr = node_repr(value, indent=indent) value_repr = value_repr.replace("\n", "\n" + indent) body += f"{indent}{key}={value_repr},\n" tail = ")" return head + body + tail return repr(obj)