Source code for seli.opt._grad

from collections.abc import Callable
from functools import partial, wraps
from typing import Any, ParamSpec, TypeVar

import jax
import jax.numpy as jnp

from seli.core._module import AttrKey, NodeType, PathKey, dfs_map
from seli.core._typecheck import typecheck
from seli.net._param import Param

P = ParamSpec("P")
T = TypeVar("T")


FLOAT_TYPES = (
    jnp.float16,
    jnp.float32,
    jnp.float64,
    jnp.bfloat16,
)


[docs] @typecheck def get_arrays( module: NodeType, collection: str | None = None, ) -> dict[str, jax.Array]: """ Extract arrays from parameters in a module. This function traverses the module and extracts all arrays from Param objects, optionally filtering by collection. It returns a copy of the module with the array values set to None, and a dictionary mapping path strings to arrays. Parameters --- module : NodeType The module to extract arrays from. collection : str | None, default=None If provided, only extract arrays from Param objects with this collection. If None, extract arrays from all Param objects. Returns --- dict[str, jax.Array] A dictionary mapping path strings to arrays """ arrays_paths: dict[PathKey, jax.Array] = {} def fun(path: PathKey, obj: NodeType): if not isinstance(obj, jax.Array): return obj # if no collection is provided, return all arrays if collection is None: arrays_paths[path] = obj return obj assert collection is not None # if a collection is provided, the base object cannot be the array # in a Param object if not path.path: return obj # if the last item is not the value attribute, return the object if path[-1] != AttrKey("value"): return obj # get the parent Param object parent_path = path[:-1] parent = parent_path.get(module) # if the parent is not a Param object, return the object if not isinstance(parent, Param): return obj # if the collection does not match, return the object if parent.collection != collection: return obj arrays_paths[path] = obj return obj # does not create any side effects module = dfs_map(module, fun) arrays = {repr(path): arr for path, arr in arrays_paths.items()} return arrays
[docs] @typecheck def set_arrays( module: NodeType, arrays: dict[str, jax.Array], ) -> NodeType: """ Set arrays back into parameters in a module. This function takes a module and a dictionary of arrays, and sets the arrays back into the corresponding Param objects in the module. The paths in the dictionary should match those returned by get_arrays. Parameters --- module : NodeType The module to set arrays into. arrays : dict[str, jax.Array] A dictionary mapping path strings to arrays. Returns --- NodeType A new module with the arrays set into the parameters. Raises --- ValueError If a path in the arrays dictionary doesn't point to a Param object. """ array_paths = {PathKey.from_str(path): arr for path, arr in arrays.items()} if PathKey([]) in array_paths: if len(arrays) != 1: error = f"Base object is set to an array, but got path {arrays}" raise ValueError(error) return array_paths[PathKey([])] # perform memory efficient copy module = dfs_map(module) for path, arr in array_paths.items(): path.set(module, arr) return module
[docs] def grad( func: Callable[P, T], *, collection: str | None = None, has_aux: bool = False, ) -> Callable[P, Any]: """ Create a function that computes gradients with respect to module parameters. This function wraps a loss function that takes a module as its first argument and returns a new function that computes the gradients of the loss with respect to the module's parameters. The gradient function extracts arrays from the module, computes gradients, and returns them in a dictionary mapping path strings to gradient arrays. Parameters --- func : Callable The function to compute gradients for. It should take a module as its first argument and return a scalar loss value. collection : str | None, default=None If provided, only extract arrays from Param objects with this collection. If None, extract arrays from all Param objects. has_aux : bool, default=False Whether the function returns auxiliary data. If True, the function should return a tuple (loss, aux_data), where loss is a scalar and aux_data can be any type. Returns --- Callable A new function that takes the same arguments as func but returns gradients with respect to the module's parameters. If has_aux is True, it returns a tuple (gradients, aux_data). Examples --- >>> def loss_fn(module, x, y): ... pred = module(x) ... return ((pred - y) ** 2).mean() >>> grad_fn = grad(loss_fn) >>> gradients = grad_fn(module, x, y) """ @wraps(func) def wrap_fn(module: NodeType, *args: P.args, **kwargs: P.kwargs) -> Any: arrays = get_arrays(module, collection) arrays = {k: v for k, v in arrays.items() if v.dtype in FLOAT_TYPES} @partial(jax.grad, has_aux=has_aux) def grad_fn( arrays: dict[str, jax.Array], *args: P.args, **kwargs: P.kwargs, ) -> Any: module_ = set_arrays(module, arrays) return func(module_, *args, **kwargs) return grad_fn(arrays, *args, **kwargs) return wrap_fn
[docs] def value_and_grad( func: Callable[P, T], *, collection: str | None = None, has_aux: bool = False, ) -> Callable[P, Any]: """ Create a function that computes gradients with respect to module parameters. This function wraps a loss function that takes a module as its first argument and returns a new function that computes the gradients of the loss with respect to the module's parameters. The gradient function extracts arrays from the module, computes gradients, and returns them in a dictionary mapping path strings to gradient arrays. Parameters --- func : Callable The function to compute gradients for. It should take a module as its first argument and return a scalar loss value. collection : str | None, default=None If provided, only extract arrays from Param objects with this collection. If None, extract arrays from all Param objects. has_aux : bool, default=False Whether the function returns auxiliary data. If True, the function should return a tuple (loss, aux_data), where loss is a scalar and aux_data can be any type. Returns --- Callable A new function that takes the same arguments as func but returns values and gradients with respect to the module's parameters. If has_aux is True, it returns a tuple (gradients, aux_data). Examples --- >>> def loss_fn(module, x, y): ... pred = module(x) ... return ((pred - y) ** 2).mean() >>> grad_fn = grad(loss_fn) >>> value, gradients = grad_fn(module, x, y) """ @wraps(func) def wrap_fn(module: NodeType, *args: P.args, **kwargs: P.kwargs) -> Any: arrays = get_arrays(module, collection) arrays = {k: v for k, v in arrays.items() if v.dtype in FLOAT_TYPES} @partial(jax.value_and_grad, has_aux=has_aux) def grad_fn( arrays: dict[str, jax.Array], *args: P.args, **kwargs: P.kwargs, ) -> Any: module_ = set_arrays(module, arrays) return func(module_, *args, **kwargs) return grad_fn(arrays, *args, **kwargs) return wrap_fn