ItemKey#
- class seli.ItemKey(key: str | int)[source]#
Bases:
ModuleKey for accessing items using the [] operator. Used to access dictionary items by key or sequence items by index.
- key#
The key which describes the position of the item in the dictionary or list.
- Type:
str | int
Methods Summary
from_str(s)get(obj)set(obj, value)set_rngs(key_or_seed[, collection])Set the state of the random number generator(s) for the module.
Flatten the module into a list of arrays and a tuple for reconstructing the orignal module.
tree_unflatten(aux_data, arrs_vals)Reconstruct the module from the outputs produced by Module.tree_flatten.
Methods Documentation
- set_rngs(key_or_seed: Key[Array, ''] | UInt32[Array, '2'] | int, collection: list[str] | None = None) Self#
Set the state of the random number generator(s) for the module.
- Parameters:
key_or_seed (PRNGKeyArray | int) – The random number generator key or seed.
collection (list[str] | None, optional) – The collection of random number generators to set. If None, all random number generators will be set.
- Returns:
The module with the updated random number generator state.
- Return type:
Self
- tree_flatten() tuple[list[Array], tuple[list[PathKey], None | bool | int | float | str | type | Array | ShapeDtypeStruct | list | dict | Module | Any]]#
Flatten the module into a list of arrays and a tuple for reconstructing the orignal module. The tuple contains the path keys to the arrays and a copy of the original module wihout the arrays. This function is needed to be compatible with the Jax PyTree API.