Source code for seli.opt._nesterov
import jax.numpy as jnp
from jax import Array
from jaxtyping import Float
from seli.opt._opt import Optimizer
[docs]
class Nesterov(Optimizer, name="opt.Nesterov"):
"""
Nesterov Accelerated Gradient optimizer.
Improves on standard momentum by computing gradients at a "lookahead"
position, providing better convergence rates.
"""
def __init__(self, lr: float = 1e-3, beta: float = 0.9):
self.lr = lr
self.beta = beta
self.v: dict[str, Float[Array, "*_"]] = {}
[docs]
def call_param(
self,
key: str,
grad: Float[Array, "*s"],
**_,
) -> Float[Array, "*s"]:
if key not in self.v:
self.v[key] = jnp.zeros_like(grad)
# Calculate the update using Nesterov momentum
velocity_prev = self.v[key]
self.v[key] = velocity_prev * self.beta + grad
# This effectively computes the gradient at a "lookahead" position
return self.lr * (self.beta * self.v[key] + (1 - self.beta) * grad)