API Reference#

Flexible Flash Attention#

To support computing irregular-shaped masks, we implemented a flexible_flash_attention kernel, which can be invoked through the following interface.

Dispatch#

Varlen Dispatch#

If you’re using a mask defined by cu_seqlens, such as a varlen full or varlen causal mask, we’ve designed a similar interface magi_attn_varlen_key inspired by FlashAttention’s API as follows, making it easy for you to get started quickly.

If you want to apply more than one masks within the same training pass, you can use make_varlen_key_for_new_mask_after_dispatch to make a new key for the new mask, given the mask arguments specific for varlen mask in flash-attn-varlen style and the existing key used for dispatch.

Then the new mask will reuse the same dispatch solution as the mask used for dispatch, but with different meta arguments for computation and communication.

Flexible Dispatch#

If the masks you’re using are not limited to varlen full or varlen causal, but also include sliding window masks or other more diverse types, we recommend using the magi_attn_flex_key as follows.

If you want to apply more than one varlen masks within the same training pass, you can use make_flex_key_for_new_mask_after_dispatch to make a new key for the new mask, given the mask arguments and the existing key used for dispatch.

Then the new mask will reuse the same dispatch solution as the mask used for dispatch, but with different meta arguments for computation and communication.

Dispatch Function#

When you get the dist attn runtime key, you can call dispatch function to dispatch the global input tensor(s) to get the padded local tensor(s) along the seqlen dim.

Calculate Attention#

After dispatch and QKV projection, you should obtain the local query, key, and value. Then you can calculate the distributed attention by calling calc_attn with the dist attn runtime key to get the local attention output tensor.

Roll#

Roll Function#

After dispatching, if you need to cyclically shift the local tensor along the sequence dimension, you can call roll with the dist attn runtime key. This is primarily designed for Multi-Token Prediction (MTP), where the labels need to be shifted by one or more positions relative to the input tokens. It can also serve other use cases such as relative positional offsets or shifted-window patterns.

Semantically, roll is equivalent to undispatch -> torch.roll -> dispatch, but avoids materialising the full global tensor, cutting peak memory from O(N) to O(N/P) and reducing communication volume by ~P times.

Undispatch#

Undispatch Function#

When you need to recover the global output tensor(s) from the local one(s), to compute the loss or some reason else, you can call undispatch function to undispatch the padded local ouput tensor(s) back to the unpadded global tensor along the seqlen dim.

Utility Functions#

Compute Pad Size and Padding#

During the use of MagiAttention, we divide the total_seqlen into multiple chunks of size chunk_size and evenly distribute them across multiple GPUs. To ensure that total_seqlen is divisible by chunk_size and that each GPU receives the same number of chunks, we need to pad the original input.

You can call compute_pad_size to calculate the required padding length, and use this value as a parameter in subsequent functions.

After obtaining pad_size, you can use pad_at_dim and unpad_at_dim function to pad and unpad the tensor.

Similarly, you can use pad_size along with total_seqlen and other related information to apply padding to a (q_ranges, k_ranges, mask_types) tuple using apply_padding function.

This function fills the padding region with invalid slices.

Get Position Ids#

Since MagiAttention needs to permute the input tensor along the seqlen dim, some token-aware ops might be affected, such as RoPE.

Therefore, we provide a function get_position_ids to get the position ids of the input tensor similar to Llama.

Get Most Recent Key#

If you have trouble accessing the meta key, and meanwhile you need to get the most recent key for certain cp_group, then you can call get_most_recent_key to get it by specifying the cp_group.

However, we strongly recommend you to access the key passed through the arguments, in case of unexpected inconsistency.

Infer Varlen Masks#

If you want to use a varlen mask where each segment has the same length, we provide a infer_varlen_mask_from_batch function that generates the corresponding cu_seqlens tensors for you.

During the use of varlen mask, it is often necessary to reshape a tensor of shape [batch_size × seq_len, ...] into [batch_size × seq_len, ...].

To facilitate the use of the above APIs, we provide the squash_batch_dim function to merge the tensor dimensions.

Moreover, if you have already computed the cu_seqlens tensor and want to generate a varlen mask based on it, we provide the infer_attn_mask_from_cu_seqlens function. This function can create four types of masks—varlen full, varlen causal, varlen sliding window, and varlen sliding window with global attention—according to cu_seqlens, causal, window_size, and global_window_size, and returns the result in the form of a (q_ranges, k_ranges, mask_types, total_seqlen_q, total_seqlen_k).

When global_window_size is set to a positive integer, every query in a sample always attends to the first global_window_size key tokens of that sample in addition to the sliding window, which is useful for architectures that require certain prefix tokens (e.g. system prompt, sink tokens) to be globally visible. To prevent information leakage, a query at relative position i can only see global tokens at positions [0, min(global_window_size, i + window_size_right + 1)).

Infer Sliding Window Masks#

In the design of MagiAttention, we use a (q_range, k_range, masktype) tuple to represent a slice.

For sliding window masks, we do not provide a dedicated masktype to represent them directly.

However, a sliding window mask can be decomposed into a combination of existing masktypes such as full, causal, inv_causal, and bi_causal.

If you’re unsure how to perform this decomposition, we provide infer_attn_mask_from_sliding_window function to handle this process for you.