dfs_map#
- seli.dfs_map(obj: None | bool | int | float | str | type | ~jax.Array | ~jax._src.api.ShapeDtypeStruct | list | dict | ~seli.core._module.Module | ~typing.Any, fun: ~collections.abc.Callable[[~seli.core._module.PathKey, None | bool | int | float | str | type | ~jax.Array | ~jax._src.api.ShapeDtypeStruct | list | dict | ~seli.core._module.Module | ~typing.Any], None | bool | int | float | str | type | ~jax.Array | ~jax._src.api.ShapeDtypeStruct | list | dict | ~seli.core._module.Module | ~typing.Any] = <function <lambda>>, *, refs: dict[int, None | bool | int | float | str | type | ~jax.Array | ~jax._src.api.ShapeDtypeStruct | list | dict | ~seli.core._module.Module | ~typing.Any] | None = None, path: ~seli.core._module.PathKey | None = None, refs_fun: ~collections.abc.Callable[[~seli.core._module.PathKey, None | bool | int | float | str | type | ~jax.Array | ~jax._src.api.ShapeDtypeStruct | list | dict | ~seli.core._module.Module | ~typing.Any], None | bool | int | float | str | type | ~jax.Array | ~jax._src.api.ShapeDtypeStruct | list | dict | ~seli.core._module.Module | ~typing.Any] | None = None) list | dict | Module | None | bool | int | float | str | type | Array | ShapeDtypeStruct[source]#
Performs a depth-first traversal of a nested data structure, applying a transformation function to each element.
This function traverses dictionaries, lists, and Module objects recursively in a depth-first manner, applying the provided transformation function to each element. It builds a new structure with the same shape as the original, but with transformed values. During traversal, it tracks the path to each element and handles circular references to prevent infinite recursion.
- Parameters:
obj (NodeType) – The object to traverse, which can be a dictionary, list, Module, or a leaf value.
fun (Callable[[PathKey, NodeType], NodeType]) –
A transformation function to apply to each element in the structure. The function should return a transformed version of the element.
The function should accept two arguments: - path: A PathKey object representing the current path - x: The current element being processed
refs (dict[int, NodeType] | None, optional) – A dictionary mapping object IDs to their transformed versions. Used internally to track already-processed objects and handle circular references. Default is None (an empty dict will be created).
path (PathKey | None, optional) – A PathKey object representing the current path in the structure. Used for tracking position during recursive calls. Default is None (an empty PathKey will be created).
refs_fun (Callable[[PathKey, NodeType], NodeType] | None, optional) – A function to handle repeated references. Default is None. When an object is encountered multiple times during traversal: If refs_fun is None, the already-processed version is returned directly, if refs_fun is provided, it is called with (path, processed_obj) to determine what to return for the repeated reference.
- Returns:
A new structure with the same shape as the input, but with all elements
transformed according to the provided function.
- Raises:
ValueError – Supported types are: dictionaries, lists, Module objects, and leaf values.
TypeError –
Notes
The function preserves the structure of the original object while creating a new transformed copy.
Dictionary keys and Module attributes are processed in sorted order for deterministic traversal.
For circular references, the function uses the refs_fun parameter to determine how to handle them.
Module objects are created using object.__new__ without calling __init__, which may bypass important initialization logic.
The path parameter tracks the exact location of each element in the nested structure using:
ItemKey for dictionary keys and list indices
AttrKey for Module attributes