Usage Guide#
This guide will help you get started with the seli package.
Basic Concepts#
Seli is built around a few key concepts:
Modules: The base building block for parameterized functions
Parameters: Trainable values that are initialized when first used
RNGs: Random number generators for reproducible initialization
Optimizers: For training models
Creating Models#
Define new layers by subclassing seli.Module:
import seli as sl
import jax.numpy as jnp
# Add a name to make the module saveable
class Linear(sl.Module, name="example:Linear"):
def __init__(self, dim: int):
self.dim = dim
# Parameters can be directly initialized
# or an initialization method can be passed
self.weight = sl.net.Param(init=sl.net.init.Normal("Kaiming"))
def __call__(self, x):
# The weight gets initialized on the first call
# by providing the shape, the value is stored
return x @ self.weight((x.shape[-1], self.dim))
Working with Random Number Generators#
Set random number generators for all submodules at once:
# No need to pass RNGs between layers
model = Linear(10).set_rngs(42)
y = model(jnp.ones(8))
Training Models#
Training steps can be written concisely:
# Define optimizer and loss
optimizer = sl.opt.Adam(1e-3)
loss = sl.opt.MeanSquaredError()
# Create sample data
x = jnp.ones((32, 8))
y = jnp.ones((32, 10))
# Perform a gradient update - returns updated optimizer, model and loss value
optimizer, model, loss_value = optimizer.minimize(loss, model, y, x)
Saving and Loading Models#
Models can be serialized and loaded:
# Save the model
sl.save(model, "model.npz")
# Load the model
loaded_model = sl.load("model.npz")
assert isinstance(loaded_model, Linear)
Using JIT Compilation#
Seli provides a jit function that handles shared references and static arguments:
import seli as sl
@sl.jit
def forward(model, x):
return model(x)
# Now forward is JIT compiled for faster execution
result = forward(model, jnp.ones((1, 8)))
Building Complex Models#
Combine modules to create more complex architectures:
class MLP(sl.Module, name="example:MLP"):
def __init__(self, hidden_dim: int, output_dim: int):
self.hidden = Linear(hidden_dim)
self.output = Linear(output_dim)
def __call__(self, x):
x = self.hidden(x)
x = jnp.tanh(x) # Activation function
return self.output(x)
# Create and initialize the model
model = MLP(hidden_dim=64, output_dim=10).set_rngs(42)