value_and_grad

Contents

value_and_grad#

seli.opt.value_and_grad(func: Callable[[P], T], *, collection: str | None = None, has_aux: bool = False) Callable[[P], Any][source]#

Create a function that computes gradients with respect to module parameters.

This function wraps a loss function that takes a module as its first argument and returns a new function that computes the gradients of the loss with respect to the module’s parameters.

The gradient function extracts arrays from the module, computes gradients, and returns them in a dictionary mapping path strings to gradient arrays.

Parameters:
  • func (Callable) – The function to compute gradients for. It should take a module as its first argument and return a scalar loss value.

  • collection (str | None, default=None) – If provided, only extract arrays from Param objects with this collection. If None, extract arrays from all Param objects.

  • has_aux (bool, default=False) – Whether the function returns auxiliary data. If True, the function should return a tuple (loss, aux_data), where loss is a scalar and aux_data can be any type.

Returns:

A new function that takes the same arguments as func but returns values and gradients with respect to the module’s parameters. If has_aux is True, it returns a tuple (gradients, aux_data).

Return type:

Callable

Examples

>>> def loss_fn(module, x, y):
...     pred = module(x)
...     return ((pred - y) ** 2).mean()
>>> grad_fn = grad(loss_fn)
>>> value, gradients = grad_fn(module, x, y)