xFormers optimized operators¶
Memory-efficient attention¶
- xformers.ops.memory_efficient_attention(query: Tensor, key: Tensor, value: Tensor, attn_bias: Optional[Union[Tensor, AttentionBias]] = None, p: float = 0.0, scale: Optional[float] = None, *, op: Optional[Tuple[Optional[Type[AttentionFwOpBase]], Optional[Type[AttentionBwOpBase]]]] = None) Tensor [source]¶
Implements the memory-efficient attention mechanism following “Self-Attention Does Not Need O(n^2) Memory”.
- Inputs shape
Input tensors must be in format
[B, M, H, K]
, where B is the batch size, M the sequence length, H the number of heads, and K the embeding size per headIf inputs have dimension 3, it is assumed that the dimensions are
[B, M, K]
andH=1
Inputs can also be of dimension 5 with GQA - see note below
Inputs can be non-contiguous - we only require the last dimension’s stride to be 1
- Equivalent pytorch code
scale = 1 / query.shape[-1] ** 0.5 query = query * scale attn = query @ key.transpose(-2, -1) if attn_bias is not None: attn = attn + attn_bias attn = attn.softmax(-1) attn = F.dropout(attn, p) return attn @ value
- Examples
import xformers.ops as xops # Compute regular attention y = xops.memory_efficient_attention(q, k, v) # With a dropout of 0.2 y = xops.memory_efficient_attention(q, k, v, p=0.2) # Causal attention y = xops.memory_efficient_attention( q, k, v, attn_bias=xops.LowerTriangularMask() )
- Supported hardware
NVIDIA GPUs with compute capability above 6.0 (P100+), datatype
f16
,bf16
andf32
.- EXPERIMENTAL
Using with Multi Query Attention (MQA) and Grouped Query Attention (GQA):
MQA/GQA is an experimental feature supported only for the forward pass. If you have 16 heads in query, and 2 in key/value, you can provide 5-dim tensors in the
[B, M, G, H, K]
format, whereG
is the number of head groups (here 2), andH
is the number of heads per group (8 in the example).Please note that xFormers will not automatically broadcast the inputs, so you will need to broadcast it manually before calling memory_efficient_attention.
- GQA/MQA example
import torch import xformers.ops as xops B, M, K = 3, 32, 128 kwargs = dict(device="cuda", dtype=torch.float16) q = torch.randn([B, M, 8, K], **kwargs) k = torch.randn([B, M, 2, K], **kwargs) v = torch.randn([B, M, 2, K], **kwargs) out_gqa = xops.memory_efficient_attention( q.reshape([B, M, 2, 4, K]), k.reshape([B, M, 2, 1, K]).expand([B, M, 2, 4, K]), v.reshape([B, M, 2, 1, K]).expand([B, M, 2, 4, K]), )
- Raises
NotImplementedError – if there is no operator available to compute the MHA
ValueError – if inputs are invalid
- Parameters
query – Tensor of shape
[B, Mq, H, K]
key – Tensor of shape
[B, Mkv, H, K]
value – Tensor of shape
[B, Mkv, H, Kv]
attn_bias – Bias to apply to the attention matrix - defaults to no masking. For common biases implemented efficiently in xFormers, see
xformers.ops.fmha.attn_bias.AttentionBias
. This can also be atorch.Tensor
for an arbitrary mask (slower).p – Dropout probability. Disabled if set to
0.0
scale – Scaling factor for
Q @ K.transpose()
. If set toNone
, the default scale (q.shape[-1]**-0.5) will be used.op – The operators to use - see
xformers.ops.AttentionOpBase
. If set toNone
(recommended), xFormers will dispatch to the best available operator, depending on the inputs and options.
- Returns
multi-head attention Tensor with shape
[B, Mq, H, Kv]
- class xformers.ops.AttentionBias[source]¶
Bases:
object
- Base class for a custom bias that can be applied as the attn_bias argument in
That function has the ability to add a tensor, the attention bias, to the QK^T matrix before it is used in the softmax part of the attention calculation. The attention bias tensor with shape (B or 1, n_queries, number of keys) can be given as the attn_bias input. The most common use case is for an attention bias is to contain only zeros and negative infinities, which forms a mask so that some queries only attend to some keys.
Children of this class define alternative things which can be used as the attn_bias input to define an attention bias which forms such a mask, for some common cases.
When using an
xformers.ops.AttentionBias
instead of atorch.Tensor
, the mask matrix does not need to be materialized, and can be hardcoded into some kernels for better performance.See:
- class xformers.ops.AttentionOpBase[source]¶
Bases:
BaseOperator
Base class for any attention operator in xFormers
See:
Available implementations¶
- class xformers.ops.fmha.cutlass.FwOp[source]¶
xFormers’ MHA kernel based on CUTLASS. Supports a large number of settings (including without TensorCores, f32 …) and GPUs as old as P100 (Sm60)
- class xformers.ops.fmha.cutlass.BwOp[source]¶
xFormers’ MHA kernel based on CUTLASS. Supports a large number of settings (including without TensorCores, f32 …) and GPUs as old as P100 (Sm60)
- class xformers.ops.fmha.flash.FwOp[source]¶
Operator that computes memory-efficient attention using Flash-Attention implementation.
- class xformers.ops.fmha.flash.BwOp[source]¶
Operator that computes memory-efficient attention using Flash-Attention implementation.
- class xformers.ops.fmha.triton.FwOp[source]¶
Operator that computes memory-efficient attention using Tri Dao’s implementation, based on Phil Tillet’s code
- class xformers.ops.fmha.triton.BwOp[source]¶
Operator that computes memory-efficient attention using Tri Dao’s implementation, based on Phil Tillet’s code
- class xformers.ops.fmha.small_k.FwOp[source]¶
An operator optimized for very small values of K (
K <= 32
) and f32 pre-Ampere as it does not use TensorCores. Only supports contiguous inputs in BMK format, so an extra reshape or contiguous call might be done.- Deprecated
This operator is deprecated and should not be used in new code
- class xformers.ops.fmha.small_k.BwOp[source]¶
An operator optimized for very small values of K (
K <= 32
) and f32 pre-Ampere as it does not use TensorCores. Only supports contiguous inputs in BMK format, so an extra reshape or contiguous call might be done.- Deprecated
This operator is deprecated and should not be used in new code
Attention biases¶
- class xformers.ops.fmha.attn_bias.AttentionBias[source]¶
Bases:
object
- Base class for a custom bias that can be applied as the attn_bias argument in
That function has the ability to add a tensor, the attention bias, to the QK^T matrix before it is used in the softmax part of the attention calculation. The attention bias tensor with shape (B or 1, n_queries, number of keys) can be given as the attn_bias input. The most common use case is for an attention bias is to contain only zeros and negative infinities, which forms a mask so that some queries only attend to some keys.
Children of this class define alternative things which can be used as the attn_bias input to define an attention bias which forms such a mask, for some common cases.
When using an
xformers.ops.AttentionBias
instead of atorch.Tensor
, the mask matrix does not need to be materialized, and can be hardcoded into some kernels for better performance.See:
- class xformers.ops.fmha.attn_bias.LowerTriangularMask(*tensor_args, **tensor_kwargs)[source]¶
Bases:
AttentionBias
A lower-triangular (aka causal) mask
A query Q cannot attend to a key which is farther from the initial key than Q is from the initial query.
- class xformers.ops.fmha.attn_bias.LowerTriangularMaskWithTensorBias(bias: Tensor)[source]¶
Bases:
LowerTriangularMask
A lower-triangular (aka causal) mask with an additive bias
- class xformers.ops.fmha.attn_bias.BlockDiagonalMask(q_seqinfo: _SeqLenInfo, k_seqinfo: _SeqLenInfo, _batch_sizes: Optional[Sequence[int]] = None)[source]¶
Bases:
AttentionBias
A block-diagonal mask that can be passed as
attn_bias
argument toxformers.ops.memory_efficient_attention
.Queries and Keys are each divided into the same number of blocks. Queries in block i only attend to keys in block i.
- Example
import torch from xformers.ops import fmha K = 16 dtype = torch.float16 device = "cuda" list_x = [ torch.randn([1, 3, 1, K], dtype=dtype, device=device), torch.randn([1, 6, 1, K], dtype=dtype, device=device), torch.randn([1, 2, 1, K], dtype=dtype, device=device), ] attn_bias, x = fmha.BlockDiagonalMask.from_tensor_list(list_x) linear = torch.nn.Linear(K, K * 3).to(device=device, dtype=dtype) q, k, v = linear(x).reshape([1, -1, 1, 3, K]).unbind(-2) out = fmha.memory_efficient_attention(q, k, v, attn_bias=attn_bias) list_out = attn_bias.split(out) print(list_out[0].shape) # [1, 3, 1, K] assert tuple(list_out[0].shape) == (1, 3, 1, K)
- materialize(shape: Tuple[int, ...], dtype: dtype = torch.float32, device: Union[str, device] = 'cpu') Tensor [source]¶
Materialize the attention bias - for debugging & testing
- classmethod from_seqlens(q_seqlen: Sequence[int], kv_seqlen: Optional[Sequence[int]] = None) BlockDiagonalMask [source]¶
Creates a
BlockDiagonalMask
from a list of tensors lengths for query and key/value.- Parameters
q_seqlen (Union[Sequence[int], torch.Tensor]) – List or tensor of sequence lengths for query tensors
kv_seqlen (Union[Sequence[int], torch.Tensor], optional) – List or tensor of sequence lengths for key/value. (Defaults to
q_seqlen
.)
- Returns
BlockDiagonalMask
- classmethod from_tensor_list(tensors: Sequence[Tensor]) Tuple[BlockDiagonalMask, Tensor] [source]¶
Creates a
BlockDiagonalMask
from a list of tensors, and returns the tensors concatenated on the sequence length dimension- Parameters
tensors (Sequence[torch.Tensor]) – A list of tensors of shape
[B, M_i, *]
. All tensors should have the same dimension and the same batch sizeB
, but they can have different sequence lengthM
.- Returns
Tuple[BlockDiagonalMask, torch.Tensor] – The corresponding bias for the attention along with tensors concatenated on the sequence length dimension, with shape
[1, sum_i{M_i}, *]
- split(tensor: Tensor) Sequence[Tensor] [source]¶
The inverse operation of
BlockDiagonalCausalMask.from_tensor_list
- Parameters
tensor (torch.Tensor) – Tensor of tokens of shape
[1, sum_i{M_i}, *]
- Returns
Sequence[torch.Tensor] – A list of tokens with possibly different sequence lengths
- make_causal() BlockDiagonalCausalMask [source]¶
Makes each block causal
- make_causal_from_bottomright() BlockDiagonalCausalFromBottomRightMask [source]¶
Makes each block causal with a possible non-causal prefix
- make_local_attention(window_size: int) BlockDiagonalCausalLocalAttentionMask [source]¶
Experimental: Makes each block causal with local attention
- make_local_attention_from_bottomright(window_size: int) BlockDiagonalCausalLocalAttentionFromBottomRightMask [source]¶
Experimental: Makes each block causal with local attention, start from bottom right
- class xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask(q_seqinfo: _SeqLenInfo, k_seqinfo: _SeqLenInfo, _batch_sizes: Optional[Sequence[int]] = None)[source]¶
Bases:
BlockDiagonalMask
Same as
xformers.ops.fmha.attn_bias.BlockDiagonalMask
, except that each block is causal.Queries and Keys are each divided into the same number of blocks. A query Q in block i cannot attend to a key which is not in block i, nor one which is farther from the initial key in block i than Q is from the initial query in block i.
- class xformers.ops.fmha.attn_bias.BlockDiagonalCausalFromBottomRightMask(q_seqinfo: _SeqLenInfo, k_seqinfo: _SeqLenInfo, _batch_sizes: Optional[Sequence[int]] = None)[source]¶
Bases:
BlockDiagonalMask
Same as
xformers.ops.fmha.attn_bias.BlockDiagonalMask
, except that each block is causal. This mask allows for a non-causal prefix NOTE: Each block should have num_keys >= num_queries otherwise the forward pass is not defined (softmax of vector of -inf in the attention)Queries and keys are each divided into the same number of blocks. A query Q in block i cannot attend to a key which is not in block i, nor one which nearer the final key in block i than Q is to the final query in block i.
- class xformers.ops.fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask(q_seqinfo: _SeqLenInfo, k_seqinfo: _PaddedSeqLenInfo, causal_diagonal: Optional[Any] = None)[source]¶
Bases:
AttentionBias
Same as
xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask
, except an offset on causality is allowed for each block and we support padding for k/vThe keys and values are divided into blocks which are padded out to the same total length. For example, if there is space for 12 keys, for three blocks of max length 4, but we only want to use the first 2, 3 and 2 of each block, use kv_padding=4 and kv_seqlens=[2, 3, 2]. The queries are divided into blocks, without padding, of lengths given by q_seqlen.
A query Q in block i cannot attend to a key which is not in block i, nor one which is not in use (i.e. in the padded area), nor one which is nearer to the final key in block i than Q is to the final query in block i.
- materialize(shape: Tuple[int, ...], dtype: dtype = torch.float32, device: Union[str, device] = 'cpu') Tensor [source]¶
Materialize the attention bias - for debugging & testing
- classmethod from_seqlens(q_seqlen: Sequence[int], kv_padding: int, kv_seqlen: Sequence[int], causal_diagonal: Optional[Any] = None) BlockDiagonalCausalWithOffsetPaddedKeysMask [source]¶
Creates a
BlockDiagonalCausalWithOffsetPaddedKeysMask
from a list of tensor lengths for query and key/value.- Parameters
- Returns
BlockDiagonalCausalWithOffsetPaddedKeysMask
- class xformers.ops.fmha.attn_bias.BlockDiagonalCausalLocalAttentionMask(q_seqinfo: _SeqLenInfo, k_seqinfo: _SeqLenInfo, _batch_sizes: Optional[Sequence[int]] = None, _window_size: int = 0)[source]¶
Bases:
BlockDiagonalCausalMask
(Experimental feature) Same as
xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask
. This makes the mask “local” and the attention pattern banded.Query i only attends to keys in its block and cannot attend keys further than “window_size” from it.
- class xformers.ops.fmha.attn_bias.BlockDiagonalCausalLocalAttentionFromBottomRightMask(q_seqinfo: _SeqLenInfo, k_seqinfo: _SeqLenInfo, _batch_sizes: Optional[Sequence[int]] = None, _window_size: int = 0)[source]¶
Bases:
BlockDiagonalCausalFromBottomRightMask
(Experimental feature) Same as
xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask
. This makes the mask “local” and the attention pattern banded.Query i only attends to keys in its block and cannot attend keys further than “window_size” from it.
Non-autograd implementations¶
- xformers.ops.fmha.memory_efficient_attention_backward(grad: Tensor, output: Tensor, lse: Tensor, query: Tensor, key: Tensor, value: Tensor, attn_bias: Optional[Union[Tensor, AttentionBias]] = None, p: float = 0.0, scale: Optional[float] = None, *, op: Optional[Type[AttentionBwOpBase]] = None) Tuple[Tensor, Tensor, Tensor] [source]¶
Computes the gradient of the attention. Returns a tuple (dq, dk, dv) See
xformers.ops.memory_efficient_attention
for an explanation of the arguments. lse is the tensor returned byxformers.ops.memory_efficient_attention_forward_requires_grad
- xformers.ops.fmha.memory_efficient_attention_forward(query: Tensor, key: Tensor, value: Tensor, attn_bias: Optional[Union[Tensor, AttentionBias]] = None, p: float = 0.0, scale: Optional[float] = None, *, op: Optional[Type[AttentionFwOpBase]] = None) Tensor [source]¶
Calculates the forward pass of
xformers.ops.memory_efficient_attention
.
- xformers.ops.fmha.memory_efficient_attention_forward_requires_grad(query: Tensor, key: Tensor, value: Tensor, attn_bias: Optional[Union[Tensor, AttentionBias]] = None, p: float = 0.0, scale: Optional[float] = None, *, op: Optional[Type[AttentionFwOpBase]] = None) Tuple[Tensor, Tensor] [source]¶
Returns a tuple (output, lse), where lse can be used to compute the backward pass later. See
xformers.ops.memory_efficient_attention
for an explanation of the arguments Seexformers.ops.memory_efficient_attention_backward
for running the backward pass