Source code for seli.net._param

from typing import Generic, TypeVar

from jax import Array
from jax.typing import DTypeLike
from jaxtyping import PRNGKeyArray

from seli._env import DEFAULT_FLOAT
from seli.core._module import Module
from seli.net._init import Init
from seli.net._key import RNGs

__all__ = [
    "Param",
]

# make generic to differentiate between initialized and uninitialized
# at type inference time
V = TypeVar("V", bound=Array | None)


[docs] class Param(Module, Generic[V], name="net.Param"): """ Organizes a parameter """ value: V init: Init | None rngs: RNGs | None def __init__( self, *, init: Init, rngs: PRNGKeyArray | None = None, value: V | None = None, collection: str | None = "param", ) -> None: self.init = init self.value = value self.collection = collection self.rngs = RNGs(rngs, "init")
[docs] @classmethod def from_value( cls, value: Array, *, collection: str | None = "param", ) -> "Param[Array]": return cls(init=None, data=value, collection=collection)
@property def initialized(self) -> bool: return self.value is not None
[docs] def __call__( self, shape: tuple[int, ...], dtype: DTypeLike = DEFAULT_FLOAT, ) -> Array: if not self.initialized: if not self.rngs.initialized: error = "Key has not been set" raise ValueError(error) assert self.init is not None, "Init or value was changed to None?" self.value = self.init(self.rngs.key, shape, dtype) if self.value.shape != shape: error = f"Expected shape {shape}, got {self.value.shape}" raise ValueError(error) if self.value.dtype != dtype: error = f"Expected dtype {dtype}, got {self.value.dtype}" raise ValueError(error) return self.value