Module#
- class seli.Module[source]#
Bases:
ModuleBaseBase class for all modules. Modules can be used to implement parameterized functions like neural networks.
Modules are PyTrees, which means they can be flattened and unflattened using JAXs tree_util functions.
The flattening will automatically go through all attributes including slots. Submodules as well as dictionaries, lists, and arrays will also be recursively flattened.
If the module and its children do not contain arrays, the module supports hashing and equality checking. These checks even respect the structure of shared references.
Methods Summary
set_rngs(key_or_seed[, collection])Set the state of the random number generator(s) for the module.
Flatten the module into a list of arrays and a tuple for reconstructing the orignal module.
tree_unflatten(aux_data, arrs_vals)Reconstruct the module from the outputs produced by Module.tree_flatten.
Methods Documentation
- set_rngs(key_or_seed: Key[Array, ''] | UInt32[Array, '2'] | int, collection: list[str] | None = None) Self[source]#
Set the state of the random number generator(s) for the module.
- Parameters:
key_or_seed (PRNGKeyArray | int) – The random number generator key or seed.
collection (list[str] | None, optional) – The collection of random number generators to set. If None, all random number generators will be set.
- Returns:
The module with the updated random number generator state.
- Return type:
Self
- tree_flatten() tuple[list[Array], tuple[list[PathKey], None | bool | int | float | str | type | Array | ShapeDtypeStruct | list | dict | Module | Any]][source]#
Flatten the module into a list of arrays and a tuple for reconstructing the orignal module. The tuple contains the path keys to the arrays and a copy of the original module wihout the arrays. This function is needed to be compatible with the Jax PyTree API.