API Reference#
Flexible Flash Attention#
To support computing irregular-shaped masks, we implemented a flexible_flash_attention kernel, which can be invoked through the following interface.
- magi_attention.functional.flex_flash_attn.flex_flash_attn_func()#
An interface similar to flash attention that doesn’t require distributed environment, dispatch or undispatch. Directly call magi_attn_kernel to get attention output and lse. This is faster when you don’t need context parallel.
- Parameters:
q (torch.Tensor) – Query tensor.
k (torch.Tensor) – Key tensor.
v (torch.Tensor) – Value tensor.
q_ranges (torch.Tensor) – Query ranges tensor to represent the attn mask.
k_ranges (torch.Tensor) – Key ranges tensor to represent the attn mask.
max_seqlen_q (int | None, optional) – Maximum sequence length for query. Defaults to
None. If provided, enables optimization for tile_scheduler. Most recommended to set this when using auto_range_merge(for block sparse attention).attn_type_map (torch.Tensor, optional) –
Attention type map tensor with dtype=torch.int32, Defaults to
Noneto apply full attention for all ranges. The values specify the attention type for each token:0: full attention
1: causal attention
2: inverse causal attention
3: bidirectional causal attention
More information about the attention type map can be found in the
Notebelow.sink (torch.Tensor, optional) – Learnable sink token tensor. Defaults to
Noneto not apply attention sink.sink_layout (AttnSinkLayout, optional) – the layout of the sink tokens. Defaults to “sh”. Available Options: “sh”, “ssh”.
softmax_scale (float, optional) – Softmax scale. Defaults to
Noneto use:1/sqrt(head_dim).softcap (float, optional) – Softcap. Defaults to
0.0.deterministic (bool, optional) – Whether to use deterministic attention. Defaults to
False.sm_margin (int, optional) – The amount of SMs reserved out, useful when considering overlapping with other kernels such as communication kernels. Defaults to
0to use all available SMs.disable_fwd_atomic_reduction (bool, optional) –
Whether to disable forward atomic reduction. Defaults to
False.If you can ensure
q_rangesis non-overlapped, you can set this toTruefor better performance. The “overlap” term amongq_rangesis defined as: if any twoq_rangeinq_rangeshave non-empty intersection, then it is overlapped. For example,q_ranges=[[0, 15], [10, 20], [20, 30]]is overlapped sinceq_range1=[0, 15]andq_range2=[10, 20]intersect, while `` q_ranges`` =[[0, 15], [15, 20], [20, 30]]then is non-overlapped.disable_bwd_dkv_atomic_reduction (bool, optional) –
Whether to disable backward dK/dV atomic reduction. Defaults to
False.If you can ensure
k_ranges(used in backward) is non-overlapped and sorted, you can set this toTruefor better performance. The “overlap” term amongk_rangesis defined as: if any twok_rangeink_rangeshave non-empty intersection, then it is overlapped. For example,k_ranges=[[0, 15], [10, 20], [20, 30]]is overlapped sincek_range1=[0, 15]andk_range2=[10, 20]intersect, whilek_ranges=[[0, 15], [15, 20], [20, 30]]then is non-overlapped. Note: This flag can only be enabled with MHA or catGQA.ref_block_size (tuple[int, int], optional) – Reference block size (M, N) for kernel selection. Defaults to
Noneto use the internal heuristic. Note: This flag is useful for sparse attention scenarios but still under development.max_seqlen_q – Maximum sequence length for query. Defaults to
None. If provided, enables optimization for forward tile_scheduler, especially for block sparse attention scenarios.auto_range_merge (bool, optional) – Whether to automatically merge k_ranges for the same q_range. Defaults to
False. Note: This flag is useful for sparse attention scenarios but still under development.swap_ab (bool, optional) – Whether to swap the order of A and B operands for the matmul operation (i.e. transpose C=A x B^T to C^T= B x A^T) in attention forward passes. Defaults to
False. Note: This flag is useful for sparse attention scenarios but still under development.pack_gqa (bool, optional) – Whether to group query heads sharing the same KV head into a single computation block tile for small seqlen_q scenarios. This method significantly improves the computational efficiency of block sparse attention when seqlen_q is small. Defaults to
False. Note: kblockm must be divisible by qhead_per_khead(num_qhead // num_khead). For backward pass, this flag is only enabled when swap_bwd_qk_loop is True.cat_gqa (bool, optional) – Whether to concatenate multiple Q heads sharing the same KV head, to optimize the backward performance under GQA settings. Defaults to
False.sparse_load (bool, optional) – Whether to enable sparse load mode for optimizing performance when k_range size is small (< 64). Must be used together with
auto_range_merge=Truefor enhanced performance. Defaults toFalse.swap_bwd_qk_loop (bool, optional) – Whether to swap the order of Q and K double-loops (i.e. from the default K for outer-loop and Q for inner-loop to Q for outer-loop and K for inner-loop) in the attention backward pass. Defaults to
False. Note: This flag is useful for sparse attention scenarios but still under development.return_max_logits (bool, optional) – Whether to return the maximum attention logits, according to the Muon QK-Clip technique introduced in Kimi K2: https://arxiv.org/pdf/2507.20534.pdf. Defaults to
False.
- Returns:
out (torch.Tensor): Attention output tensor
- meta (AttnForwardMeta): Meta information of the attention forward pass,
for now, including lse (torch.Tensor) with dtype=torch.float32, and max_logits (torch.Tensor) with dtype=torch.float32, if
return_max_logitsisTrue, otherwiseNone.
- Return type:
tuple[torch.Tensor, AttnForwardMeta]
- Shape:
q: (num_tokens_q, num_heads_q, head_dim)
k: (num_tokens_kv, num_heads_kv, head_dim)
v: (num_tokens_kv, num_heads_kv, head_dim)
- sink:
if sink_layout == “sh”: (num_tokens_sink, num_heads_q)
if sink_layout == “ssh”: (num_tokens_q, num_tokens_sink, num_heads_q)
q_ranges: (num_ranges, 2)
k_ranges: (num_ranges, 2)
attn_type_map: (num_ranges,)
out: (num_tokens_q, num_heads_q, head_dim)
lse: (num_tokens_q, num_heads_q)
max_logits: (num_heads_q,)
Note
The
attn_type_mapexplains the semantics of different attention mask types. In addition to the descriptions below, see our blog for a visual explanation: https://sandai-org.github.io/MagiAttention/blog/#flex-flash-attn- Full attention:
If seqlen_q = 5 and seqlen_k = 2:
1 1 1 1 1 1 1 1 1 1
If seqlen_q = 2 and seqlen_k = 5:
1 1 1 1 1 1 1 1 1 1
If seqlen_q = 5 and seqlen_k = 5:
1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
- Causal attention (bottom-right aligned):
If seqlen_q = 5 and seqlen_k = 2:
0 0 0 0 0 0 1 0 1 1
If seqlen_q = 2 and seqlen_k = 5:
1 1 1 1 0 1 1 1 1 1
If seqlen_q = 5 and seqlen_k = 5:
1 0 0 0 0 1 1 0 0 0 1 1 1 0 0 1 1 1 1 0 1 1 1 1 1
- Inverse causal attention (top-left aligned):
If seqlen_q = 5 and seqlen_k = 2:
1 1 0 1 0 0 0 0 0 0
If seqlen_q = 2 and seqlen_k = 5:
1 1 1 1 1 0 1 1 1 1
If seqlen_q = 5 and seqlen_k = 5:
1 1 1 1 1 0 1 1 1 1 0 0 1 1 1 0 0 0 1 1 0 0 0 0 1
- Bidirectional causal attention (intersection of causal and inverse causal):
This is the element-wise AND of causal and inverse causal masks.
If seqlen_q = 5 and seqlen_k = 2:
0 0 0 0 0 0 0 0 0 0
If seqlen_q = 2 and seqlen_k = 5:
1 1 1 1 0 0 1 1 1 1
If seqlen_q = 5 and seqlen_k = 5:
1 0 0 0 0 0 1 0 0 0 0 0 1 0 0 0 0 0 1 0 0 0 0 0 1
Dispatch#
Varlen Dispatch#
If you’re using a mask defined by cu_seqlens, such as a varlen full or varlen causal mask, we’ve designed a similar interface magi_attn_varlen_key inspired by FlashAttention’s API as follows, making it easy for you to get started quickly.
- magi_attention.api.magi_attn_interface.magi_attn_varlen_key()#
This is a flash-attn-varlen like interface, to generate
q_ranges,k_rangesandattn_mask_typefromcu_seqlens_q,cu_seqlens_k,causalandwindow_size, calculatedist_attn_runtime_keyand generate the corr. innerdist_attn_runtime_mgr.- Parameters:
cu_seqlens_q (torch.Tensor) – the cumulative sequence lengths for queries.
cu_seqlens_k (torch.Tensor) – the cumulative sequence lengths for keys.
num_heads_q (int) – the number of heads for query.
num_heads_kv (int) – the number of heads for key/value.
head_dim (int) – the dimension of each attention head.
pad_size (int) – the size to pad the global input tensor along sequence dim, due to the constraint that the sequence length need to be divisable by
chunk_size * cp_size.chunk_size (int) – the size to chunk the global input tensor along the seqlen dim for later sharding and dispatching among the cp ranks as a granularity factor of computational load-balance.
cp_group_or_mesh (dist.ProcessGroup | DeviceMesh) – process group or device mesh. NOTE: for process group, we only support nccl backend for now, and for device mesh, we only support 1D or 2D mesh for now.
causal (bool, optional) – if
True, all mask types are set toCAUSAL, otherwise, determine the mask types bywindow_size. Defaults toFalse.window_size (tuple[int, int], optional) – window_size of sliding window mask which represents
[window_size_left, window_size_right]. The parameter is effective only whencausalisFalse; whencausalisTrue, it is required to be(-1, -1). Defaults to be(-1, -1).dist_attn_config (DistAttnConfig) – dist attn config.
- Returns:
the key points to the inner DistAttnRuntimeMgr.
- Return type:
DistAttnRuntimeKey
Example
>>> import torch >>> import torch.distributed as dist >>> from magi_attention.api import magi_attn_varlen_key, dispatch, undispatch, calc_attn >>> from magi_attention.api.functools import compute_pad_size >>> from magi_attention.config import ( ... DistAttnConfig, ... DispatchConfig, ... OverlapConfig, ... MinHeapDispatchAlg, ... UniformOverlapAlg ... ) >>> from magi_attention.common.enum import AttnOverlapMode >>> >>> # Step1. generate a dist_attn_runtime_key to store and indicate the inner meta info >>> dist_attn_runtime_key = magi_attn_varlen_key( ... cu_seqlen_q=torch.tensor( ... [0, 2048, 4096], dtype=torch.int32 ... ), ... cu_seqlen_k=torch.tensor( ... [0, 2048, 4096], dtype=torch.int32 ... ), ... num_heads_q=16, ... num_heads_kv=4, ... head_dim=128, ... pad_size=compute_pad_size(4096, 4, 512), # seqlen, cp_size, chunk_size ... chunk_size=512, ... cp_group_or_mesh=dist.new_group(list(range(4)), backend="nccl"), ... causal=False, ... window_size=(-1, -1), ... dist_attn_config=DistAttnConfig( ... dispatch_config=DispatchConfig(alg=MinHeapDispatchAlg()), ... overlap_config=OverlapConfig( ... enable=True, ... mode=AttnOverlapMode.STATIC, ... degree=2, ... min_chunk_size=512, ... max_num_chunks=64, ... alg=UniformOverlapAlg(), ... ), ... ), ... ) >>> >>> # Step2. dispatch the global tensors to local tensors >>> local_x, local_label, local_rope = [ ... dispatch(tensor, dist_attn_runtime_key) ... for tensor in [total_x, total_label, total_rope] ... ] >>> >>> # Step3. apply QKV projection on local tensors >>> local_q, local_k, local_v = q_project(local_x), k_project(local_x), v_project(local_x) >>> >>> # Step4. calculate distributed attention to get the local attention output tensor >>> local_out, meta = calc_attn(local_q, local_k, local_v, dist_attn_runtime_key) >>> >>> # Step5. undispatch local attention output to the global one if needed >>> total_out = undispatch(local_out, dist_attn_runtime_key)
If you want to apply more than one masks within the same training pass, you can use make_varlen_key_for_new_mask_after_dispatch to make a new key for the new mask, given the mask arguments specific for varlen mask in flash-attn-varlen style and the existing key used for dispatch.
Then the new mask will reuse the same dispatch solution as the mask used for dispatch, but with different meta arguments for computation and communication.
- magi_attention.api.magi_attn_interface.make_varlen_key_for_new_mask_after_dispatch()#
Make a new dist attn runtime key for a new mask after dispatch with the given arguments for the new mask in flash-attn-varlen style and the key used for dispatch
NOTE: this API is useful when you want to apply more than one masks within the same training pass, if your model adopts hybrid-attn structure, in which case, we can only choose one of the masks to dispatch, while the others’re supposed to reuse the same dispatch solution with different meta arguments for computation and communication
WARNING: in such case, we can not guarantee all the masks are load-balanced in computation and optimized in communication.
- Parameters:
cu_seqlens_q (torch.Tensor) – the cumulative sequence lengths for queries.
cu_seqlens_k (torch.Tensor) – the cumulative sequence lengths for keys.
key_for_dispatch (DistAttnRuntimeKey) – the key used for dispatch.
causal (bool, optional) – whether the varlen attention mask is causal. Defaults to
False.window_size (tuple[int, int], optional) – window_size of sliding window mask which represents
[window_size_left, window_size_right]. The parameter is effective only whencausalisFalse; whencausalisTrue, it is required to be(-1, -1). Defaults to be(-1, -1).dist_attn_config (DistAttnConfig, optional) – the optional new dist attn config. NOTE: if not provided, we will use the same config as the
key_for_dispatch, and if provided, the dispatch config of the new dist attn config won’t be applied to the new mask
- Returns:
- the new dist attn runtime key
for new mask with the same dispatch solution as the
key_for_dispatch.
- Return type:
DistAttnRuntimeKey
Example
>>> import torch >>> import torch.distributed as dist >>> from magi_attention.api import magi_attn_varlen_key, dispatch, undispatch, calc_attn >>> from magi_attention.api import make_varlen_key_for_new_mask_after_dispatch >>> from magi_attention.api.functools import compute_pad_size >>> from magi_attention.config import ( ... DistAttnConfig, ... DispatchConfig, ... OverlapConfig, ... MinHeapDispatchAlg, ... UniformOverlapAlg ... ) >>> from magi_attention.common.enum import AttnOverlapMode >>> >>> # Step1. generate a dist_attn_runtime_key to dispatch for flash-attn-varlen style mask >>> # in the following case, we use a causal mask as the key for dispatch, thus it will consider >>> # computation load-balance, communication optimization and computation-communication overlap >>> # according to the causal mask pattern >>> key_for_dispatch = magi_attn_varlen_key( ... cu_seqlen_q=torch.tensor( ... [0, 4096], dtype=torch.int32 ... ), ... cu_seqlen_k=torch.tensor( ... [0, 4096], dtype=torch.int32 ... ), ... num_heads_q=16, ... num_heads_kv=4, ... head_dim=128, ... pad_size=compute_pad_size(4096, 4, 512), # seqlen, cp_size, chunk_size ... chunk_size=512, ... cp_group_or_mesh=dist.new_group(list(range(4)), backend="nccl"), ... causal=True, ... window_size=(-1, -1), ... dist_attn_config=DistAttnConfig( ... dispatch_config=DispatchConfig(alg=MinHeapDispatchAlg()), ... overlap_config=OverlapConfig( ... enable=True, ... mode=AttnOverlapMode.STATIC, ... degree=2, ... min_chunk_size=512, ... max_num_chunks=64, ... alg=UniformOverlapAlg(), ... ), ... ), ... ) >>> >>> # Step2. dispatch the global tensors to local tensors with the same key_for_dispatch >>> local_x, local_label, local_rope = [ ... dispatch(tensor, key_for_dispatch) ... for tensor in [total_x, total_label, total_rope] ... ] >>> >>> # Step3. make a new dist_attn_runtime_key from key_for_dispatch >>> # for a new mask, such as a sliding window causal mask below, >>> # with the same dispatch solution as the causal mask used for dispatch, >>> # i.e. this new key share the same dispatch meta as key_for_dispatch >>> # but it can handle the computation and communication of the new mask >>> # and calculate attn correctly as well, though no optimization is applied for now >>> new_key_for_swa_mask = make_varlen_key_for_new_mask_after_dispatch( ... cu_seqlens_q=torch.tensor([0, 4096], dtype=torch.int32), ... cu_seqlens_k=torch.tensor([0, 4096], dtype=torch.int32), ... causal=False, ... window_size=(512, 0), # sliding window causal mask ... key_for_dispatch=key_for_dispatch, ... ) >>> >>> # Step4. apply QKV projection on local tensors >>> local_q, local_k, local_v = q_project(local_x), k_project(local_x), v_project(local_x) >>> >>> # Step5. calculate distributed attention >>> # for the causal mask used to dispatch with key_for_dispatch >>> local_out1, _ = calc_attn(local_q, local_k, local_v, key_for_dispatch) >>> >>> # Step6. calculate distributed attention >>> # for the new swa mask with the new key >>> # w/o undispatching back and re-dispatching again to avoid OOM >>> local_out2, _ = calc_attn(local_q, local_k, local_v, new_key_for_swa_mask) >>> >>> # Step7. undispatch local attention output to the global one if needed >>> total_out1 = undispatch(local_out1, key_for_dispatch) >>> total_out2 = undispatch(local_out2, new_key_for_swa_mask)
Flexible Dispatch#
If the masks you’re using are not limited to varlen full or varlen causal, but also include sliding window masks or other more diverse types, we recommend using the magi_attn_flex_key as follows.
- magi_attention.api.magi_attn_interface.magi_attn_flex_key()#
This is the most flexible interface, directly passing in
q_ranges,k_rangesandattn_mask_typeto generatedist_attn_runtime_keywhich stores and indicates the inner meta data as a required argument for following APIs includingdispatch,undispatch,calc_attn, etc.- Parameters:
q_ranges (AttnRanges) – the global query ranges.
k_ranges (AttnRanges) – the global key ranges.
attn_mask_type (str | AttnMaskType | list[str | AttnMaskType]) – the global attn mask type (list), represented by str or enum
AttnMaskTypeor their mixed combination.total_seqlen_q (int) – the total seqlen of query.
total_seqlen_k (int) – the total seqlen of key.
num_heads_q (int) – the number of heads for query.
num_heads_kv (int) – the number of heads for key/value.
head_dim (int) – the dimension of each attention head.
pad_size (int) – the size to pad the global input tensor along sequence dim, due to the constraint that the sequence length need to be divisable by
chunk_size * cp_size.chunk_size (int) – the size to chunk the global input tensor along the seqlen dim for later sharding and dispatching among the cp ranks as a granularity factor of computational load-balance.
cp_group_or_mesh (dist.ProcessGroup | DeviceMesh) – process group or device mesh. NOTE: for process group, we only support nccl backend for now, and for device mesh, we only support 1D or 2D mesh for now.
dist_attn_config (DistAttnConfig) – dist attn config.
is_same_source (bool) – is query tensor and key tensor share the same source. Default to
True.is_q_permutable (bool) – is query tensor permutable. Default to
True.is_k_permutable (bool) – is key tensor permutable. Default to
True.
- Returns:
the key stores and indicates the inner meta data.
- Return type:
DistAttnRuntimeKey
Note
For decoder-only transformers (e.g., GPT), it applies ‘self-attn’ as follows:
is_same_sourceis True.Both
qandkare permutable, as long as they are permuted in the same way.
For encoder-decoder transformers (e.g., T5), it applies ‘cross-attn’ as follows:
is_same_sourceis False.qis permutable butkis not.
For multi-modal transformers with external encoders, it applies ‘cross-attn’ as follows:
is_same_sourceis False.qis unpermutable due to self-attn, butkis permutable even in a different way.
Example
>>> import torch >>> import torch.distributed as dist >>> from magi_attention.api import magi_attn_flex_key, dispatch, undispatch, calc_attn >>> from magi_attention.api.functools import compute_pad_size >>> from magi_attention.config import ( ... DistAttnConfig, ... DispatchConfig, ... OverlapConfig, ... MinHeapDispatchAlg, ... UniformOverlapAlg ... ) >>> from magi_attention.common.enum import AttnOverlapMode >>> from magi_attention.common import AttnRanges >>> >>> # Step1. generate a dist_attn_runtime_key to store and indicate the inner meta info >>> dist_attn_runtime_key = magi_attn_flex_key( ... q_ranges=AttnRanges.from_ranges([[0, 2048], [2048, 4096]]), ... k_ranges=AttnRanges.from_ranges([[0, 2048], [0, 4096]]), ... attn_mask_type="full", ... total_seqlen_q=4096, ... total_seqlen_k=4096, ... num_heads_q=16, ... num_heads_kv=4, ... head_dim=128, ... pad_size=compute_pad_size(4096, 4, 512), # seqlen, cp_size, chunk_size ... chunk_size=512, ... cp_group_or_mesh=dist.new_group(list(range(4)), backend="nccl"), ... dist_attn_config=DistAttnConfig( ... dispatch_config=DispatchConfig(alg=MinHeapDispatchAlg()), ... overlap_config=OverlapConfig( ... enable=True, ... mode=AttnOverlapMode.STATIC, ... degree=2, ... min_chunk_size=512, ... max_num_chunks=64, ... alg=UniformOverlapAlg(), ... ), ... ), ... ) >>> >>> # Step2. dispatch the global tensors to local tensors >>> local_x, local_label, local_rope = [ ... dispatch(tensor, dist_attn_runtime_key) ... for tensor in [total_x, total_label, total_rope] ... ] >>> >>> # Step3. apply QKV projection on local tensors >>> local_q, local_k, local_v = q_project(local_x), k_project(local_x), v_project(local_x) >>> >>> # Step4. calculate distributed attention to get the local attention output tensor >>> local_out, meta = calc_attn(local_q, local_k, local_v, dist_attn_runtime_key) >>> >>> # Step5. undispatch local attention output to the global one if needed >>> total_out = undispatch(local_out, dist_attn_runtime_key)
If you want to apply more than one varlen masks within the same training pass, you can use make_flex_key_for_new_mask_after_dispatch to make a new key for the new mask, given the mask arguments and the existing key used for dispatch.
Then the new mask will reuse the same dispatch solution as the mask used for dispatch, but with different meta arguments for computation and communication.
- magi_attention.api.magi_attn_interface.make_flex_key_for_new_mask_after_dispatch()#
Make a new dist attn runtime key for a new mask after dispatch with the given arguments for the new mask and the key used for dispatch
NOTE: this API is useful when you want to apply more than one masks within the same training pass, if your model adopts hybrid-attn structure, in which case, we can only choose one of the masks to dispatch, while the others’re supposed to reuse the same dispatch solution with different meta arguments for computation and communication
WARNING: in such case, we can not guarantee all the masks are load-balanced in computation and optimized in communication for now. However, we are working on it with the dynamic dist-attn solver to optimize the computation and communication for each distinct mask with the same dispatch solution
- Parameters:
q_ranges (AttnRanges) – the global query ranges.
k_ranges (AttnRanges) – the global key ranges.
attn_mask_type (str | AttnMaskType | list[str | AttnMaskType]) – the global attn mask type (list), represented by str or enum
AttnMaskTypeor their mixed combination.key_for_dispatch (DistAttnRuntimeKey) – the key used for dispatch.
dist_attn_config (DistAttnConfig, optional) – the optional new dist attn config. NOTE: if not provided, we will use the same config as the
key_for_dispatch, and if provided, the dispatch config of the new dist attn config won’t be applied to the new mask
- Returns:
- the new dist attn runtime key
for new mask with the same dispatch solution as the
key_for_dispatch
- Return type:
DistAttnRuntimeKey
Example
>>> import torch >>> import torch.distributed as dist >>> from magi_attention.api import magi_attn_flex_key, dispatch, undispatch, calc_attn >>> from magi_attention.api import make_flex_key_for_new_mask_after_dispatch >>> from magi_attention.api.functools import compute_pad_size >>> from magi_attention.config import ( ... DistAttnConfig, ... DispatchConfig, ... OverlapConfig, ... MinHeapDispatchAlg, ... UniformOverlapAlg ... ) >>> from magi_attention.common.enum import AttnOverlapMode >>> from magi_attention.common import AttnRanges >>> >>> # Step1. generate a dist_attn_runtime_key to dispatch for arbitrary mask represented by attn slices >>> # in the following case, we use a causal mask as the key for dispatch, thus it will consider >>> # computation load-balance, communication optimization and computation-communication overlap >>> # according to the causal mask pattern >>> key_for_dispatch = magi_attn_flex_key( ... q_ranges=AttnRanges.from_ranges([[0, 4096]]), ... k_ranges=AttnRanges.from_ranges([[0, 4096]]), ... attn_mask_type="causal", ... total_seqlen_q=4096, ... total_seqlen_k=4096, ... num_heads_q=16, ... num_heads_kv=4, ... head_dim=128, ... pad_size=compute_pad_size(4096, 4, 512), # seqlen, cp_size, chunk_size ... chunk_size=512, ... cp_group_or_mesh=dist.new_group(list(range(4)), backend="nccl"), ... dist_attn_config=DistAttnConfig( ... dispatch_config=DispatchConfig(alg=MinHeapDispatchAlg()), ... overlap_config=OverlapConfig( ... enable=True, ... mode=AttnOverlapMode.STATIC, ... degree=2, ... min_chunk_size=512, ... max_num_chunks=64, ... alg=UniformOverlapAlg(), ... ), ... ), ... ) >>> >>> # Step2. dispatch the global tensors to local tensors with the same key_for_dispatch >>> local_x, local_label, local_rope = [ ... dispatch(tensor, key_for_dispatch) ... for tensor in [total_x, total_label, total_rope] ... ] >>> >>> # Step3. make a new dist_attn_runtime_key from key_for_dispatch >>> # for a new mask, such as a sliding window causal mask below, >>> # with the same dispatch solution as the causal mask used for dispatch, >>> # i.e. this new key share the same dispatch meta as key_for_dispatch >>> # but it can handle the computation and communication of the new mask >>> # and calculate attn correctly as well, though no optimization is applied for now >>> new_key_for_swa_mask = make_flex_key_for_new_mask_after_dispatch( ... q_ranges=AttnRanges.from_ranges([[0, 512], [512, 4096]]), ... k_ranges=AttnRanges.from_ranges([[0, 512], [0, 4096]]), ... attn_mask_type=["causal", "bi_causal"], # sliding window causal mask ... key_for_dispatch=key_for_dispatch, ... ) >>> >>> # Step4. apply QKV projection on local tensors >>> local_q, local_k, local_v = q_project(local_x), k_project(local_x), v_project(local_x) >>> >>> # Step5. calculate distributed attention >>> # for the causal mask used to dispatch with key_for_dispatch >>> local_out1, _ = calc_attn(local_q, local_k, local_v, key_for_dispatch) >>> >>> # Step6. calculate distributed attention >>> # for the new swa mask with the new key >>> # w/o undispatching back and re-dispatching again to avoid OOM >>> local_out2, _ = calc_attn(local_q, local_k, local_v, new_key_for_swa_mask) >>> >>> # Step7. undispatch local attention output to the global one if needed >>> total_out1 = undispatch(local_out1, key_for_dispatch) >>> total_out2 = undispatch(local_out2, new_key_for_swa_mask)
Dispatch Function#
When you get the dist attn runtime key, you can call dispatch function to dispatch the global input tensor(s) to get the padded local tensor(s) along the seqlen dim.
- magi_attention.api.magi_attn_interface.dispatch()#
Pad and dispatch the global input tensor to local input tensor for each cp rank along the seqlen dim.
- Parameters:
x (torch.Tensor) – the global input tensor.
key (DistAttnRuntimeKey) – the key that holds some inner meta data, as a required argument for many APIs of
magi_attention, which users don’t have to bother with.pad_value (float) – the specific value to pad to input tensor. Defaults to
0.
- Returns:
the padded local input tensor.
- Return type:
torch.Tensor
- Raises:
ValueError – If the provided
keydoes not exist in cacheddist_attn_runtime_dict.
Calculate Attention#
After dispatch and QKV projection, you should obtain the local query, key, and value. Then you can calculate the distributed attention by calling calc_attn with the dist attn runtime key to get the local attention output tensor.
- magi_attention.api.magi_attn_interface.calc_attn()#
Calculate distributed attention with local q, k, v tensors.
- Parameters:
q (torch.Tensor) – the local query tensor.
k (torch.Tensor) – the local key tensor.
v (torch.Tensor) – the local value tensor.
key (DistAttnRuntimeKey) – the key that holds some inner meta data, as a required argument for many APIs of
magi_attention, which users don’t have to bother with.sink (torch.Tensor, optional) – the global sink tensor (replicated among cp ranks). Defaults to
Noneto not apply attention sink.softmax_scale (float, optional) – softmax scale. Defaults to
Noneto use the value:1/sqrt(head_dim).softcap (float, optional) – softcap. Defaults to
0.0.return_max_logits (bool, optional) – whether to return the global maximum attention logits (replicated among cp ranks), according to the Muon QK-Clip technique introduced in Kimi K2: https://arxiv.org/pdf/2507.20534.pdf. Defaults to
False.
- Returns:
out (torch.Tensor): local output tensor.
- meta (AttnForwardMeta): Meta information of the attention forward pass,
for now, including local
lse(torch.Tensor) with dtype=torch.float32, and globalmax_logits(torch.Tensor) with dtype=torch.float32, ifreturn_max_logitsisTrue, otherwiseNone.
- Return type:
tuple[torch.Tensor, AttnForwardMeta]
- Shapes:
q: [num_tokens_q_local, num_heads_q, head_dim]
k: [num_tokens_kv_local, num_heads_kv, head_dim]
v: [num_tokens_kv_local, num_heads_kv, head_dim]
sink: [num_tokens_sink_global, num_heads_q]
out: [num_tokens_q_local, num_heads_q, head_dim]
lse: [num_tokens_q_local, num_heads_q]
max_logits: [num_heads_q,]
- Raises:
ValueError – If the provided
keydoes not exist in cacheddist_attn_runtime_dict.- Parameters:
q (Tensor)
k (Tensor)
v (Tensor)
key (DistAttnRuntimeKey)
sink (Tensor | None)
softmax_scale (float | None)
softcap (float)
return_max_logits (bool)
- Return type:
tuple[Tensor, AttnForwardMeta]
Undispatch#
Undispatch Function#
When you need to recover the global output tensor(s) from the local one(s), to compute the loss or some reason else, you can call undispatch function to undispatch the padded local ouput tensor(s) back to the unpadded global tensor along the seqlen dim.
- magi_attention.api.magi_attn_interface.undispatch()#
Undispatch and unpad the local output tensor to global output tensor for each cp rank along the seqlen dim.
- Parameters:
x (torch.Tensor) – the local output tensor.
key (DistAttnRuntimeKey) – the key that holds some inner meta data, as a required argument for many APIs of
magi_attention, which users don’t have to bother with.
- Returns:
the unpadded global output tensor.
- Return type:
torch.Tensor
- Raises:
ValueError – If the provided
keydoes not exist in cacheddist_attn_runtime_dict.
Utility Functions#
Compute Pad Size and Padding#
During the use of MagiAttention, we divide the total_seqlen into multiple chunks of size chunk_size and evenly distribute them across multiple GPUs. To ensure that total_seqlen is divisible by chunk_size and that each GPU receives the same number of chunks, we need to pad the original input.
You can call compute_pad_size to calculate the required padding length, and use this value as a parameter in subsequent functions.
- magi_attention.api.functools.compute_pad_size()#
Compute the size to pad to the input tensor along the seqlen dim at last.
- Parameters:
total_seqlen_q (int) – seqlen of q.
cp_size (int) – The size of cp group.
chunk_size (int) – chunk size to chunk the input tensor x along the seqlen dim for dispatch to control the granularity of computation load-balance.
- Returns:
the number of tokens to pad.
- Return type:
int
After obtaining pad_size, you can use pad_at_dim and unpad_at_dim function to pad and unpad the tensor.
- magi_attention.api.functools.pad_at_dim()#
Pads a tensor along a specified dimension with a given value, either on the left or right side.
- Parameters:
x (torch.Tensor) – Input tensor to be padded.
dim (int) – The dimension along which to apply padding.
pad_size (int) – The number of values to pad.
value (float, optional) – The padding value. Defaults to
0.0.side (str, optional) – Side on which to apply the padding, either
leftorright. Defaults toright.
- Returns:
The padded tensor with the same number of dimensions as the input.
- Return type:
torch.Tensor
- magi_attention.api.functools.unpad_at_dim()#
Removes padding from a tensor along a specified dimension.
- Parameters:
x (torch.Tensor) – Input tensor from which padding will be removed.
dim (int) – The dimension along which to remove padding.
pad_size (int) – The number of elements to remove from the end of the specified dimension.
- Returns:
The tensor with padding removed along the specified dimension.
- Return type:
torch.Tensor
Similarly, you can use pad_size along with total_seqlen and other related information to apply padding to a (q_ranges, k_ranges, mask_types) tuple using apply_padding function.
This function fills the padding region with invalid slices.
- magi_attention.api.functools.apply_padding()#
Appends padding to the attention ranges and updates the corresponding mask type.
This function adds a padding query range at the end of q_ranges, a dummy key range to k_ranges, and appends a FULL attention mask type to maintain alignment. It is typically used when padding is required for alignment or block-wise processing.
- Parameters:
q_ranges (AttnRanges) – Query token ranges before padding.
k_ranges (AttnRanges) – Key token ranges before padding.
attn_mask_type (list[AttnMaskType]) – List of attention mask types corresponding to the ranges.
total_seqlen (int) – The total original sequence length (used to place the padding at the end).
pad_size (int) – The size of the padding to append.
- Returns:
Updated query ranges with padding added.
Updated key ranges with a dummy range for padding.
Updated attention mask type list with a FULL mask for the padding block.
- Return type:
tuple[AttnRanges, AttnRanges, list[AttnMaskType]]
Get Position Ids#
Since MagiAttention needs to permute the input tensor along the seqlen dim, some token-aware ops might be affected, such as RoPE.
Therefore, we provide a function get_position_ids to get the position ids of the input tensor similar to Llama.
- magi_attention.api.magi_attn_interface.get_position_ids()#
Get the global positional ids of the local tensor, as it is sliced from the global tensor after dispatching.
- Parameters:
key (DistAttnRuntimeKey) – the key that holds some inner meta data, as a required argument for many APIs of
magi_attention, which users don’t have to bother with.- Returns:
the global positional ids.
- Return type:
torch.Tensor
- Raises:
ValueError – If the provided
keydoes not exist in cacheddist_attn_runtime_dict.
Get Most Recent Key#
If you have trouble accessing the meta key, and meanwhile you need to get the most recent key for certain cp_group, then you can call get_most_recent_key to get it by specifying the cp_group.
However, we strongly recommend you to access the key passed through the arguments, in case of unexpected inconsistency.
- magi_attention.api.magi_attn_interface.get_most_recent_key()#
Get the most recent inserted key.
NOTE: this is useful when you can not access the key through the arguments, and meanwhile you only need the most recent inserted key. However, we strongly recommend you to access the key passed through the arguments, in case of unexpected inconsistency.
- Returns:
the most recent inserted dist_attn_runtime_key.
- Return type:
DistAttnRuntimeKey
- Parameters:
cp_group (ProcessGroup)
Infer Varlen Masks#
If you want to use a varlen mask where each segment has the same length, we provide a infer_varlen_mask_from_batch function that generates the corresponding cu_seqlens tensors for you.
- magi_attention.api.functools.infer_varlen_mask_from_batch()#
Converts fixed-length full attention into varlen fulll attention format by generating cumulative sequence lengths for queries and keys.
- Parameters:
batch_size (int) – The number of sequences in the batch.
seq_len (int) – The fixed sequence length for each sequence in the batch.
device (str, optional) – The device to allocate the tensors on. Defaults to
"cuda".
- Returns:
A pair of 1D tensors (cu_seqlens_q, cu_seqlens_k), each of shape
[batch_size + 1,], representing the cumulative sequence lengths for the queries and keys respectively.- Return type:
tuple[torch.Tensor, torch.Tensor]
During the use of varlen mask, it is often necessary to reshape a tensor of shape [batch_size × seq_len, ...] into [batch_size × seq_len, ...].
To facilitate the use of the above APIs, we provide the squash_batch_dim function to merge the tensor dimensions.
- magi_attention.api.functools.squash_batch_dim()#
Reshapes a tensor from shape
[b, s, ...]to[b x s, ...], effectively flattening the batch and sequence dimensions into a single leading dimension.- Parameters:
x (torch.Tensor) – Input tensor of shape
[batch_size, seq_len, ...]to be merged.- Returns:
Reshaped tensor of shape
[batch_size x seq_len, ...].- Return type:
torch.Tensor
Moreover, if you have already computed the cu_seqlens tensor and want to generate a varlen mask based on it, we provide the infer_attn_mask_from_cu_seqlens function. This function can create three types of masks—varlen full, varlen causal, and varlen sliding window—according to cu_seqlens, causal, and window_size, and returns the result in the form of a (q_ranges, k_ranges, mask_types, total_seqlen_q, total_seqlen_k).
- magi_attention.api.functools.infer_attn_mask_from_cu_seqlens()#
Infer query ranges, key ranges and other arguments for flexible attn mask representation from cu_seqlens, widely used for varlen masks.
- Parameters:
cu_seqlens_q (torch.Tensor) – cumulative sequence lengths for queries
cu_seqlens_k (torch.Tensor) – cumulative sequence lengths for keys
causal (bool, optional) – whether the varlen attention mask is causal. Defaults to
False.window_size (tuple[int, int], optional) – window_size of sliding window mask which represents
[window_size_left, window_size_right]. The parameter is effective only whencausalisFalse; whencausalisTrue, it is required to be(-1, -1). Defaults to(-1, -1).
- Returns:
query ranges, key ranges, attn mask type list, total seqlen of q, total seqlen of k
- Return type:
tuple[AttnRanges, AttnRanges, list[AttnMaskType], int, int]
Infer Sliding Window Masks#
In the design of MagiAttention, we use a (q_range, k_range, masktype) tuple to represent a slice.
For sliding window masks, we do not provide a dedicated masktype to represent them directly.
However, a sliding window mask can be decomposed into a combination of existing masktypes such as full, causal, inv_causal, and bi_causal.
If you’re unsure how to perform this decomposition, we provide infer_attn_mask_from_sliding_window function to handle this process for you.
- magi_attention.api.functools.infer_attn_mask_from_sliding_window()#
Convert only one sliding window masks into representations using q_range, k_range, and mask type. The mask type is specified using window_size.
- Parameters:
q_range (AttnRange) – q_range of this sliding window mask
k_range (AttnRange) – k_range of this sliding window mask
window_size (tuple[int, int]) – window_size of sliding window mask which represents
[window_size_left, window_size_right]
- Returns:
processed
(q_ranges, k_ranges, masktypes)triple, sliding window mask have been cutted into triple representation.- Return type:
tuple[AttnRanges, AttnRanges, list[AttnMaskType]]
Example
Here’s an example of
infer_attn_mask_from_sliding_window:>>> q_ranges, k_ranges, attn_mask_type = infer_attn_mask_from_sliding_window( ... q_range=AttnRange.from_range([5, 15]), ... k_range=AttnRange.from_range([5, 15]), ... window_size=(2, 3), ... )
The code above represents the sliding window mask within the
[5, 15] x [5, 15]region with a window size of(2, 3).