Source code for seli.opt._adam

import jax.numpy as jnp
from jax import Array
from jaxtyping import Float

from seli.opt._opt import Optimizer
from seli.opt._utils import lerp


[docs] class Adam(Optimizer, name="opt.Adam"): """ Adaptive Moment Estimation optimizer. Combines momentum and RMSProp, maintaining both first moment (mean) and second moment (variance) of gradients with bias correction. Adam has become the de facto standard optimizer for deep learning. """ def __init__( self, lr: float = 3e-4, beta1: float = 0.9, beta2: float = 0.999, eps: float = 1e-8, ): self.lr = lr self.beta1 = beta1 self.beta2 = beta2 self.eps = eps # First moment (momentum) self.m: dict[str, Float[Array, "*_"]] = {} # Second moment (RMSProp) self.v: dict[str, Float[Array, "*_"]] = {} # Timestep counter for bias correction self.t = jnp.zeros(())
[docs] def call_param( self, key: str, grad: Float[Array, "*s"], **_, ) -> Float[Array, "*s"]: # Initialize moments if not already done if key not in self.m: self.m[key] = jnp.zeros_like(grad) self.v[key] = jnp.zeros_like(grad) # Update biased first moment estimate (momentum) using lerp self.m[key] = lerp(self.m[key], grad, self.beta1) # Update biased second moment estimate (RMSProp) using lerp self.v[key] = lerp(self.v[key], jnp.square(grad), self.beta2) # Compute bias-corrected first moment estimate m_corrected = self.m[key] / (1 - self.beta1**self.t) # Compute bias-corrected second moment estimate v_corrected = self.v[key] / (1 - self.beta2**self.t) # Compute the Adam update return self.lr * m_corrected / (jnp.sqrt(v_corrected) + self.eps)
[docs] def call_model(self, grads, **_): self.t = self.t + 1 return grads