Source code for seli.core._jit

from functools import partial, wraps
from typing import Any, Callable, ParamSpec, TypeVar

import jax

from seli.core._module import Module

T = TypeVar("T")
P = ParamSpec("P")

__all__ = [
    "jit",
]


class Arguments(Module, name="builtin.Arguments"):
    """
    Wrapper for the arguments used to call a function.
    """

    def __init__(self, args, kwargs):
        self.args = list(args)
        self.kwargs = kwargs


class Result(Module, name="builtin.Result"):
    """
    Wrapper for the result of a function.
    """

    is_tuple: bool

    def __init__(self, value: Any) -> None:
        self.is_tuple = isinstance(value, tuple)

        if self.is_tuple:
            value = list(value)

        self._value = value

    @property
    def value(self) -> Any:
        if self.is_tuple:
            return tuple(self._value)

        return self._value


@partial(jax.jit, static_argnames=("function",))
def _apply_filter_jit(module: Arguments, function: Any) -> Any:
    result = function(*module.args, **module.kwargs)
    return Result(result)


[docs] def jit(function: Callable[P, T]) -> Callable[P, T]: """ Just-in-time compiling functions. This is a drop-in replacement for jax.jit, that traces shared references between the different arguments. Using jax.jit with references shared between arguments will untie the references in the body of the function and the output. The function will not recompile if only the values inside of the jax.Arrays change. Parameters --- function: Callable Function to apply the jax.jit to. Returns --- compiled: Callable The compiled function. This function takes the same arguments as the original function. """ @wraps(function) def compiled(*args: P.args, **kwargs: P.kwargs) -> T: module = Arguments(args, kwargs) return _apply_filter_jit(module, function).value return compiled