AttrKey#

class seli.AttrKey(key: str)[source]#

Bases: ItemKey

Key for accessing object attributes using the dot operator. Used to access attributes of an object using the dot notation (obj.attr).

key#

The name of the attribute to access.

Type:

str

Methods Summary

from_str(s)

get(obj)

set(obj, value)

set_rngs(key_or_seed[, collection])

Set the state of the random number generator(s) for the module.

tree_flatten()

Flatten the module into a list of arrays and a tuple for reconstructing the orignal module.

tree_unflatten(aux_data, arrs_vals)

Reconstruct the module from the outputs produced by Module.tree_flatten.

Methods Documentation

classmethod from_str(s: str) AttrKey[source]#
get(obj: Any) Any[source]#
set(obj: Any, value: Any) None[source]#
set_rngs(key_or_seed: Key[Array, ''] | UInt32[Array, '2'] | int, collection: list[str] | None = None) Self#

Set the state of the random number generator(s) for the module.

Parameters:
  • key_or_seed (PRNGKeyArray | int) – The random number generator key or seed.

  • collection (list[str] | None, optional) – The collection of random number generators to set. If None, all random number generators will be set.

Returns:

The module with the updated random number generator state.

Return type:

Self

tree_flatten() tuple[list[Array], tuple[list[PathKey], None | bool | int | float | str | type | Array | ShapeDtypeStruct | list | dict | Module | Any]]#

Flatten the module into a list of arrays and a tuple for reconstructing the orignal module. The tuple contains the path keys to the arrays and a copy of the original module wihout the arrays. This function is needed to be compatible with the Jax PyTree API.

classmethod tree_unflatten(aux_data: tuple[list[PathKey], None | bool | int | float | str | type | Array | ShapeDtypeStruct | list | dict | Module | Any], arrs_vals: Sequence[Array | ShapeDtypeStruct]) Self#

Reconstruct the module from the outputs produced by Module.tree_flatten.