Source code for seli.net._einops
"""
Wrapper around the populer einops library.
"""
from typing import Literal
import einops
from seli.core._module import Module
from seli.core._typecheck import typecheck
__all__ = [
"Rearrange",
"Reduce",
"Repeat",
"Einsum",
]
[docs]
@typecheck
class Rearrange(Module, name="net.Rearrange"):
"""
Rearrange the input tensor according to the given pattern.
See `einops.rearrange` for more information.
Parameters
---
pattern: str
The pattern to rearrange the input tensor.
dims: int
The dimensions to use in the pattern.
"""
def __init__(self, pattern: str, **dims: int) -> None:
self.pattern = pattern
self.dims = dims
[docs]
def __call__(self, *args):
return einops.rearrange(*args, pattern=self.pattern, **self.dims)
[docs]
@typecheck
class Reduce(Module, name="net.Reduce"):
"""
Reduce the input tensor according to the given pattern and the reduction
type. See `einops.reduce` for more information.
Parameters
---
pattern: str
The pattern to reduce the input tensor.
reduction: str
The reduction type to use. Can be one of "sum", "mean", "max", "min",
or "prod".
dims: int
The dimensions to use in the pattern.
"""
def __init__(
self,
pattern: str,
reduction: Literal["sum", "mean", "max", "min", "prod"],
**dims,
) -> None:
self.pattern = pattern
self.reduction = reduction
self.dims = dims
[docs]
def __call__(self, *args):
return einops.reduce(
*args,
pattern=self.pattern,
reduction=self.reduction,
**self.dims,
)
[docs]
@typecheck
class Repeat(Module, name="net.Repeat"):
"""
Repeat the input tensor according to the given pattern. See `einops.repeat`
for more information.
Parameters
---
pattern: str
The pattern to repeat the input tensor.
dims: int
The dimensions to use in the pattern.
"""
def __init__(self, pattern: str, **dims: int) -> None:
self.pattern = pattern
self.dims = dims
[docs]
def __call__(self, *args):
return einops.repeat(*args, pattern=self.pattern, **self.dims)
[docs]
@typecheck
class Einsum(Module, name="net.Einsum"):
"""
Wrapper around `einops.einsum`. Performs a contraction on the input tensors
according to the given pattern.
Parameters
---
pattern: str
The pattern to contract the input tensors.
dims: int
The dimensions to use in the pattern.
"""
def __init__(self, pattern: str) -> None:
self.pattern = pattern
[docs]
def __call__(self, *args):
return einops.einsum(*args, pattern=self.pattern) # type: ignore