Source code for seli.opt._adagrad

import jax.numpy as jnp
from jax import Array
from jaxtyping import Float

from seli.opt._opt import Optimizer


[docs] class Adagrad(Optimizer, name="opt.Adagrad"): """ Adaptive Gradient optimizer. Adapts learning rates per-parameter by scaling with the inverse square root of accumulated squared gradients. """ def __init__(self, lr: float = 1e-2, eps: float = 1e-8): self.lr = lr self.eps = eps self.G2: dict[str, Float[Array, "*_"]] = {}
[docs] def call_param( self, key: str, grad: Float[Array, "*s"], **_, ) -> Float[Array, "*s"]: if key not in self.G2: self.G2[key] = jnp.zeros_like(grad) # Accumulate squared gradients self.G2[key] = self.G2[key] + jnp.square(grad) # Compute the adaptive learning rate update return self.lr * grad / (jnp.sqrt(self.G2[key]) + self.eps)