Source code for seli.opt._rmsprop
import jax.numpy as jnp
from jax import Array
from jaxtyping import Float
from seli.opt._opt import Optimizer
from seli.opt._utils import lerp
[docs]
class RMSProp(Optimizer, name="opt.RMSProp"):
"""
Root Mean Square Propagation optimizer.
Addresses Adagrad's diminishing learning rates by using exponential moving
average of squared gradients.
"""
def __init__(
self,
lr: float = 1e-3,
beta: float = 0.9,
eps: float = 1e-8,
):
self.lr = lr
self.beta = beta
self.eps = eps
self.g2: dict[str, Float[Array, "*_"]] = {}
[docs]
def call_param(
self,
key: str,
grad: Float[Array, "*s"],
**_,
) -> Float[Array, "*s"]:
if key not in self.g2:
self.g2[key] = jnp.zeros_like(grad)
# compute the EMA of the squared gradients
self.g2[key] = lerp(self.g2[key], jnp.square(grad), self.beta)
# Normalize the gradient by the EMA of the squared gradients
return self.lr * grad / (jnp.sqrt(self.g2[key]) + self.eps)