Source code for seli.opt._loss

import jax.nn as jnn
import jax.numpy as jnp
from jax import Array
from jaxtyping import Float

from seli.core._module import Module, NodeType
from seli.core._typecheck import typecheck


[docs] @typecheck class Loss(Module, name="opt.Loss"): """ Base class for all loss functions. """ @property def collection(self) -> str | None: if not hasattr(self, "_collection"): return "param" return self._collection @collection.setter def collection(self, value: str | None): self._collection = value
[docs] def __call__(self, model: NodeType, *args, **kwargs) -> Float[Array, ""]: error = "Subclasses must implement this method" raise NotImplementedError(error)
[docs] class MeanSquaredError(Loss, name="opt.MeanSquaredError"): """ Mean squared error loss function. """
[docs] def __call__( self, model: NodeType, y_true: Float[Array, "..."], *model_args, **model_kwargs, ) -> Float[Array, ""]: y_pred = model(*model_args, **model_kwargs) return jnp.mean(jnp.square(y_pred - y_true))
[docs] class MeanAbsoluteError(Loss, name="opt.MeanAbsoluteError"): """ Mean absolute error loss function. """
[docs] def __call__( self, model: NodeType, y_true: Float[Array, "..."], *model_args, **model_kwargs, ) -> Float[Array, ""]: y_pred = model(*model_args, **model_kwargs) return jnp.mean(jnp.abs(y_pred - y_true))
[docs] class BinaryCrossEntropy(Loss, name="opt.BinaryCrossEntropy"): """ Binary cross entropy loss function. """
[docs] def __call__( self, model: NodeType, y_true: Float[Array, "..."], *model_args, **model_kwargs, ) -> Float[Array, ""]: y_logits = model(*model_args, **model_kwargs) pos_term = y_true * jnn.log_sigmoid(y_logits) neg_term = (1 - y_true) * jnn.log_sigmoid(-y_logits) return -jnp.mean(pos_term + neg_term)