set_arrays

Contents

set_arrays#

seli.opt.set_arrays(module: None | bool | int | float | str | type | Array | ShapeDtypeStruct | list | dict | Module | Any, arrays: dict[str, Array]) None | bool | int | float | str | type | Array | ShapeDtypeStruct | list | dict | Module | Any[source]#

Set arrays back into parameters in a module.

This function takes a module and a dictionary of arrays, and sets the arrays back into the corresponding Param objects in the module. The paths in the dictionary should match those returned by get_arrays.

Parameters:
  • module (NodeType) – The module to set arrays into.

  • arrays (dict[str, jax.Array]) – A dictionary mapping path strings to arrays.

Returns:

A new module with the arrays set into the parameters.

Return type:

NodeType

Raises:

ValueError – If a path in the arrays dictionary doesn’t point to a Param object.