Source code for seli.opt._sgd
from jax import Array
from jaxtyping import Float
from seli.opt._opt import Optimizer
def lerp(a, b, t):
"""
Linear interpolation between a and b with factor t.
"""
return a * t + b * (1 - t)
[docs]
class SGD(Optimizer, name="opt.SGD"):
"""
Stochastic Gradient Descent optimizer.
The gradient is the direction of steepest descent. The SGD update simply
scaled the gradient by the learning rate and takes a step in that
direction. It does not account for information from previous gradients.
There has been some evidence that SGD has a regularization effect,
which leads to better generalization performance, at the cost of slower
convergence.
"""
def __init__(self, lr: float = 1e-3):
self.lr = lr
[docs]
def call_param(self, grad: Float[Array, "*s"], **_) -> Float[Array, "*s"]:
# scale the gradient by the learning rate
return grad * self.lr