MagiAttention
A Distributed Attention Towards Linear Scalability for Ultra-Long Context, Heterogeneous Mask Training
Overview

Training large-scale models for video generation presents two major challenges: (1) The extremely long context length of video tokens, which reaching up to 4 million during training, results in prohibitive computational and memory overhead. (2) The combination of block-causal attention and Packing-and-Padding (PnP) introduces highly complex attention mask patterns.
To address these challenges, we propose MagiAttention, which aims to support a wide variety of attention mask types with kernel-level flexibility, while achieving linear scalability with respect to context-parallel (CP) size across a broad range of scenarios, particularly suitable for training tasks involving ultra-long, heterogeneous mask training like video-generation for Magi-1.
Introduction
Training large-scale autoregressive diffusion models like \magi for video generation presents two major challenges:
-
The extremely long context length of video tokens, which reaching up to 4 million during training, results in prohibitive computational and memory overhead. Context-Parallelism (CP) is designed for dealing such long context challenge, but existing state-of-the-art CP methods
face scalability limitations that face scalability limitations due to size constraints or the high communication overhead inherent in inefficient ring-style point-to-point (P2P) patterns. While recent efforts dynamically adjust CP sizes to avoid unnecessary sharding and redundant communication for shorter sequences, they still incur extra memory overhead for NCCL buffers and involve complex scheduling to balance loads and synchronize across different subsets of ranks. -
The combination of block-causal attention and Packing-and-Padding (PnP) introduces highly complex attention mask patterns with variable sequence lengths, which cannot be efficiently handled by existing attention implementations.
To address the aforementioned challenges, we propose MagiAttention, which aims to support a wide variety of attention mask types (\emph{i.e.} kernel flexibility) while achieving linear scalability with respect to context-parallel (CP) size across a broad range of scenarios. Achieving this goal depends on meeting the following fundamental conditions:
- Linearly Scalable Attention Kernel: The performance of the attention kernel should not degradate as CP size increases. To this end, we introduce Flex-Flash-Attention, an extension of FlashAttention-3 (FA3), which native considers the efficiency impact of attention mask partitioning in distributed environments. It supports distributable mask representations with a tailored kernel implementation to ensure scalability while accommodating a broader range of attention mask types.
- Balanced Computational Workloads: Imbalances in the computational load across CP ranks lead to unavoidable idle bubbles that hinder scalability. MagiAttention is natively designed to ensure Computation Load Balancing, mitigating such inefficiencies.
- Full Overlap of Communication and Computation: Without sufficient overlap, increasing CP size results in communication-induced idle time on GPUs, impairing scalability. MagiAttention introduces novel Zero-Redundant Communication Primitives to minimize communication overhead, along with an Adaptive Multi-Stage Overlap strategy that enables effective communication-computation overlap.
The overview of MagiAttention is shown in Overview, and we will introduce key designs in the following Methodology section, with comprehensive experimental results presented in Experiment.
Related Work
To tackle the ultra-long context challenge in large-scale model training, the distributed attention mechanism, or context parallelism (CP), is essential.
However, current strategies fall short in our demanding settings. DeepSpeed’s Ulysses
Worse still, for irregular attention mask patterns like the aforementioned varlen masks, classic Ring-Attention-based CP strategies are facing more challenges, besides the attention kernel limitations. First, the naive sequential even sharding along the sequence dimension causes uneven distribution of the varlen mask area, leading to imbalanced computational loads across CP ranks. Although the customized zigzag sharding design

Second, the communication overhead issue is exacerbated under sparse varlen mask settings, as entire sequence chunks are still transferred across all CP ranks even when not all ranks require them, might causing over 30% redundant communication costs as illustrated in Zero-Redundant Comm. Third, the former challenges cause the pipeline compute-communication overlap strategy fails more often due to imbalanced loads and large communication overheads, further limiting scalability.
Recent efforts
Methodology
Flex-Flash-Attn
Flash Attention
Therefore, we introduce Flex-Flash-Attention (FFA), which is natively designed for distribution scenarios and provides greater flexibility in handling diverse attention mask types. The core idea behind FFA is to generalize a distributable formulation for irregular attention masks by decomposing the entire mask into multiple computational units, each referred to as an $\mathrm{AttnSlice}$. Each $\mathrm{AttnSlice}$ is defined by a triplet $\mathrm{(QRange, KRange, MaskType)}$, which specifies a submask with a basic shape bounded by a contiguous 2D query-key region as seen in the figure below.

Using this formulation, as shown in the figure below, a wide variety of commonly used attention masks, including the varlen block-causal mask for autoregressive video generation, can be expressed as a composition of multiple such triplets even after sharding and rearrrangement in distributed settings, making FFA highly suitable for distributed attention computation.

Built on Flash-Attention 3 (FA3) kernels
However, even though we can express most mask patterns using $\mathrm{AttnSlice}$ with two common mask type $\lbrace\mathrm{FULL}, \mathrm{CAUSAL}\rbrace$, but when comes to the mask patterns such as $\textit{sliding-window}$, they are quite inefficient (in such case, we have to express each row one by one). Therefore, we design two new but a little bit bizarre mask types named $\lbrace\text{INV-CAUSAL}, \text{BI-CAUSAL}\rbrace$ to efficiently represent more specific mask patterns, and provide some basic examples about the current $4$ mask types we support in the following figures.
Although $\mathrm{AttnSlice}$ can represent most mask patterns using two common types ($\mathrm{FULL}$ and $\mathrm{CAUSAL}$), it is inefficient for patterns like $\textit{sliding-window}$, which requires row-by-row expression. To address this, we introduce two new mask types, $\mathrm{INV\text{-}CAUSAL}$ and $\mathrm{BI\text{-}CAUSAL}$, to efficiently represent more specific $\textit{sliding-window}$-style patterns. We provide basic examples of these four mask types in the following figures.



Based on the four mask types currently supported, we provide examples of how to express common $\textit{sliding-window}$-style mask patterns using the $\mathrm{AttnSlice}$ formulation, as illustrated in the figure below.

Comp Load-Balance
In context-parallel settings, different CP ranks may be assigned heterogeneous attention masks, resulting in imbalanced computational workloads across ranks. Ring-Attention, as mentioned in Related Work, employs a specialized partitioning strategy designed specifically for causal attention, which limits its applicability to more general attention patterns. To overcome this limitation, we propose a generic and efficient dispatch solver that enables balanced workload distribution across CP ranks for a broad range of attention types.
First, to enable finer-grained control, we propose a chunk-wise permutable sharding strategy as seen in Overview. Specifically, the entire mask is evenly partitioned along the query-dimension into chunks, each associated with a submask area: $\lbrace(C_i, \mathrm{Area}(C_i))\rbrace_{i=1}^n$, where $C_i$ indicates i-th chunk, $\mathrm{Area}(C_i)$ is the mask area of $C_i$, $n$ is $\frac{seqlen}{\textit{chunk_size}}$, and $\textit{chunk_size}$ is a hyperparameter controlling granularity.
These chunks are then equally assigned to $\textit{cp_size}$ buckets, with each bucket containing the exact same number of chunks to ensure token-level load balance in non-attention modules, attaching with a summed submask area, denoted as $\lbrace(B_j, \mathrm{SumArea}(B_j))\rbrace_{j=1}^{\textit{cp_size}}$.
With above strategy, we could fine-grained control the computational workloads of each CP rank, and the load-balancing dispatch becomes a combinatorial optimization problem, defined as finding an optimal mapping function $f^*: \lbrace C_i\rbrace_{i=1}^n \rightarrow \lbrace B_j\rbrace_{j=1}^{\textit{cp_size}}$ follows:
\[\begin{aligned} &f^* = \arg \min\limits_{f}\max\limits_{j}\left\{\mathrm{SumArea}(B_j)\right\} \label{eq:comp_load_balance}\\ &\text{s.t.}\;\;|B_j| = \frac{n}{\textit{cp_size}}, \;\; seqlen \;\%\; (\textit{cp_size} \times \textit{chunk_size}) = 0\nonumber \end{aligned}\]However, this optimization is a known NP-hard problem, making it impractical to find an optimal solution on-the-fly during each training iteration, especially given the varying mask patterns across micro-batches. Thus, we propose an efficient greedy algorithm as shown below that provides a suboptimal yet effective solution within $O(n\log n)$ complexity.

Zero-Redundant Comm
The existing ring-style implementation uses point-to-point send/recv communication primitives, which cannot provide sufficient communication granularity, resulting in redundant communication. Take causal mask as an example, we analyze the redundant communication by recording the distribution of remote key-value ($\mathrm{KV}$) requests and their gradients ($\mathrm{dKV}$) under sparse attention masks. As shown in the following figure, $\mathrm{KV}_0$ is required by all queries and should be sent to all devices via Broad-Cast in the forward pass, with $\mathrm{dKV}_0$ reduced via All-Reduce in the backward pass. In contrast, $\mathrm{KV}_7$ is only needed by its host device but still circulates through all devices, and this redundancy intensifies in varlen scenarios.

To address this, as illustrated in the figure below, we introduce two communication primitives: $\textit{Group-Cast}$ and $\textit{Group-Reduce}$, which model the communication patterns of low-demand $\mathrm{KV}$ and $\mathrm{dKV}$. For example, in the causal mask, $\mathrm{KV}_5$ on $\mathrm{rank}_2$ is required only by ${\mathrm{Q}_6,\mathrm{Q}_7}$ and should be sent exclusively to the target ranks ${\mathrm{rank}_0, \mathrm{rank}_1}$ via Group-Cast, while the partial $\mathrm{dKV}_5$ is collected and reduced back to $\mathrm{rank}_2$ via Group-Reduce accordingly.

As no existing communication kernels support these primitives, we prototype them using All-to-All-v, achieving zero-redundant communication in both forward and backward passes. However, this approach introduces extra pre-/post-processing overhead, similar to (un)permutation in expert parallelism (EP)
Multi-Stage Overlap
Leveraging previous optimizations, we achieve high-performance computation through an efficient kernel and balanced workload dispatch, while minimizing communication overhead with our new primitives. To drive true linear scalability, we further improve end-to-end performance by introducing a multi-stage compute-communication overlap strategy, that effectively hides communication latency and adaptively optimizes overlap through manual or automatic tuning.
Similar to prior works

In the forward pass, the scheduler first launches the Group-Cast kernel to prefetch the next remote $\mathrm{KV}$, then asynchronously executes the FFA kernel for partial attention computation, hiding all communication behind computation. To prevent all SMs from being occupied by the attention kernel, by default, we ensure the communication kernel picked first by setting CUDA_DEVICE_MAX_CONNECTIONS=1
sm_margin
argument for the FFA kernel, to preserve some SMs for communication kernels to be launched.
In the backward pass, besides prefetching the next $\mathrm{KV}$, the Group-Reduce kernel reduces the last $\mathrm{dKV}$ in a separate CUDA stream before launching the FFA kernel for the current stage, ensuring communication is overlapped across all stages except the final $\mathrm{dKV}$ reduction. Due to PyTorch’s one-to-one mapping for process groups and collective communication streams including All-to-All-v
To adaptively control overlap granularity, we further introduce a tunable hyperparameter, $\texttt{num_stages}$, accounting for varying compute-to-communication ratios across training setups, microbatches, or between forward and backward passes. This parameter can be manually configured or automatically determined by our $\textit{overlap solver}$, with a simple dynamic search algorithm as shown below.

Experiment
Kernel-Level
To demonstrate FFA kernels’ state-of-the-art performance and flexibility in handling ultra-long, heterogeneous mask training, we measure the throughput (in $\texttt{TFLOPs/s}$) on Hopper GPUs for both forward and backward passes of prevalent attention kernels across standard and irregular mask patterns.
settings | value |
---|---|
batch size (b) | 1 |
number of heads (nh) | nhq:nhk:nhv = 64:8:8 (GQA) |
head dimension (hd) | 128 |
dtype | torch.bfloat16 |
window size | 1024 (for sliding window masks only) |
Benchmark settings: for each mask pattern, we vary the sequence length $seqlen$ from $4k,8k,16k,…,$ up to $128k$ ($seqlen_q = seqlen_k = seqlen$) while measuring the throughput (in $\texttt{TFLOPs/s}$) for forward and backward passes of different attention kernels. Other configurations are fixed using common training settings (see the table above) to focus on the impact of sequence length and mask pattern. For the varlen packed data, we simply follow the variable sequence length distribution in the open-sourced dataset

Results are reported in the following figures.






Module-Level
To validate the scalability of MagiAttention, we assess the throughput (in $\texttt{TFLOPs/s}$) of the attention module propagation as the sequence length and parallel size increases for both forward and backward passes across various mask patterns, and compare it with several state-of-the-art CP strategies.
To validate the scalability of MagiAttention, we assess the per-GPU throughput (in $\texttt{TFLOPs/s/GPU}$) of the attention module during both forward and backward propagation, as the sequence length and parallel size increase. This assessment is compared against common CP strategies including Ring-Attention
The experiments are conducted on a large-scale productive GPU cluster
The tensor-parallel size $\textit{tp_size}$ is fixed at 8, with sequence-parallel enabled. Other data and model configurations for different mask types are the same as in the table in Kernel-Level Experiments.
Therefore, in every training setting, each rank is assigned constantly with $seqlen=64k$, $\textit{num_heads_q} = 8$ and $\textit{num_heads_k} = 1$ for attention propagation, while the remaining activations stays $seqlen=8k$, $\textit{num_heads_q} = 64$ and $\textit{num_heads_k} = 8$ with SP enabled. This setup simulates a common training configuration.
The results are presented in the following figures.




Discussion
comming soon …
Future Work
comming soon …
FAQ
comming soon …
Acknowledgement
We are grateful to the contributors listed below for their valuable contributions during the early stages of MagiAttention.
Member | Affiliations | GitHub Account | |
---|---|---|---|
Zewei Tao | SandAI | zeweitao@sand.ai | littsk |
Yunpeng Huang | SandAI, Nanjing University | yunpenghuang@sand.ai,hyp@smail.nju.edu.cn | Strivin0311 |
Qiangang Wang | Nanjing University | 522024330081@smail.nju.edu.cn | WT1W |
Hanwen Sun | SandAI, Peking University | sunhanwen@stu.pku.edu.cn | hanwen-sun |
Tao Bu | Nanjing University | 502024330002@smail.nju.edu.cn | Big-TRex |
WenYang Fang | Nanjing University | fwy@smail.nju.edu.cn | kagami4243 |
Siyuang Yan | Nanjing University | siyuanyan@smail.nju.edu.cn | FibonaccciYan |
Zixu Jiang | Nanjing University | 522023330040@smail.nju.edu.cn | 191220042 |
Dingkun Xu | Nanjing University | 211220090@smail.nju.edu.cn | PureDimension |
Mingyu Liang | Nanjing University | mingyuliang518@gmail.com | gaomusiki |
Jingwei Xu | Nanjing University | jingweix@nju.edu.cn | paragonlight |
Citation
If you use MagiAttention in your research, please cite:
@misc{magiattention2025,
title={MagiAttention: A Distributed Attention Towards Linear Scalability for Ultra-Long Context, Heterogeneous Mask Training},
author={Zewei, Tao and Yunpeng, Huang},
year={2025},
howpublished={\url{https://github.com/SandAI-org/MagiAttention/}},
}