Usage Guide#

This guide will help you get started with the seli package.

Basic Concepts#

Seli is built around a few key concepts:

  1. Modules: The base building block for parameterized functions

  2. Parameters: Trainable values that are initialized when first used

  3. RNGs: Random number generators for reproducible initialization

  4. Optimizers: For training models

Creating Models#

Define new layers by subclassing seli.Module:

import seli as sl
import jax.numpy as jnp

# Add a name to make the module saveable
class Linear(sl.Module, name="example:Linear"):
    def __init__(self, dim: int):
        self.dim = dim

        # Parameters can be directly initialized
        # or an initialization method can be passed
        self.weight = sl.net.Param(init=sl.net.init.Normal("Kaiming"))

    def __call__(self, x):
        # The weight gets initialized on the first call
        # by providing the shape, the value is stored
        return x @ self.weight((x.shape[-1], self.dim))

Working with Random Number Generators#

Set random number generators for all submodules at once:

# No need to pass RNGs between layers
model = Linear(10).set_rngs(42)
y = model(jnp.ones(8))

Training Models#

Training steps can be written concisely:

# Define optimizer and loss
optimizer = sl.opt.Adam(1e-3)
loss = sl.opt.MeanSquaredError()

# Create sample data
x = jnp.ones((32, 8))
y = jnp.ones((32, 10))

# Perform a gradient update - returns updated optimizer, model and loss value
optimizer, model, loss_value = optimizer.minimize(loss, model, y, x)

Saving and Loading Models#

Models can be serialized and loaded:

# Save the model
sl.save(model, "model.npz")

# Load the model
loaded_model = sl.load("model.npz")
assert isinstance(loaded_model, Linear)

Using JIT Compilation#

Seli provides a jit function that handles shared references and static arguments:

import seli as sl

@sl.jit
def forward(model, x):
    return model(x)

# Now forward is JIT compiled for faster execution
result = forward(model, jnp.ones((1, 8)))

Building Complex Models#

Combine modules to create more complex architectures:

class MLP(sl.Module, name="example:MLP"):
    def __init__(self, hidden_dim: int, output_dim: int):
        self.hidden = Linear(hidden_dim)
        self.output = Linear(output_dim)

    def __call__(self, x):
        x = self.hidden(x)
        x = jnp.tanh(x)  # Activation function
        return self.output(x)

# Create and initialize the model
model = MLP(hidden_dim=64, output_dim=10).set_rngs(42)