Source code for seli.opt._momentum
import jax.numpy as jnp
from jax import Array
from jaxtyping import Float
from seli.opt._opt import Optimizer
[docs]
class Momentum(Optimizer, name="opt.Momentum"):
"""
Momentum optimizer.
Accelerates optimization by accumulating a velocity vector in the direction
of persistent gradient directions. This is analogous to the momentum of a
ball rolling down a hill.
The velocity is updated with the gradient and a decay factor. The decay
factor is a hyperparameter that controls the influence of previous
gradients on the current update.
For well-behaved functions, momentum often leads to faster convergence,
when compared to SGD.
"""
def __init__(self, lr: float = 1e-3, beta: float = 0.9):
self.lr = lr
self.beta = beta
self.v: dict[str, Float[Array, "*_"]] = {}
[docs]
def call_param(
self,
key: str,
grad: Float[Array, "*s"],
**_,
) -> Float[Array, "*s"]:
if key not in self.v:
self.v[key] = jnp.zeros_like(grad)
# update the velocity
self.v[key] = self.v[key] * self.beta + grad
# scale the velocity by the learning rate
return self.v[key] * self.lr