Source code for seli.net._norm

"""
Normalization layers.
"""

import jax.lax as lax
from jaxtyping import Array, Float, jaxtyped

from seli.core._module import Module
from seli.core._typecheck import typecheck
from seli.net._init import InitOnes, InitZeros
from seli.net._param import Param

__all__ = [
    "LayerNorm",
    "RMSNorm",
]


[docs] class LayerNorm(Module, name="net.LayerNorm"): """ Normalize the input along the last axis. Then add a learnable offset and scale by a learnable weight along the last axis. Parameters --- eps: float Epsilon value for numerical stability. offset: bool Whether to add 1 to the scale weight before multiplying. If true, the model is initialized to the identity function. If false, the model is initialized to the constant zero function. """ @typecheck def __init__( self, eps: float = 1e-6, offset: float | int = 1, ) -> None: self.eps = eps self.offset = offset self.weight = Param(init=InitZeros()) self.bias = Param(init=InitZeros())
[docs] @jaxtyped(typechecker=typecheck) def __call__( self, x: Float[Array, "*batch dim"], ) -> Float[Array, "*batch dim"]: w = self.weight((x.shape[-1],), x.dtype) b = self.bias((x.shape[-1],), x.dtype) m = x.mean(axis=-1, keepdims=True) x = x - m v = x.var(axis=-1, keepdims=True) r = lax.rsqrt(v + self.eps) x = x * r x = x * (w + self.offset) x = x + b return x
[docs] class RMSNorm(Module, name="net.RMSNorm"): """ Scale the input by the reciprocal of the root mean square along the last axis. Then add a learnable offset and scale by a learnable weight along the last axis. Parameters --- eps: float Epsilon value for numerical stability. axis: int The axis to calculate the root mean square. """ @typecheck def __init__( self, eps: float = 1e-6, offset: float | int = 1, ) -> None: self.eps = eps self.offset = offset self.weight = Param(init=InitOnes()) self.bias = Param(init=InitZeros())
[docs] @jaxtyped(typechecker=typecheck) def __call__( self, x: Float[Array, "*batch dim"], ) -> Float[Array, "*batch dim"]: w = self.weight((x.shape[-1],), x.dtype) b = self.bias((x.shape[-1],), x.dtype) v = (x * x).mean(axis=-1, keepdims=True) r = lax.rsqrt(v + self.eps) x = x * r x = x * (w + self.offset) x = x + b return x