ArrayPlaceholder#
- class seli.ArrayPlaceholder(index: int)[source]#
Bases:
ModulePlaceholder for an array that will be serialized and deserialized later.
- index#
The index of the array in the list of arrays that will be serialized.
- Type:
int
Methods Summary
set_rngs(key_or_seed[, collection])Set the state of the random number generator(s) for the module.
Flatten the module into a list of arrays and a tuple for reconstructing the orignal module.
tree_unflatten(aux_data, arrs_vals)Reconstruct the module from the outputs produced by Module.tree_flatten.
Methods Documentation
- set_rngs(key_or_seed: Key[Array, ''] | UInt32[Array, '2'] | int, collection: list[str] | None = None) Self#
Set the state of the random number generator(s) for the module.
- Parameters:
key_or_seed (PRNGKeyArray | int) – The random number generator key or seed.
collection (list[str] | None, optional) – The collection of random number generators to set. If None, all random number generators will be set.
- Returns:
The module with the updated random number generator state.
- Return type:
Self
- tree_flatten() tuple[list[Array], tuple[list[PathKey], None | bool | int | float | str | type | Array | ShapeDtypeStruct | list | dict | Module | Any]]#
Flatten the module into a list of arrays and a tuple for reconstructing the orignal module. The tuple contains the path keys to the arrays and a copy of the original module wihout the arrays. This function is needed to be compatible with the Jax PyTree API.