dfs_map

Contents

dfs_map#

seli.dfs_map(obj: None | bool | int | float | str | type | ~jax.Array | ~jax._src.api.ShapeDtypeStruct | list | dict | ~seli.core._module.Module | ~typing.Any, fun: ~collections.abc.Callable[[~seli.core._module.PathKey, None | bool | int | float | str | type | ~jax.Array | ~jax._src.api.ShapeDtypeStruct | list | dict | ~seli.core._module.Module | ~typing.Any], None | bool | int | float | str | type | ~jax.Array | ~jax._src.api.ShapeDtypeStruct | list | dict | ~seli.core._module.Module | ~typing.Any] = <function <lambda>>, *, refs: dict[int, None | bool | int | float | str | type | ~jax.Array | ~jax._src.api.ShapeDtypeStruct | list | dict | ~seli.core._module.Module | ~typing.Any] | None = None, path: ~seli.core._module.PathKey | None = None, refs_fun: ~collections.abc.Callable[[~seli.core._module.PathKey, None | bool | int | float | str | type | ~jax.Array | ~jax._src.api.ShapeDtypeStruct | list | dict | ~seli.core._module.Module | ~typing.Any], None | bool | int | float | str | type | ~jax.Array | ~jax._src.api.ShapeDtypeStruct | list | dict | ~seli.core._module.Module | ~typing.Any] | None = None) list | dict | Module | None | bool | int | float | str | type | Array | ShapeDtypeStruct[source]#

Performs a depth-first traversal of a nested data structure, applying a transformation function to each element.

This function traverses dictionaries, lists, and Module objects recursively in a depth-first manner, applying the provided transformation function to each element. It builds a new structure with the same shape as the original, but with transformed values. During traversal, it tracks the path to each element and handles circular references to prevent infinite recursion.

Parameters:
  • obj (NodeType) – The object to traverse, which can be a dictionary, list, Module, or a leaf value.

  • fun (Callable[[PathKey, NodeType], NodeType]) –

    A transformation function to apply to each element in the structure. The function should return a transformed version of the element.

    The function should accept two arguments: - path: A PathKey object representing the current path - x: The current element being processed

  • refs (dict[int, NodeType] | None, optional) – A dictionary mapping object IDs to their transformed versions. Used internally to track already-processed objects and handle circular references. Default is None (an empty dict will be created).

  • path (PathKey | None, optional) – A PathKey object representing the current path in the structure. Used for tracking position during recursive calls. Default is None (an empty PathKey will be created).

  • refs_fun (Callable[[PathKey, NodeType], NodeType] | None, optional) – A function to handle repeated references. Default is None. When an object is encountered multiple times during traversal: If refs_fun is None, the already-processed version is returned directly, if refs_fun is provided, it is called with (path, processed_obj) to determine what to return for the repeated reference.

Returns:

  • A new structure with the same shape as the input, but with all elements

  • transformed according to the provided function.

Raises:
  • ValueError – Supported types are: dictionaries, lists, Module objects, and leaf values.

  • TypeError

Notes

  • The function preserves the structure of the original object while creating a new transformed copy.

  • Dictionary keys and Module attributes are processed in sorted order for deterministic traversal.

  • For circular references, the function uses the refs_fun parameter to determine how to handle them.

  • Module objects are created using object.__new__ without calling __init__, which may bypass important initialization logic.

  • The path parameter tracks the exact location of each element in the nested structure using:

  • ItemKey for dictionary keys and list indices

  • AttrKey for Module attributes