CrossAttention#
- class seli.net.CrossAttention(dim: int, heads_q: int, heads_k: int | None = None, bias: Array | None = None, mask: Array | None = None, *, scale: float | None = None, is_causal: bool = False, key_value_seq_lengths: Array | None = None, implementation: Literal['xla', 'cudnn'] | None = None)[source]#
Bases:
ModulePerform cross-attention between two sequences.
- Parameters:
dim (int) – The dimension of the final output.
heads_q (int) – The number of heads for the query.
heads_k (int | None) – The number of heads for the key. If None, defaults to heads_q.
kwargs (Any) – For the full list of keyword arguments, see jax.nn.dot_product_attention.
Attributes Summary
Return the input dimension of the first sequence.
Return the input dimension of the second sequence.
Methods Summary
__call__(x, y)Call self as a function.
Attributes Documentation
- dim_in_x#
Return the input dimension of the first sequence. If the module does not have a fixed input dimension yet, return None.
- dim_in_y#
Return the input dimension of the second sequence. If the module does not have a fixed input dimension yet, return None.
Methods Documentation