API Reference#

Flexible Flash Attention#

To support computing irregular-shaped masks, we implemented a flexible_flash_attention kernel, which can be invoked through the following interface.

magi_attention.functional.flex_flash_attn.flex_flash_attn_func()#

An interface similar to flash attention that doesn’t require distributed environment, dispatch or undispatch. Directly call magi_attn_kernel to get attention output and lse. This is faster when you don’t need context parallel.

Parameters:
  • q (torch.Tensor) – Query tensor.

  • k (torch.Tensor) – Key tensor.

  • v (torch.Tensor) – Value tensor.

  • q_ranges (torch.Tensor) – Query ranges tensor to represent the attn mask.

  • k_ranges (torch.Tensor) – Key ranges tensor to represent the attn mask.

  • attn_type_map (torch.Tensor, optional) –

    Attention type map tensor with dtype=torch.int32, Defaults to None to apply full attention for all ranges. The values specify the attention type for each token:

    • 0: full attention

    • 1: causal attention

    • 2: inverse causal attention

    • 3: bidirectional causal attention

    More information about the attention type map can be found in the Note below.

  • sink (torch.Tensor, optional) – Learnable sink token tensor. Defaults to None to not apply attention sink.

  • softmax_scale (float, optional) – Softmax scale. Defaults to None to use: 1/sqrt(head_dim).

  • softcap (float, optional) – Softcap. Defaults to 0.0.

  • deterministic (bool, optional) – Whether to use deterministic attention. Defaults to False.

  • sm_margin (int, optional) – The amount of SMs reserved out, useful when considering overlapping with other kernels such as communication kernels. Defaults to 0 to use all available SMs.

  • disable_fwd_atomic_reduction (bool, optional) –

    Whether to disable forward atomic reduction. Defaults to False.

    If you can ensure q_ranges is non-overlapped, you can set this to True for better performance. The “overlap” term among q_ranges is defined as: if any two q_range in q_ranges have non-empty intersection, then it is overlapped. For example, q_ranges = [[0, 15], [10, 20], [20, 30]] is overlapped since q_range1 = [0, 15] and q_range2 = [10, 20] intersect, while `` q_ranges`` = [[0, 15], [15, 20], [20, 30]] then is non-overlapped.

  • auto_range_merge (bool, optional) – Whether to automatically merge k_ranges for the same q_range. Defaults to False. Note: This flag is useful for sparse attention scenarios but still under development.

  • ref_block_size (tuple[int, int] | None)

Returns:

  • out (torch.Tensor): Attention output tensor

  • lse (torch.Tensor): Log-sum-exp values with dtype=torch.float32.

Return type:

tuple[torch.Tensor, torch.Tensor]

Shape:
  • q: (num_tokens_q, num_heads_q, head_dim)

  • k: (num_tokens_kv, num_heads_kv, head_dim)

  • v: (num_tokens_kv, num_heads_kv, head_dim)

  • sink: (num_tokens_sink, num_heads_q)

  • q_ranges: (num_ranges, 2)

  • k_ranges: (num_ranges, 2)

  • attn_type_map: (num_ranges,)

  • out: (num_tokens_q, num_heads_q, head_dim)

  • lse: (num_tokens_q, num_heads_q)

Note

The attn_type_map explains the semantics of different attention mask types. In addition to the descriptions below, see our blog for a visual explanation: https://sandai-org.github.io/MagiAttention/blog/#flex-flash-attn

  1. Full attention:

    If seqlen_q = 5 and seqlen_k = 2:

    1 1
    1 1
    1 1
    1 1
    1 1
    

    If seqlen_q = 2 and seqlen_k = 5:

    1 1 1 1 1
    1 1 1 1 1
    

    If seqlen_q = 5 and seqlen_k = 5:

    1 1 1 1 1
    1 1 1 1 1
    1 1 1 1 1
    1 1 1 1 1
    1 1 1 1 1
    
  2. Causal attention (bottom-right aligned):

    If seqlen_q = 5 and seqlen_k = 2:

    0 0
    0 0
    0 0
    1 0
    1 1
    

    If seqlen_q = 2 and seqlen_k = 5:

    1 1 1 1 0
    1 1 1 1 1
    

    If seqlen_q = 5 and seqlen_k = 5:

    1 0 0 0 0
    1 1 0 0 0
    1 1 1 0 0
    1 1 1 1 0
    1 1 1 1 1
    
  3. Inverse causal attention (top-left aligned):

    If seqlen_q = 5 and seqlen_k = 2:

    1 1
    0 1
    0 0
    0 0
    0 0
    

    If seqlen_q = 2 and seqlen_k = 5:

    1 1 1 1 1
    0 1 1 1 1
    

    If seqlen_q = 5 and seqlen_k = 5:

    1 1 1 1 1
    0 1 1 1 1
    0 0 1 1 1
    0 0 0 1 1
    0 0 0 0 1
    
  4. Bidirectional causal attention (intersection of causal and inverse causal):

    This is the element-wise AND of causal and inverse causal masks.

    If seqlen_q = 5 and seqlen_k = 2:

    0 0
    0 0
    0 0
    0 0
    0 0
    

    If seqlen_q = 2 and seqlen_k = 5:

    1 1 1 1 0
    0 1 1 1 1
    

    If seqlen_q = 5 and seqlen_k = 5:

    1 0 0 0 0
    0 1 0 0 0
    0 0 1 0 0
    0 0 0 1 0
    0 0 0 0 1
    

Dispatch#

Varlen Dispatch#

If you’re using a mask defined by cu_seqlens, such as a varlen full or varlen causal mask, we’ve designed a similar interface inspired by FlashAttention’s API, making it easy for you to get started quickly. In the function named magi_attn_varlen_dispatch, you can obtain the dispatched x and key.

magi_attention.api.magi_attn_interface.magi_attn_varlen_dispatch()#

This is a flash-attn-varlen like interface, to generate q_ranges, k_ranges and attn_mask_type from cu_seqlens_q, cu_seqlens_k, causal and window_size, further calculate DistAttnRuntimeKey, generate the corr. inner DistAttnRuntimeMgr, finally pad and dispatch the input tensor to local tensor.

Parameters:
  • x (torch.Tensor) – input tensor

  • cu_seqlens_q (torch.Tensor) – Cumulative sequence lengths for queries.

  • cu_seqlens_k (torch.Tensor) – Cumulative sequence lengths for keys.

  • pad_size (int) – the size to pad along seq_dim. The seq_len need to be divisable by chunk_size * cp_size.

  • chunk_size (int) – chunk size to chunk the input tensor x along the seqlen dim for dispatch to control the granularity of computation load-balance.

  • cp_group_or_mesh (dist.ProcessGroup | DeviceMesh) – process group or device mesh. NOTE: for process group, we only support nccl backend for now, and for device mesh, we only support 1D or 2D mesh for now.

  • causal (bool, optional) – if True, all mask types are set to CAUSAL, otherwise, determine the mask types by window_size. Defaults to False.

  • window_size (tuple[int, int], optional) – window_size of sliding window mask which represents [window_size_left, window_size_right]. The parameter is effective only when causal is False; when causal is True, it is required to be (-1, -1). Defaults to be (-1, -1).

  • dist_attn_config (DistAttnConfig) – dist attn config.

Returns:

  • x (torch.Tensor): the input tensor after padding.

  • key (DistAttnRuntimeKey): the key points to the inner DistAttnRuntimeMgr.

Return type:

tuple[torch.Tensor, DistAttnRuntimeKey]

Example

>>> import torch
>>> import torch.distributed as dist
>>> from magi_attention.api import magi_attn_varlen_dispatch, undispatch, calc_attn
>>> from magi_attention.api.functools import compute_pad_size
>>> from magi_attention.config import (
...     DistAttnConfig,
...     DispatchConfig,
...     OverlapConfig,
...     MinHeapDispatchAlg,
...     UniformOverlapAlg
... )
>>> from magi_attention.common.enum import AttnOverlapMode
>>>
>>> # Generate a DistAttnRuntimeKey and dispatch the input for flash-attn-varlen style mask
>>> local_x, dist_attn_runtime_key = magi_attn_varlen_dispatch(
...     x=torch.randn(
...         4096,  # seqlen
...         2048,  # hidden_size
...         device="cuda",
...         dtype=torch.bfloat16,
...         requires_grad=True
...     ),
...     cu_seqlen_q=torch.tensor(
...         [0, 2048, 4096], dtype=torch.int32
...     ),
...     cu_seqlen_k=torch.tensor(
...         [0, 2048, 4096], dtype=torch.int32
...     ),
...     pad_size=compute_pad_size(4096, 4, 512),  # seqlen, cp_size, chunk_size
...     chunk_size=512,
...     cp_group_or_mesh=dist.new_group(list(range(4)), backend="nccl"),
...     causal=False,
...     window_size=(-1, -1),
...     dist_attn_config=DistAttnConfig(
...         dispatch_config=DispatchConfig(alg=MinHeapDispatchAlg()),
...         overlap_config=OverlapConfig(
...             enable=True,
...             mode=AttnOverlapMode.STATIC,
...             degree=2,
...             min_chunk_size=512,
...             max_num_chunks=64,
...             alg=UniformOverlapAlg(),
...         ),
...     ),
... )
>>>
>>> # Apply QKV projection
>>> local_q, local_k, local_v = q_project(local_x), k_project(local_x), v_project(local_x)
>>>
>>> # Calculate local attention
>>> local_out, _ = calc_attn(local_q, local_k, local_v, dist_attn_runtime_key)
>>>
>>> # Gather local attention outputs to total output if needed
>>> total_out = undispatch(local_out, dist_attn_runtime_key)

The logic of the magi_attn_varlen_dispatch function mainly consists of two parts: it first calls magi_attn_varlen_key to compute a key value, and then uses this key to dispatch the input x. The description of magi_attn_varlen_key is as follows.

magi_attention.api.magi_attn_interface.magi_attn_varlen_key()#

This is a flash-attn-varlen like interface, to generate q_ranges, k_ranges and attn_mask_type from cu_seqlens_q, cu_seqlens_k, causal and window_size, calculate DistAttnRuntimeKey and generate the corr. inner DistAttnRuntimeMgr.

Parameters:
  • cu_seqlens_q (torch.Tensor) – Cumulative sequence lengths for queries.

  • cu_seqlens_k (torch.Tensor) – Cumulative sequence lengths for keys.

  • pad_size (int) – the size to pad along seq_dim. The seq_len need to be divisable by chunk_size * cp_size.

  • chunk_size (int) – chunk size to chunk the input tensor x along the seqlen dim for dispatch to control the granularity of computation load-balance.

  • cp_group_or_mesh (dist.ProcessGroup | DeviceMesh) – process group or device mesh. NOTE: for process group, we only support nccl backend for now, and for device mesh, we only support 1D or 2D mesh for now.

  • causal (bool, optional) – if True, all mask types are set to CAUSAL, otherwise, determine the mask types by window_size. Defaults to False.

  • window_size (tuple[int, int], optional) – window_size of sliding window mask which represents [window_size_left, window_size_right]. The parameter is effective only when causal is False; when causal is True, it is required to be (-1, -1). Defaults to be (-1, -1).

  • dist_attn_config (DistAttnConfig) – dist attn config.

Returns:

the key points to the inner DistAttnRuntimeMgr.

Return type:

DistAttnRuntimeKey

Example

>>> import torch
>>> import torch.distributed as dist
>>> from magi_attention.api import magi_attn_varlen_key, dispatch, undispatch, calc_attn
>>> from magi_attention.api.functools import compute_pad_size
>>> from magi_attention.config import (
...     DistAttnConfig,
...     DispatchConfig,
...     OverlapConfig,
...     MinHeapDispatchAlg,
...     UniformOverlapAlg
... )
>>> from magi_attention.common.enum import AttnOverlapMode
>>>
>>> # Generate a DistAttnRuntimeKey to dispatch for flash-attn-varlen style mask
>>> dist_attn_runtime_key = magi_attn_varlen_key(
...     cu_seqlen_q=torch.tensor(
...         [0, 2048, 4096], dtype=torch.int32
...     ),
...     cu_seqlen_k=torch.tensor(
...         [0, 2048, 4096], dtype=torch.int32
...     ),
...     pad_size=compute_pad_size(4096, 4, 512), # seqlen, cp_size, chunk_size
...     chunk_size=512,
...     cp_group_or_mesh=dist.new_group(list(range(4)), backend="nccl"),
...     causal=False,
...     window_size=(-1, -1),
...     dist_attn_config=DistAttnConfig(
...         dispatch_config=DispatchConfig(alg=MinHeapDispatchAlg()),
...         overlap_config=OverlapConfig(
...             enable=True,
...             mode=AttnOverlapMode.STATIC,
...             degree=2,
...             min_chunk_size=512,
...             max_num_chunks=64,
...             alg=UniformOverlapAlg(),
...         ),
...     ),
... )
>>>
>>> # Dispatch several tensors with the same key
>>> local_x, local_label, local_rope = [
...     dispatch(tensor, dist_attn_runtime_key)
...     for tensor in [total_x, total_label, total_rope]
... ]
>>>
>>> # Apply QKV projection
>>> local_q, local_k, local_v = q_project(local_x), k_project(local_x), v_project(local_x)
>>>
>>> # Calculate local attention
>>> local_out, _ = calc_attn(local_q, local_k, local_v, dist_attn_runtime_key)
>>>
>>> # Gather local attention outputs to total output if needed
>>> total_out = undispatch(local_out, dist_attn_runtime_key)

If you want to apply more than one masks within the same training pass, you can use make_varlen_key_for_new_mask_after_dispatch to make a new key for the new mask, given the mask arguments specific for varlen mask in flash-attn-varlen style and the existing key used for dispatch.

Then the new mask will reuse the same dispatch solution as the mask used for dispatch, but with different meta arguments for computation and communication.

magi_attention.api.magi_attn_interface.make_varlen_key_for_new_mask_after_dispatch()#

Make a new dist attn runtime key for a new mask after dispatch with the given arguments for the new mask in flash-attn-varlen style and the key used for dispatch

NOTE: this API is useful when you want to apply more than one masks within the same training pass, if your model adopts hybrid-attn structure, in which case, we can only choose one of the masks to dispatch, while the others’re supposed to reuse the same dispatch solution with different meta arguments for computation and communication

WARNING: in such case, we can not guarantee all the masks are load-balanced in computation and optimized in communication.

Parameters:
  • cu_seqlens_q (torch.Tensor) – Cumulative sequence lengths for queries.

  • cu_seqlens_k (torch.Tensor) – Cumulative sequence lengths for keys.

  • key_for_dispatch (DistAttnRuntimeKey) – the key used for dispatch

  • causal (bool, optional) – whether the varlen attention mask is causal. Defaults to False.

  • window_size (tuple[int, int], optional) – window_size of sliding window mask which represents [window_size_left, window_size_right]. The parameter is effective only when causal is False; when causal is True, it is required to be (-1, -1). Defaults to be (-1, -1).

  • dist_attn_config (DistAttnConfig, optional) –

    the optional new dist attn config,

    NOTE: if not provided, we will use the same config as the key_for_dispatch, and if provided, the dispatch config of the new dist attn config won’t be applied to the new mask

Returns:

the new dist attn runtime key

for new mask with the same dispatch solution as the key_for_dispatch

Return type:

DistAttnRuntimeKey

Example

>>> import torch
>>> import torch.distributed as dist
>>> from magi_attention.api import magi_attn_varlen_key, dispatch, undispatch, calc_attn
>>> from magi_attention.api import make_varlen_key_for_new_mask_after_dispatch
>>> from magi_attention.api.functools import compute_pad_size
>>> from magi_attention.config import (
...     DistAttnConfig,
...     DispatchConfig,
...     OverlapConfig,
...     MinHeapDispatchAlg,
...     UniformOverlapAlg
... )
>>> from magi_attention.common.enum import AttnOverlapMode
>>>
>>> # Generate a DistAttnRuntimeKey to dispatch for flash-attn-varlen style mask
>>> # in the following case, we use a causal mask as the key for dispatch, thus it will consider
>>> # computation load-balance, communication optimization and computation-communication overlap
>>> # according to the causal mask pattern
>>> key_for_dispatch = magi_attn_varlen_key(
...     cu_seqlen_q=torch.tensor(
...         [0, 4096], dtype=torch.int32
...     ),
...     cu_seqlen_k=torch.tensor(
...         [0, 4096], dtype=torch.int32
...     ),
...     pad_size=compute_pad_size(4096, 4, 512), # seqlen, cp_size, chunk_size
...     chunk_size=512,
...     cp_group_or_mesh=dist.new_group(list(range(4)), backend="nccl"),
...     causal=True,
...     window_size=(-1, -1),
...     dist_attn_config=DistAttnConfig(
...         dispatch_config=DispatchConfig(alg=MinHeapDispatchAlg()),
...         overlap_config=OverlapConfig(
...             enable=True,
...             mode=AttnOverlapMode.STATIC,
...             degree=2,
...             min_chunk_size=512,
...             max_num_chunks=64,
...             alg=UniformOverlapAlg(),
...         ),
...     ),
... )
>>>
>>> # Dispatch several tensors with the same key_for_dispatch
>>> local_x, local_label, local_rope = [
...     dispatch(tensor, key_for_dispatch)
...     for tensor in [total_x, total_label, total_rope]
... ]
>>>
>>> # Make a new dist attn runtime key from key_for_dispatch
>>> # for a new mask, such as a sliding window causal mask below,
>>> # with the same dispatch solution as the causal mask used for dispatch,
>>> # i.e. this new key share the same dispatch meta as key_for_dispatch
>>> # but it can handle the computation and communication of the new mask
>>> # and calculate attn correctly as well, though no optimization is applied for now
>>> new_key_for_swa_mask = make_varlen_key_for_new_mask_after_dispatch(
...     cu_seqlens_q=torch.tensor([0, 4096], dtype=torch.int32),
...     cu_seqlens_k=torch.tensor([0, 4096], dtype=torch.int32),
...     causal=False,
...     window_size=(512, 0), # sliding window causal mask
...     key_for_dispatch=key_for_dispatch,
... )
>>>
>>> # Apply QKV projection
>>> local_q, local_k, local_v = q_project(local_x), k_project(local_x), v_project(local_x)
>>>
>>> # Calculate local attention for the mask used to dispatch with key_for_dispatch
>>> local_out1, _ = calc_attn(local_q, local_k, local_v, key_for_dispatch)
>>>
>>> # Calculate local attention for the new swa mask with the new key
>>> # w/o undispatching back and dispatching again to avoid OOM
>>> local_out2, _ = calc_attn(local_q, local_k, local_v, new_key_for_swa_mask)
>>>
>>> # Gather local attention outputs to total output if needed
>>> total_out1 = undispatch(local_out1, key_for_dispatch)
>>> total_out2 = undispatch(local_out2, new_key_for_swa_mask)

Flexible Dispatch#

If the masks you’re using are not limited to varlen full or varlen causal, but also include sliding window masks or other more diverse types, we recommend using the following API. By calling magi_attn_flex_dispatch, you can obtain the dispatched x and key.

magi_attention.api.magi_attn_interface.magi_attn_flex_dispatch()#

This is the most flexible interface, directly passing in q_ranges, k_ranges and attn_mask_type to calculate DistAttnRuntimeKey, generate the corr. inner DistAttnRuntimeMgr, finally pad and dispatch the input tensor to local tensor.

Parameters:
  • x (torch.Tensor) – input tensor

  • q_ranges (AttnRanges) – the global query ranges

  • k_ranges (AttnRanges) – the global key ranges

  • attn_mask_type (str | AttnMaskType | list[str | AttnMaskType]) – the global attn mask type (list) represented by str or enum AttnMaskType or their mixed combination

  • total_seqlen_q (int) – the total seqlen of query

  • total_seqlen_k (int) – the total seqlen of key

  • pad_size (int) – the size to pad along seq_dim. The seq_len need to be divisable by chunk_size * cp_size.

  • chunk_size (int) – chunk size to chunk the input tensor x along the seqlen dim for dispatch to control the granularity of computation load-balance.

  • cp_group_or_mesh (dist.ProcessGroup | DeviceMesh) – process group or device mesh. NOTE: for process group, we only support nccl backend for now, and for device mesh, we only support 1D or 2D mesh for now.

  • dist_attn_config (DistAttnConfig) – dist attn config

  • is_same_source (bool) – is query tensor and key tensor share the same source

  • is_q_permutable (bool) – is query tensor permutable

  • is_k_permutable (bool) – is key tensor permutable

Returns:

  • local_x (torch.Tensor): the local input x after padding.

  • key (DistAttnRuntimeKey): the key points to the inner DistAttnRuntimeMgr.

Return type:

tuple[torch.Tensor, DistAttnRuntimeKey]

Note

  1. For decoder-only transformers (e.g., GPT), it applies ‘self-attn’ as follows:

    1. is_same_source is True.

    2. Both q and k are permutable, as long as they are permuted in the same way.

  2. For encoder-decoder transformers (e.g., T5), it applies ‘cross-attn’ as follows:

    1. is_same_source is False.

    2. q is permutable but k is not.

  3. For multi-modal transformers with external encoders, it applies ‘cross-attn’ as follows:

    1. is_same_source is False.

    2. q is unpermutable due to self-attn, but k is permutable even in a different way.

Example

>>> import torch
>>> import torch.distributed as dist
>>> from magi_attention.api import magi_attn_flex_dispatch, undispatch, calc_attn
>>> from magi_attention.api.functools import compute_pad_size
>>> from magi_attention.config import (
...     DistAttnConfig,
...     DispatchConfig,
...     OverlapConfig,
...     MinHeapDispatchAlg,
...     UniformOverlapAlg
... )
>>> from magi_attention.common.enum import AttnOverlapMode
>>> from magi_attention.common import AttnRanges
>>>
>>> # Generate a DistAttnRuntimeKey and dispatch the input for arbitrary mask represented by attn-slices
>>> local_x, dist_attn_runtime_key = magi_attn_flex_dispatch(
...     x = torch.randn(
...         4096,   # seqlen
...         2048,   # hidden_size
...         device="cuda",
...         dtype=torch.bfloat16,
...         requires_grad=True
...     ),
...     q_ranges=AttnRanges.from_ranges([[0, 2048], [2048, 4096]]),
...     k_ranges=AttnRanges.from_ranges([[0, 2048], [0, 4096]]),
...     attn_mask_type="full",
...     total_seqlen_q=4096,
...     total_seqlen_k=4096,
...     pad_size=compute_pad_size(4096, 4, 512),  # seqlen, cp_size, chun_size
...     chunk_size=512,
...     cp_group_or_mesh=dist.new_group(list(range(4)), backend="nccl"),
...     dist_attn_config=DistAttnConfig(
...         dispatch_config=DispatchConfig(alg=MinHeapDispatchAlg()),
...         overlap_config=OverlapConfig(
...             enable=True,
...             mode=AttnOverlapMode.STATIC,
...             degree=2,
...             min_chunk_size=512,
...             max_num_chunks=64,
...             alg=UniformOverlapAlg(),
...         ),
...     ),
...     is_same_source=True,
...     is_q_permutable=True,
...     is_k_permutable=True,
... )
>>>
>>> # Apply QKV projection
>>> local_q, local_k, local_v = q_project(local_x), k_project(local_x), v_project(local_x)
>>>
>>> # Calculate local attention
>>> local_out, _ = calc_attn(local_q, local_k, local_v, dist_attn_runtime_key)
>>>
>>> # Gather local attention outputs to total output if needed
>>> total_out = undispatch(local_out, dist_attn_runtime_key)

Similar to the logic of magi_attn_varlen_dispatch, magi_attn_flex_dispatch first calls magi_attn_flex_key to obtain a key, and then uses this key to dispatch x. The description of magi_attn_flex_key is as follows.

magi_attention.api.magi_attn_interface.magi_attn_flex_key()#

This is the most flexible interface, directly passing in q_ranges, k_ranges and attn_mask_type to calculate DistAttnRuntimeKey and generate the corr. inner DistAttnRuntimeMgr.

Parameters:
  • q_ranges (AttnRanges) – the global query ranges

  • k_ranges (AttnRanges) – the global key ranges

  • attn_mask_type (str | AttnMaskType | list[str | AttnMaskType]) – the global attn mask type (list) represented by str or enum AttnMaskType or their mixed combination

  • total_seqlen_q (int) – the total seqlen of query

  • total_seqlen_k (int) – the total seqlen of key

  • pad_size (int) – the size to pad along seq_dim. The seq_len need to be divisable by chunk_size * cp_size.

  • chunk_size (int) – chunk size to chunk the input tensor x along the seqlen dim for dispatch to control the granularity of computation load-balance.

  • cp_group_or_mesh (dist.ProcessGroup | DeviceMesh) – process group or device mesh. NOTE: for process group, we only support nccl backend for now, and for device mesh, we only support 1D or 2D mesh for now.

  • dist_attn_config (DistAttnConfig) – dist attn config

  • is_same_source (bool) – is query tensor and key tensor share the same source

  • is_q_permutable (bool) – is query tensor permutable

  • is_k_permutable (bool) – is key tensor permutable

Returns:

the key points to the inner DistAttnRuntimeMgr.

Return type:

DistAttnRuntimeKey

Note

  1. For decoder-only transformers (e.g., GPT), it applies ‘self-attn’ as follows:

    1. is_same_source is True.

    2. Both q and k are permutable, as long as they are permuted in the same way.

  2. For encoder-decoder transformers (e.g., T5), it applies ‘cross-attn’ as follows:

    1. is_same_source is False.

    2. q is permutable but k is not.

  3. For multi-modal transformers with external encoders, it applies ‘cross-attn’ as follows:

    1. is_same_source is False.

    2. q is unpermutable due to self-attn, but k is permutable even in a different way.

Example

>>> import torch
>>> import torch.distributed as dist
>>> from magi_attention.api import magi_attn_flex_key, dispatch, undispatch, calc_attn
>>> from magi_attention.api.functools import compute_pad_size
>>> from magi_attention.config import (
...     DistAttnConfig,
...     DispatchConfig,
...     OverlapConfig,
...     MinHeapDispatchAlg,
...     UniformOverlapAlg
... )
>>> from magi_attention.common.enum import AttnOverlapMode
>>> from magi_attention.common import AttnRanges
>>>
>>> # Generate a DistAttnRuntimeKey to dispatch for arbitrary mask represented by attn-slices
>>> dist_attn_runtime_key = magi_attn_flex_key(
...     q_ranges=AttnRanges.from_ranges([[0, 2048], [2048, 4096]]),
...     k_ranges=AttnRanges.from_ranges([[0, 2048], [0, 4096]]),
...     attn_mask_type="full",
...     total_seqlen_q=4096,
...     total_seqlen_k=4096,
...     pad_size=compute_pad_size(4096, 4, 512),  # seqlen, cp_size, chunk_size
...     chunk_size=512,
...     cp_group_or_mesh=dist.new_group(list(range(4)), backend="nccl"),
...     is_same_source=True,
...     is_q_permutable=True,
...     is_k_permutable=True,
...     dist_attn_config=DistAttnConfig(
...         dispatch_config=DispatchConfig(alg=MinHeapDispatchAlg()),
...         overlap_config=OverlapConfig(
...             enable=True,
...             mode=AttnOverlapMode.STATIC,
...             degree=2,
...             min_chunk_size=512,
...             max_num_chunks=64,
...             alg=UniformOverlapAlg(),
...         ),
...     ),
... )
>>>
>>> # Dispatch several tensors with the same key
>>> local_x, local_label, local_rope = [
...     dispatch(tensor, dist_attn_runtime_key)
...     for tensor in [total_x, total_label, total_rope]
... ]
>>>
>>> # Apply QKV projection
>>> local_q, local_k, local_v = q_project(local_x), k_project(local_x), v_project(local_x)
>>>
>>> # Calculate local attention
>>> local_out, _ = calc_attn(local_q, local_k, local_v, dist_attn_runtime_key)
>>>
>>> # Gather local attention outputs to total output if needed
>>> total_out = undispatch(local_out, dist_attn_runtime_key)

If you want to apply more than one varlen masks within the same training pass, you can use make_flex_key_for_new_mask_after_dispatch to make a new key for the new mask, given the mask arguments and the existing key used for dispatch.

Then the new mask will reuse the same dispatch solution as the mask used for dispatch, but with different meta arguments for computation and communication.

magi_attention.api.magi_attn_interface.make_flex_key_for_new_mask_after_dispatch()#

Make a new dist attn runtime key for a new mask after dispatch with the given arguments for the new mask and the key used for dispatch

NOTE: this API is useful when you want to apply more than one masks within the same training pass, if your model adopts hybrid-attn structure, in which case, we can only choose one of the masks to dispatch, while the others’re supposed to reuse the same dispatch solution with different meta arguments for computation and communication

WARNING: in such case, we can not guarantee all the masks are load-balanced in computation and optimized in communication for now. However, we are working on it with the dynamic dist-attn solver to optimize the computation and communication for each distinct mask with the same dispatch solution

Parameters:
  • q_ranges (AttnRanges) – the global query ranges

  • k_ranges (AttnRanges) – the global key ranges

  • attn_mask_type (str | AttnMaskType | list[str | AttnMaskType]) – the global attn mask type (list) represented by str or enum AttnMaskType or their mixed combination

  • key_for_dispatch (DistAttnRuntimeKey) – the key used for dispatch

  • dist_attn_config (DistAttnConfig, optional) –

    the optional new dist attn config,

    NOTE: if not provided, we will use the same config as the key_for_dispatch, and if provided, the dispatch config of the new dist attn config won’t be applied to the new mask

Returns:

the new dist attn runtime key

for new mask with the same dispatch solution as the key_for_dispatch

Return type:

DistAttnRuntimeKey

Example

>>> import torch
>>> import torch.distributed as dist
>>> from magi_attention.api import magi_attn_flex_key, dispatch, undispatch, calc_attn
>>> from magi_attention.api import make_flex_key_for_new_mask_after_dispatch
>>> from magi_attention.api.functools import compute_pad_size
>>> from magi_attention.config import (
...     DistAttnConfig,
...     DispatchConfig,
...     OverlapConfig,
...     MinHeapDispatchAlg,
...     UniformOverlapAlg
... )
>>> from magi_attention.common.enum import AttnOverlapMode
>>> from magi_attention.common import AttnRanges
>>>
>>> # Generate a DistAttnRuntimeKey to dispatch for arbitrary mask represented by attn-slices
>>> # in the following case, we use a causal mask as the key for dispatch, thus it will consider
>>> # computation load-balance, communication optimization and computation-communication overlap
>>> # according to the causal mask pattern
>>> key_for_dispatch = magi_attn_flex_key(
...     q_ranges=AttnRanges.from_ranges([[0, 4096]]),
...     k_ranges=AttnRanges.from_ranges([[0, 4096]]),
...     attn_mask_type="causal",
...     total_seqlen_q=4096,
...     total_seqlen_k=4096,
...     pad_size=compute_pad_size(4096, 4, 512),  # seqlen, cp_size, chunk_size
...     chunk_size=512,
...     cp_group_or_mesh=dist.new_group(list(range(4)), backend="nccl"),
...     is_same_source=True,
...     is_q_permutable=True,
...     is_k_permutable=True,
...     dist_attn_config=DistAttnConfig(
...         dispatch_config=DispatchConfig(alg=MinHeapDispatchAlg()),
...         overlap_config=OverlapConfig(
...             enable=True,
...             mode=AttnOverlapMode.STATIC,
...             degree=2,
...             min_chunk_size=512,
...             max_num_chunks=64,
...             alg=UniformOverlapAlg(),
...         ),
...     ),
... )
>>>
>>> # Dispatch several tensors with the same key_for_dispatch
>>> local_x, local_label, local_rope = [
...     dispatch(tensor, key_for_dispatch)
...     for tensor in [total_x, total_label, total_rope]
... ]
>>>
>>> # Make a new dist attn runtime key from key_for_dispatch
>>> # for a new mask, such as a sliding window causal mask below,
>>> # with the same dispatch solution as the causal mask used for dispatch,
>>> # i.e. this new key share the same dispatch meta as key_for_dispatch
>>> # but it can handle the computation and communication of the new mask
>>> # and calculate attn correctly as well, though no optimization is applied for now
>>> new_key_for_swa_mask = make_flex_key_for_new_mask_after_dispatch(
...     q_ranges=AttnRanges.from_ranges([[0, 512], [512, 4096]]),
...     k_ranges=AttnRanges.from_ranges([[0, 512], [0, 4096]]),
...     attn_mask_type=["causal", "bi_causal"], # sliding window causal mask
...     key_for_dispatch=key_for_dispatch,
... )
>>>
>>> # Apply QKV projection
>>> local_q, local_k, local_v = q_project(local_x), k_project(local_x), v_project(local_x)
>>>
>>> # Calculate local attention for the mask used to dispatch with key_for_dispatch
>>> local_out1, _ = calc_attn(local_q, local_k, local_v, key_for_dispatch)
>>>
>>> # Calculate local attention for the new swa mask with the new key
>>> # w/o undispatching back and dispatching again to avoid OOM
>>> local_out2, _ = calc_attn(local_q, local_k, local_v, new_key_for_swa_mask)
>>>
>>> # Gather local attention outputs to total output if needed
>>> total_out1 = undispatch(local_out1, key_for_dispatch)
>>> total_out2 = undispatch(local_out2, new_key_for_swa_mask)

Dispatch Function#

If you already have the key, you can call dispatch function to get the padded and dispatched local tensor.

magi_attention.api.magi_attn_interface.dispatch()#

Pad and dispatch the global input tensor to local tensor on each rank along the seqlen dim.

Parameters:
  • x (torch.Tensor) – global input tensor.

  • key (DistAttnRuntimeKey) – the key that holds some inner meta data, as one argument for many other magi_attention APIs, which users don’t have to bother with.

  • pad_value (float) – the specific value to pad to input tensor. Defaults to 0.

Returns:

the padded and dispatched local tensor.

Return type:

torch.Tensor

Raises:

ValueError – If the provided key does not exist in dist_attn_runtime_dict.

Calculate Attention#

After dispatch and projection, you should obtain the query, key, and value needed for computation. Using the key obtained from the dispatch function mentioned above, you can perform the computation by calling calc_attn, which returns the results out and lse.

magi_attention.api.magi_attn_interface.calc_attn()#

Apply attention computation.

Parameters:
  • q (torch.Tensor) – local query tensor.

  • k (torch.Tensor) – local key tensor.

  • v (torch.Tensor) – local value tensor.

  • key (DistAttnRuntimeKey) – the object that holds some inner meta data as one argument for many other magi_attention APIs, which users don’t have to bother with.

  • sink (torch.Tensor, optional) – global sink tensor (replicated among cp ranks). Defaults to None to not apply attention sink.

  • softmax_scale (float, optional) – softmax scale. Defaults to None to use: 1/sqrt(head_dim).

  • softcap (float, optional) – softcap. Defaults to 0.0.

Returns:

  • out (torch.Tensor): local output tensor.

  • lse (torch.Tensor): local log-sum-exp tensor.

Return type:

tuple[torch.Tensor, torch.Tensor]

Shapes:
  • q: [num_tokens_q_local, num_heads_q, head_dim]

  • k: [num_tokens_kv_local, num_heads_kv, head_dim]

  • v: [num_tokens_kv_local, num_heads_kv, head_dim]

  • sink: [num_tokens_sink_global, num_heads_q]

  • out: [num_tokens_q_local, num_heads_q, head_dim]

  • lse: [num_tokens_q_local, num_heads_q]

Raises:

ValueError – If the provided key does not exist in dist_attn_runtime_dict.

Parameters:
  • q (Tensor)

  • k (Tensor)

  • v (Tensor)

  • key (DistAttnRuntimeKey)

  • sink (Tensor | None)

  • softmax_scale (float | None)

  • softcap (float)

Return type:

tuple[Tensor, Tensor]

Undispatch#

Undispatch Function#

When you need to recover the complete global tensor from the local tensor like computing the loss, you can call undispatch function to unpad and undispatch the local tensor along the seqlen dim.

magi_attention.api.magi_attn_interface.undispatch()#

Undispatch and unpad the local tensor to global tensor along the seqlen dim.

Parameters:
  • x (torch.Tensor) – local tensor

  • key (DistAttnRuntimeKey) – the key that holds some inner meta data, as one argument for many other magi_attention APIs, which users don’t have to bother with.

Returns:

the undispatched and unpadded tensor.

Return type:

torch.Tensor

Raises:

ValueError – If the provided key does not exist in dist_attn_runtime_dict.

Utility Functions#

Compute Pad Size and Padding#

During the use of MagiAttention, we divide the total_seqlen into multiple chunks of size chunk_size and evenly distribute them across multiple GPUs. To ensure that total_seqlen is divisible by chunk_size and that each GPU receives the same number of chunks, we need to pad the original input.

You can call compute_pad_size to calculate the required padding length, and use this value as a parameter in subsequent functions.

magi_attention.api.functools.compute_pad_size()#

Compute the size to pad to the input tensor along the seqlen dim at last.

Parameters:
  • total_seqlen_q (int) – seqlen of q.

  • cp_size (int) – The size of cp group.

  • chunk_size (int) – chunk size to chunk the input tensor x along the seqlen dim for dispatch to control the granularity of computation load-balance.

Returns:

the number of tokens to pad.

Return type:

int

After obtaining pad_size, you can use pad_at_dim and unpad_at_dim function to pad and unpad the tensor.

magi_attention.api.functools.pad_at_dim()#

Pads a tensor along a specified dimension with a given value, either on the left or right side.

Parameters:
  • x (torch.Tensor) – Input tensor to be padded.

  • dim (int) – The dimension along which to apply padding.

  • pad_size (int) – The number of values to pad.

  • value (float, optional) – The padding value. Defaults to 0.0.

  • side (str, optional) – Side on which to apply the padding, either left or right. Defaults to right.

Returns:

The padded tensor with the same number of dimensions as the input.

Return type:

torch.Tensor

magi_attention.api.functools.unpad_at_dim()#

Removes padding from a tensor along a specified dimension.

Parameters:
  • x (torch.Tensor) – Input tensor from which padding will be removed.

  • dim (int) – The dimension along which to remove padding.

  • pad_size (int) – The number of elements to remove from the end of the specified dimension.

Returns:

The tensor with padding removed along the specified dimension.

Return type:

torch.Tensor

Similarly, you can use pad_size along with total_seqlen and other related information to apply padding to a (q_ranges, k_ranges, mask_types) tuple using apply_padding function.

This function fills the padding region with invalid slices.

magi_attention.api.functools.apply_padding()#

Appends padding to the attention ranges and updates the corresponding mask type.

This function adds a padding query range at the end of q_ranges, a dummy key range to k_ranges, and appends a FULL attention mask type to maintain alignment. It is typically used when padding is required for alignment or block-wise processing.

Parameters:
  • q_ranges (AttnRanges) – Query token ranges before padding.

  • k_ranges (AttnRanges) – Key token ranges before padding.

  • attn_mask_type (list[AttnMaskType]) – List of attention mask types corresponding to the ranges.

  • total_seqlen (int) – The total original sequence length (used to place the padding at the end).

  • pad_size (int) – The size of the padding to append.

Returns:

  • Updated query ranges with padding added.

  • Updated key ranges with a dummy range for padding.

  • Updated attention mask type list with a FULL mask for the padding block.

Return type:

tuple[AttnRanges, AttnRanges, list[AttnMaskType]]

Get Position Ids#

Since MagiAttention needs to permute the input tensor along the seqlen dim, some token-aware ops might be affected, such as RoPE.

Therefore, we provide a function get_position_ids to get the position ids of the input tensor similar to Llama.

magi_attention.api.magi_attn_interface.get_position_ids()#

Get the position ids of local tensor to global tensor after dispatching.

Parameters:

key (DistAttnRuntimeKey) – the key that holds some inner meta data, as one argument for many other magi_attention APIs, which users don’t have to bother with.

Returns:

postion ids of local tensor w.r.t. global tensor.

Return type:

torch.Tensor

Raises:

ValueError – If the provided key does not exist in dist_attn_runtime_dict.

Get Most Recent Key#

If you have trouble accessing the meta key, and meanwhile you need to get the most recent key, then you can call get_most_recent_key to get it.

However, we strongly recommend you to access the key passed through the arguments, in case of unexpected inconsistency.

magi_attention.api.magi_attn_interface.get_most_recent_key()#

Get the most recent inserted key.

This is useful when you can not access the key through the arguments, and meanwhile you only need the most recent inserted key. However, we strongly recommend you to access the key passed through the arguments, in case of unexpected inconsistency.

Returns:

the most recent inserted key.

Return type:

DistAttnRuntimeKey

Infer Varlen Masks#

If you want to use a varlen mask where each segment has the same length, we provide a infer_varlen_mask_from_batch function that generates the corresponding cu_seqlens tensors for you.

magi_attention.api.functools.infer_varlen_mask_from_batch()#

Converts fixed-length full attention into varlen fulll attention format by generating cumulative sequence lengths for queries and keys.

Parameters:
  • batch_size (int) – The number of sequences in the batch.

  • seq_len (int) – The fixed sequence length for each sequence in the batch.

  • device (str, optional) – The device to allocate the tensors on. Defaults to "cuda".

Returns:

A pair of 1D tensors (cu_seqlens_q, cu_seqlens_k), each of shape [batch_size + 1,], representing the cumulative sequence lengths for the queries and keys respectively.

Return type:

tuple[torch.Tensor, torch.Tensor]

During the use of varlen mask, it is often necessary to reshape a tensor of shape [batch_size × seq_len, ...] into [batch_size × seq_len, ...].

To facilitate the use of the above APIs, we provide the squash_batch_dim function to merge the tensor dimensions.

magi_attention.api.functools.squash_batch_dim()#

Reshapes a tensor from shape [b, s, ...] to [b x s, ...], effectively flattening the batch and sequence dimensions into a single leading dimension.

Parameters:

x (torch.Tensor) – Input tensor of shape [batch_size, seq_len, ...] to be merged.

Returns:

Reshaped tensor of shape [batch_size x seq_len, ...].

Return type:

torch.Tensor

Moreover, if you have already computed the cu_seqlens tensor and want to generate a varlen mask based on it, we provide the infer_attn_mask_from_cu_seqlens function. This function can create three types of masks—varlen full, varlen causal, and varlen sliding window—according to cu_seqlens, causal, and window_size, and returns the result in the form of a (q_ranges, k_ranges, mask_types, total_seqlen_q, total_seqlen_k).

magi_attention.api.functools.infer_attn_mask_from_cu_seqlens()#

Infer query ranges, key ranges and other arguments for flexible attn mask representation from cu_seqlens, widely used for varlen masks.

Parameters:
  • cu_seqlens_q (torch.Tensor) – cumulative sequence lengths for queries

  • cu_seqlens_k (torch.Tensor) – cumulative sequence lengths for keys

  • causal (bool, optional) – whether the varlen attention mask is causal. Defaults to False.

  • window_size (tuple[int, int], optional) – window_size of sliding window mask which represents [window_size_left, window_size_right]. The parameter is effective only when causal is False; when causal is True, it is required to be (-1, -1). Defaults to (-1, -1).

Returns:

query ranges, key ranges, attn mask type list, total seqlen of q, total seqlen of k

Return type:

tuple[AttnRanges, AttnRanges, list[AttnMaskType], int, int]

Infer Sliding Window Masks#

In the design of MagiAttention, we use a (q_range, k_range, masktype) tuple to represent a slice.

For sliding window masks, we do not provide a dedicated masktype to represent them directly.

However, a sliding window mask can be decomposed into a combination of existing masktypes such as full, causal, inv_causal, and bi_causal.

If you’re unsure how to perform this decomposition, we provide infer_attn_mask_from_sliding_window function to handle this process for you.

magi_attention.api.functools.infer_attn_mask_from_sliding_window()#

Convert only one sliding window masks into representations using q_range, k_range, and mask type. The mask type is specified using window_size.

Parameters:
  • q_range (AttnRange) – q_range of this sliding window mask

  • k_range (AttnRange) – k_range of this sliding window mask

  • window_size (tuple[int, int]) – window_size of sliding window mask which represents [window_size_left, window_size_right]

Returns:

processed (q_ranges, k_ranges, masktypes) triple, sliding window mask have been cutted into triple representation.

Return type:

tuple[AttnRanges, AttnRanges, list[AttnMaskType]]

Example

Here’s an example of infer_attn_mask_from_sliding_window:

>>> q_ranges, k_ranges, attn_mask_type = infer_attn_mask_from_sliding_window(
...     q_range=AttnRange.from_range([5, 15]),
...     k_range=AttnRange.from_range([5, 15]),
...     window_size=(2, 3),
... )

The code above represents the sliding window mask within the [5, 15] x [5, 15] region with a window size of (2, 3).