DotProductAttention#

class seli.net.DotProductAttention(dim: int, heads_q: int, heads_k: int | None = None, *, norm: bool = False, tanh_cap: float | None = None, scale: float | None = None, is_causal: bool = False, key_value_seq_lengths: Array | None = None, implementation: Literal['xla', 'cudnn'] | None = None)[source]#

Bases: Module

Apply dot-product attention to a sequence.

Parameters:
  • dim (int) – The dimension of the final output.

  • heads_q (int) – The number of heads for the query.

  • heads_k (int | None) – The number of heads for the key. If None, defaults to heads_q.

  • kwargs (Any) – For the full list of keyword arguments, see jax.nn.dot_product_attention.

Attributes Summary

dim_in

Return the input dimension of the module.

Methods Summary

__call__(x[, bias, mask])

Call self as a function.

Attributes Documentation

dim_in#

Return the input dimension of the module. If the module does not have a fixed input dimension yet, return None.

Methods Documentation

__call__(x: Float[Array, '*batch seq dim'], bias: Array | None = None, mask: Array | None = None) Float[Array, '*batch seq dim'][source]#

Call self as a function.