value_and_grad#
- seli.opt.value_and_grad(func: Callable[[P], T], *, collection: str | None = None, has_aux: bool = False) Callable[[P], Any][source]#
Create a function that computes gradients with respect to module parameters.
This function wraps a loss function that takes a module as its first argument and returns a new function that computes the gradients of the loss with respect to the module’s parameters.
The gradient function extracts arrays from the module, computes gradients, and returns them in a dictionary mapping path strings to gradient arrays.
- Parameters:
func (Callable) – The function to compute gradients for. It should take a module as its first argument and return a scalar loss value.
collection (str | None, default=None) – If provided, only extract arrays from Param objects with this collection. If None, extract arrays from all Param objects.
has_aux (bool, default=False) – Whether the function returns auxiliary data. If True, the function should return a tuple (loss, aux_data), where loss is a scalar and aux_data can be any type.
- Returns:
A new function that takes the same arguments as func but returns values and gradients with respect to the module’s parameters. If has_aux is True, it returns a tuple (gradients, aux_data).
- Return type:
Callable
Examples
>>> def loss_fn(module, x, y): ... pred = module(x) ... return ((pred - y) ** 2).mean() >>> grad_fn = grad(loss_fn) >>> value, gradients = grad_fn(module, x, y)