CrossAttention#

class seli.net.CrossAttention(dim: int, heads_q: int, heads_k: int | None = None, bias: Array | None = None, mask: Array | None = None, *, scale: float | None = None, is_causal: bool = False, key_value_seq_lengths: Array | None = None, implementation: Literal['xla', 'cudnn'] | None = None)[source]#

Bases: Module

Perform cross-attention between two sequences.

Parameters:
  • dim (int) – The dimension of the final output.

  • heads_q (int) – The number of heads for the query.

  • heads_k (int | None) – The number of heads for the key. If None, defaults to heads_q.

  • kwargs (Any) – For the full list of keyword arguments, see jax.nn.dot_product_attention.

Attributes Summary

dim_in_x

Return the input dimension of the first sequence.

dim_in_y

Return the input dimension of the second sequence.

Methods Summary

__call__(x, y)

Call self as a function.

Attributes Documentation

dim_in_x#

Return the input dimension of the first sequence. If the module does not have a fixed input dimension yet, return None.

dim_in_y#

Return the input dimension of the second sequence. If the module does not have a fixed input dimension yet, return None.

Methods Documentation

__call__(x: Float[Array, '*batch seq dim_seq'], y: Float[Array, '*batch aux dim_aux']) Float[Array, '*batch seq dim_seq'][source]#

Call self as a function.