Source code for seli.net._linear

"""
Parametrized linear and affine transformations layers.
"""

from jaxtyping import Array, Float, jaxtyped

from seli.core._module import Module
from seli.core._typecheck import typecheck
from seli.net._init import InitUniform, InitZeros
from seli.net._param import Param

__all__ = [
    "Linear",
    "Bias",
    "Scale",
]


[docs] class Linear(Module, name="net.Linear"): """ Apply a learnable linear transformation to last axis of the input. Parameters --- key: PRNGKeyArray Key to use for random initialization. dim: int Dimensionality of the output. The input dimension is inferred from the last axis of the first input. """ def __init__(self, dim: int) -> None: self.dim = dim self.weight = Param(init=InitUniform(init="Glorot"))
[docs] @jaxtyped(typechecker=typecheck) def __call__( self, x: Float[Array, "*batch dim_in"], ) -> Float[Array, "*batch {self.dim}"]: w = self.weight((x.shape[-1], self.dim), x.dtype) return x @ w
@property def dim_in(self) -> int | None: """ Return the input dimension of the module. If the module does not have a fixed input dimension yet, return None. """ if not self.weight.initialized: return None return self.weight.value.shape[0]
[docs] class Bias(Module, name="net.Bias"): """ Add a learnable bias to the last axis of the input. """ def __init__(self) -> None: self.bias = Param(init=InitZeros())
[docs] @jaxtyped(typechecker=typecheck) def __call__( self, x: Float[Array, "*batch dim"], ) -> Float[Array, "*batch dim"]: b = self.bias((x.shape[-1],), x.dtype) return x + b
@property def dim(self) -> int | None: """ Return the dimension of the bias. If the bias has not been initialized yet, return None. """ if not self.bias.initialized: return None return self.bias.value.shape[0]
[docs] class Affine(Module, name="net.Affine"): """ Apply a learnable linear transformation followed by a learnable bias. Parameters --- dim: int The output dimension of the linear transformation. The input dimension is inferred from the last axis of the first input. """ def __init__(self, dim: int) -> None: self.linear = Linear(dim) self.bias = Bias()
[docs] @jaxtyped(typechecker=typecheck) def __call__( self, x: Float[Array, "*batch dim_in"], ) -> Float[Array, "*batch dim"]: return self.bias(self.linear(x))
@property def dim_in(self) -> int | None: return self.linear.dim_in
[docs] class Scale(Module, name="net.Scale"): """ Scale the last axis of the input by a learnable vector. Parameters --- offset: bool If True the input will be scaled by `1 + scale` instead of `scale`. The scale is initialized to 0. """ def __init__(self, offset: float = 1) -> None: self.offset = offset self.scale = Param(init=InitZeros())
[docs] @jaxtyped(typechecker=typecheck) def __call__( self, x: Float[Array, "*batch dim"], ) -> Float[Array, "*batch dim"]: s = self.scale((x.shape[-1],), x.dtype) return x * (s + self.offset)
@property def dim(self) -> int | None: """ Return the dimension of the scale. If the scale has not been initialized yet, return None. """ if not self.scale.initialized: return None return self.scale.value.shape[0]