"""
This module provides a system for registering modules. This is useful for
serializing and deserializing modules without having to know the module
structure beforehand.
See the `_serialize` module for more information on how this is used. This is
a separate module, because the registry is needed in the module system, but
the module system is needed in the serialize module.
"""
import logging
from collections.abc import Hashable
from typing import Any
import jax.nn as jnn
import jax.numpy as jnp
import jax.tree_util as jtu
from seli.core._typecheck import typecheck
logger = logging.getLogger(__name__)
__all__ = [
"registry_add",
"registry_str",
"registry_obj",
"is_registry_str",
]
REGISTRY: dict[str, Hashable] = {}
REGISTRY_INVERSE: dict[Hashable, str] = {}
@typecheck
class ModuleBase:
"""
Base class for the Module class, all modules that inherit from this class,
will be registered as Jax PyTree nodes and can be serialized and added to
the global registry, such that they can be serialized and deserialized
later.
"""
def __init_subclass__(
cls,
name: str | None = None,
overwrite: bool = False,
):
if hasattr(cls, "tree_flatten") and hasattr(cls, "tree_unflatten"):
cls = jtu.register_pytree_node_class(cls)
if name is not None:
registry_add(name, cls, overwrite=overwrite)
if hasattr(cls, "__slots__"):
error = f"Module {cls} has __slots__, which is not supported"
raise TypeError(error)
[docs]
@typecheck
def registry_add(
name: str,
value: Any,
overwrite: bool = False,
) -> None:
"""
Add a something to the registry. Everything that is added to the registry
can be serialized and deserialized later.
Parameters
---
name: str
The name of the module to register.
value: Any
The value to register.
overwrite: bool
If True, overwrite the existing value if it is already registered.
"""
if not overwrite and name in REGISTRY:
if REGISTRY[name] is value:
return
msg = f"Module {name} already registered, skipping {value} ({type(value)})"
msg += f" already registered as {REGISTRY[name]} ({type(REGISTRY[name])})"
logger.warning(msg)
return
REGISTRY[name] = value
REGISTRY_INVERSE[value] = name
[docs]
@typecheck
def registry_str(obj: Any) -> str:
"""
Get the registry string for an object. The registry string is a string
that uniquely identifies the object in the registry.
The object is looked up in the inverse registry and the name is returned
prefixed with `__registry__:`.
Parameters
----------
obj : Any
The object to get the registry string for.
Returns
-------
str
The registry string for the object.
"""
return f"__registry__:{REGISTRY_INVERSE[obj]}"
[docs]
@typecheck
def registry_obj(name: str) -> Hashable:
"""
Get the object for a registry string. The registry string is a string
that uniquely identifies the object in the registry.
The registry string is converted back to the original object by looking
up the name in the registry.
Parameters
----------
name : str
The registry string to get the object for.
Returns
-------
obj : Hashable
The object that is registered under the given name.
"""
if not is_registry_str(name):
raise ValueError(f"Invalid registry string: {name}")
name = name[len("__registry__:") :]
if name not in REGISTRY:
raise ValueError(f"Module {name} not registered")
return REGISTRY[name]
[docs]
@typecheck
def is_registry_str(obj: Any) -> bool:
"""
Check if an object is a registry string. This is useful, because to avoid
clashes were general strings are used in a module, but should not be
converted to values in the registry.
Parameters
----------
obj : Any
The object to check.
Returns
-------
bool
True if the object is a registry string, False otherwise.
"""
return isinstance(obj, str) and obj.startswith("__registry__:")
# make common activation functions serializable
registry_add("jax.nn.celu", jnn.celu)
registry_add("jax.nn.elu", jnn.elu)
registry_add("jax.nn.gelu", jnn.gelu)
registry_add("jax.nn.glu", jnn.glu)
registry_add("jax.nn.hard_sigmoid", jnn.hard_sigmoid)
registry_add("jax.nn.hard_silu", jnn.hard_silu)
registry_add("jax.nn.hard_swish", jnn.hard_swish)
registry_add("jax.nn.hard_tanh", jnn.hard_tanh)
registry_add("jax.nn.leaky_relu", jnn.leaky_relu)
registry_add("jax.nn.log_sigmoid", jnn.log_sigmoid)
registry_add("jax.nn.log_softmax", jnn.log_softmax)
registry_add("jax.nn.logsumexp", jnn.logsumexp)
registry_add("jax.nn.standardize", jnn.standardize)
registry_add("jax.nn.relu", jnn.relu)
registry_add("jax.nn.relu6", jnn.relu6)
registry_add("jax.nn.selu", jnn.selu)
registry_add("jax.nn.sigmoid", jnn.sigmoid)
registry_add("jax.nn.soft_sign", jnn.soft_sign)
registry_add("jax.nn.softmax", jnn.softmax)
registry_add("jax.nn.softplus", jnn.softplus)
registry_add("jax.nn.sparse_plus", jnn.sparse_plus)
registry_add("jax.nn.sparse_sigmoid", jnn.sparse_sigmoid)
registry_add("jax.nn.silu", jnn.silu)
registry_add("jax.nn.swish", jnn.swish)
registry_add("jax.nn.squareplus", jnn.squareplus)
registry_add("jax.nn.mish", jnn.mish)
# make the jax data types serializable
registry_add("jax.numpy.bool_", jnp.bool_)
registry_add("jax.numpy.complex64", jnp.complex64)
registry_add("jax.numpy.complex128", jnp.complex128)
registry_add("jax.numpy.float16", jnp.float16)
registry_add("jax.numpy.float32", jnp.float32)
registry_add("jax.numpy.float64", jnp.float64)
registry_add("jax.numpy.bfloat16", jnp.bfloat16)
registry_add("jax.numpy.int8", jnp.int8)
registry_add("jax.numpy.int16", jnp.int16)
registry_add("jax.numpy.int32", jnp.int32)
registry_add("jax.numpy.int64", jnp.int64)
registry_add("jax.numpy.uint8", jnp.uint8)
registry_add("jax.numpy.uint16", jnp.uint16)
registry_add("jax.numpy.uint32", jnp.uint32)
registry_add("jax.numpy.uint64", jnp.uint64)