"""
Combining neural network modules.
"""
from typing import Callable
import jax
import jax.numpy as jnp
import jax.random as jrn
from jaxtyping import PRNGKeyArray
from seli.core._module import Module
__all__ = [
"Sequential",
"Add",
"Multiply",
"Constant",
]
[docs]
class Sequential(Module, name="net.Sequential"):
"""
Call a sequence of modules in order.
The first argument is updated, while the args and kwargs are passed to
each module.
Parameters
---
modules: Callable[..., jax.Array]
Modules to call in order.
"""
modules: list[Callable]
def __init__(self, *modules: Callable):
self.modules = list(modules)
[docs]
def __call__(self, x, *args, **kwargs):
for module in self.modules:
x = module(x, *args, **kwargs)
return x
[docs]
class Add(Sequential, name="net.Add"):
"""
Add the output of a sequence of modules to the input.
The first argument is updated, while the args and kwargs are passed to
each module.
Parameters
---
modules: Callable[..., jax.Array]
Modules to call in order.
"""
[docs]
def __call__(self, x, *args, **kwargs):
for module in self.modules:
x = x + module(x, *args, **kwargs)
return x
[docs]
class Multiply(Sequential, name="net.Multiply"):
"""
Multiply the output of a sequence of modules by the input.
The first argument is updated, while the args and kwargs are passed to
each module.
Parameters
---
modules: Callable[..., jax.Array]
Modules to call in order.
"""
[docs]
def __call__(self, x, *args, **kwargs):
for module in self.modules:
x = x * module(x, *args, **kwargs)
return x
[docs]
class Constant(Module, name="net.Constant"):
"""
Return a constant value. Ignores all inputs.
Parameters
---
value: jax.Array
Constant value to return.
"""
value: jax.Array
def __init__(self, value):
self.value = value
[docs]
@classmethod
def random_normal(
cls,
key: PRNGKeyArray,
shape: tuple[int, ...],
std: float = 1.0,
):
"""
Initialize with random normal values.
Parameters
---
key: PRNGKeyArray
The key to use for random number generation.
shape: tuple[int, ...]
The shape of the array.
std: float
The standard deviation of the normal distribution.
Returns
---
Constant
The initialized module.
"""
return cls(jrn.normal(key, shape) * std)
[docs]
@classmethod
def full(cls, x: float = 0.0, shape: tuple[int, ...] = ()):
"""
Initialize with a constant value.
Parameters
---
x: float
The constant value.
shape: tuple[int, ...]
The shape of the array.
Returns
---
Constant
The initialized module.
"""
return cls(jnp.full(shape, x))
@property
def shape(self) -> tuple[int, ...]:
"""
Return the shape of the constant value.
"""
return self.value.shape
@property
def dtype(self) -> jnp.dtype:
"""
Return the dtype of the constant value.
"""
return self.value.dtype
[docs]
def __call__(self, *args, **kwargs) -> jax.Array:
"""
Return the constant value, ignoring all inputs.
"""
return self.value
class Identity(Module, name="net.Identity"):
"""
Return the input. Ignores args and kwargs.
Useful combining modules, e.g. creating a simple residual
connection by combining a module with an Identity module using Add.
Parameters
---
takes no parameters
"""
def __call__(self, x, *args, **kwargs):
return x