get_arrays

Contents

get_arrays#

seli.opt.get_arrays(module: None | bool | int | float | str | type | Array | ShapeDtypeStruct | list | dict | Module | Any, collection: str | None = None) dict[str, Array][source]#

Extract arrays from parameters in a module.

This function traverses the module and extracts all arrays from Param objects, optionally filtering by collection. It returns a copy of the module with the array values set to None, and a dictionary mapping path strings to arrays.

Parameters:
  • module (NodeType) – The module to extract arrays from.

  • collection (str | None, default=None) – If provided, only extract arrays from Param objects with this collection. If None, extract arrays from all Param objects.

Returns:

A dictionary mapping path strings to arrays

Return type:

dict[str, jax.Array]