DotProductAttention#
- class seli.net.DotProductAttention(dim: int, heads_q: int, heads_k: int | None = None, *, norm: bool = False, tanh_cap: float | None = None, scale: float | None = None, is_causal: bool = False, key_value_seq_lengths: Array | None = None, implementation: Literal['xla', 'cudnn'] | None = None)[source]#
Bases:
ModuleApply dot-product attention to a sequence.
- Parameters:
dim (int) – The dimension of the final output.
heads_q (int) – The number of heads for the query.
heads_k (int | None) – The number of heads for the key. If None, defaults to heads_q.
kwargs (Any) – For the full list of keyword arguments, see jax.nn.dot_product_attention.
Attributes Summary
Return the input dimension of the module.
Methods Summary
__call__(x[, bias, mask])Call self as a function.
Attributes Documentation
- dim_in#
Return the input dimension of the module. If the module does not have a fixed input dimension yet, return None.
Methods Documentation