Flash Attention 2 Math Derivation
This blog post is a detailed math derivation of well-known Flash Attention 2 (FA2), a memory-efficient, highly optimized and de facto kernel implementation [Dao, 2023, Dao et al., 2022, Shah et al., 2024] of scaled dot-product attention operation introduced by Transformer [Vaswani et al., 2023], which is re-implemented and further extended in Flex-Flash-Attention kernels of MagiAttention [Zewei and Yunpeng, 2025].
Note
We omit specific softmax strategies, e.g. softmax_scale, softcap, attention_sink, for simplicity.
We omit any batch dimensions, e.g. batch_size, num_heads, but keep only the seqlen dimension and the head dimension for simplicity.
Forward
Standard Attention Forward
(20)\[\begin{split}\begin{cases}
\begin{aligned}
&S = \mathrm{mask}(QK^{\mathrm{T}} + bias) \in \mathbb{R}^{N\times N} \\
&P = \mathrm{softmax}_{row\text{-}wise}(S) = \mathrm{diag}(l)^{-1}A \in \mathbb{R}^{N\times N},\\
&\quad \text{where}\; l = \mathrm{rowsum}(A) \in \mathbb{R}^{N}, \space A = \exp{(S - \mathrm{rowmax}(S))} \in \mathbb{R}^{N\times N} \\
&O = PV \in \mathbb{R}^{N\times d}
\end{aligned}
\end{cases}\end{split}\]
\[given\quad Q,K,V \in \mathbb{R}^{N\times d}, \space bias \in \mathbb{R}^{N\times N}\]
Flash Attention Forward
Step1. Basic Row Decomposition
(21)\[\begin{split}\begin{cases}
\begin{aligned}
&S = \left[ S_1\quad S_2 \right] \in \mathbb{R}^{B_q\times 2B_k},\\
&\quad\text{where}\; S_i = \mathrm{mask}(QK_i^{\mathrm{T}} + \text{bias}_{i}) \in \mathbb{R}^{B_q\times B_k},\\
&\quad Q \in \mathbb{R}^{B_q\times d},\ K_i \in \mathbb{R}^{B_k\times d},\ i \in \{1,2\} \\
&m = \max\left( \mathrm{rowmax}(S_1), \mathrm{rowmax}(S_2) \right) \in \mathbb{R}^{B_q} \\
&A = \left[ A_1\quad A_2 \right] \in \mathbb{R}^{B_q\times 2B_k},\\
&\quad\text{where}\; A_i = \exp(S_i - m) \in \mathbb{R}^{B_q\times B_k},\ i \in \{1,2\} \\
&l = \mathrm{rowsum}(A_1) + \mathrm{rowsum}(A_2) \in \mathbb{R}^{B_q} \\
&P = \left[ P_1\quad P_2 \right] = \mathrm{diag}(l)^{-1} \left[ A_1\quad A_2 \right] \in \mathbb{R}^{B_q\times 2B_k} \\
&O = \left[ P_1\quad P_2 \right] \left[
\begin{matrix}
V_1 \\
V_2
\end{matrix}
\right] = \mathrm{diag}(l)^{-1} \left( A_1V_1 + A_2V_2 \right) \in \mathbb{R}^{B_q\times d}
\end{aligned}
\end{cases}\end{split}\]
Step2. Online Softmax Correction
(22)\[\begin{split}\text{base}:
\begin{cases}
\begin{aligned}
&m_1 = \mathrm{rowmax}(S_1) \in \mathbb{R}^{B_q}\notag\\
&A_1 = \exp(S_1 - m_1) \in \mathbb{R}^{B_q\times B_k}\notag\\
&l_1 = \mathrm{rowsum}(A_1)\in \mathbb{R}^{B_q}\notag\\
&P_1 = \mathrm{diag}(l_1)^{-1}A_1\in \mathbb{R}^{B_q\times B_k}\notag\\
&O_1 = P_1V_1\in \mathbb{R}^{B_q\times d}\notag
\end{aligned}\\
\end{cases}\end{split}\]
(23)\[\begin{split}\text{update}:
\begin{cases}
\begin{aligned}
&m_2 = \max(m_1, \mathrm{rowmax}(S_2)) \in \mathbb{R}^{B_q}\\
&A_2 = \exp(S_2 - m_2) \in \mathbb{R}^{B_q\times B_k}\notag\\
&l_2 = \delta_m l_1 + \mathrm{rowsum}(A_2)\in \mathbb{R}^{B_q}\\
&P_2 = \mathrm{diag}(l_2)^{-1}A_2\in \mathbb{R}^{B_q\times B_k}\notag\\
&O_2 = \mathrm{diag}(l_1/l_2)^{-1}\delta_m O_1 + P_2V_2 \in \mathbb{R}^{B_q\times d} \notag
\end{aligned}
\end{cases}\end{split}\]
\[\begin{aligned}
&\text{where}\; \delta_m := \exp(m_1 -m_2)
\end{aligned}\]
Step3. Double-Loop Tiling
(24)\[\begin{split}\begin{cases}
\begin{aligned}
&\text{load}\space Q_i \in \mathbb{R}^{B_q\times d}\space \text{from HBM to SRAM}\notag\\
&\text{initialize}\space \tilde{O_{i}}^{(0)} = 0_{ B_q\times d },\space l_i^{(0)} = 0_{B_q} \in \mathbb{R}^{B_q},\space m_i^{(0)} = -\infty_{B_q} \in \mathbb{R}^{B_q} \notag\\
\\
&\text{loop over}\space j := 1 \rightarrow N_k\space \text{, and for each}\space j \text{-th inner iteration:} \notag\\
&\quad\text{compute}\space O_i = \mathrm{diag}(l_{i}^{(N_k)})^{-1} \tilde{O_i}^{(N_k)}\in \mathbb{R}^{B_q\times d}\\
&\quad\quad\text{and write it to HBM to return as output} \notag\\
&\quad\text{compute}\space \mathrm{LSE_i} = m_i^{(N_k)} + \log(l_i^{(N_k)})\in \mathbb{R}^{B_q}\\
&\quad\quad\text{and write it to HBM to save for backward} \notag
\end{aligned}
\end{cases}\end{split}\]
\[\begin{split}\begin{aligned}
&\text{where}\; \text{LSE}( \mathbf{x}) := \log\left(\sum\limits_{i=1}^n \exp(x_i)\right) = \max( \mathbf x) + \text{LSE}( \mathbf{x}-\max( \mathbf x)),\space \mathbf x \in \mathbb{R}^{n},\\
&\quad\text{and}\space \tilde{O_i} \space\text{is the un-normalized} \space O_i, \space\text{i.e.}\space O_i = \mathrm{diag}(l_{i})^{-1}\tilde{O_i}
\end{aligned}\end{split}\]
in which each inner loop goes across \(j := 1 \rightarrow N_k\) for each block of \(K_j,V_j\) to update \(\tilde{O_i}^{(j)}, l_i^{(j)}, m_i^{(j)}\), where \(N_k = \lceil\frac{N}{B_k}\rceil\), and for each \(j\)-th inner iteration:
(25)\[\begin{split}\begin{cases}
\begin{aligned}
&\text{load}\space K_j, V_j \in \mathbb{R}^{B_k\times d}\space \text{from HBM to SRAM} \notag\\
&\text{compute}\space S_{i}^{(j)} = \text{mask}(Q_iK_j^{\mathrm T} + bias_{(i,j)}) \in \mathbb{R}^{B_q\times B_k} \notag\\
&\text{update}\space m_i^{(j)} = \max\big(m_i^{(j-1)}, \mathrm{rowmax}(S_{i}^{(j)})\big) \in \mathbb{R}^{B_q} \notag\\
&\text{compute}\space A_i^{(j)} = \exp(S_i^{(j)} - m_i^{(j)}) \in \mathbb{R}^{B_q\times B_k} \notag\\
&\text{update}\space l_i^{(j)} = \delta_{m_i^{(j)}}l_i^{(j-1)} + \mathrm{rowsum}(A_i^{(j)})\in \mathbb{R}^{B_q} \notag\\
&\text{update}\space \tilde{O_i}^{(j)} = \mathrm{diag}(\delta_{m_i^{(j)}})^{-1}\tilde{O_i}^{(j-1)} + A_i^{(j)}V_j\in \mathbb{R}^{B_q\times d} \notag
\end{aligned}
\end{cases}\end{split}\]
\[\begin{aligned}
&\text{where}\; \delta_{m_i^{(j)}} := \exp(m_i^{(j-1)} -m_i^{(j)})
\end{aligned}\]
Backward
Standard Attention Backward
(26)\[\begin{split}\begin{cases}
\begin{aligned}
&\mathrm{d}{V} = P^{\mathrm T} \mathrm{d}{O} \in \mathbb{R}^{N\times d}, \quad \mathrm{d}{P} = \mathrm{d}{O}V^{\mathrm T} \in \mathbb{R}^{N\times N} \notag \\
&\mathrm{d}{S_{i:}} = \cfrac{\partial P_{i:}}{\partial S_{i:}}\cdot\mathrm{d}{P_{i:}}\in \mathbb{R}^{N}, \\
&\quad where\space \cfrac{\partial P_{i:}}{\partial S_{i:}} = J_{softmax} = \mathrm{diag}(P_{i:}) - P_{i:}P_{i:}^{\mathrm T} \in \mathbb{R}^{N\times N} \notag \\
&\mathrm{d}{Q} = \mathrm{d}{S}K \in \mathbb{R}^{N\times d}, \quad \mathrm{d}{K} = \mathrm{d}{S}^{\mathrm T}Q \in \mathbb{R}^{N\times d} \notag
\end{aligned}
\end{cases}\end{split}\]
\[\begin{split}\begin{aligned}
&\text{where}\space\space \mathrm{d}X \space\space\text{denotes}\space \cfrac{\partial{\mathbb{loss}}}{\partial{X}}, \space\text{and}\space X_{i:} \space\text{denotes the column vector}\\
&\text{made of the $i$-th row of}\space X, \space\text{for any matrix}\space X
\end{aligned}\end{split}\]
\[given\quad \mathrm{d}{O} \in \mathbb{R}^{N\times d}\]
Flash Attention Backward
Step0. Save LSE during forward
for each \(i\)-th row:
(27)\[\begin{split}\begin{cases}
\begin{aligned}
&\text{since}\space P_{i:} = \cfrac{A_{i:}}{l_{i:}} \in \mathbb{R}^{B_k}, \; l_{i} = \mathrm{sum}(A_{i:}) \in \mathbb{R}, \\
&\quad\quad A_{i:} = \exp(S_{i:} - m_{i}) \in \mathbb{R}^{B_k}, \; m_{i} = \max(S_{i:})\in \mathbb{R} \notag\\
&\text{therefore}\space P_{i:} = \cfrac{\exp(S_{i:} - m_{i})}{\mathrm{sum}(\exp(S_{i:} - m_{i}))} = \cfrac{\exp(S_{i:} - m_{i})}{\exp(\mathrm{LSE}(S_{i:} - m_{i}))}\\
&\quad\quad\quad\quad = \exp(S_{i:} - (m_{i} + \mathrm{LSE}(S_{i:} - m_i))) \notag\\
\\
&\text{and according to}\space \text{LSE}( \mathbf{x}) = \max( \mathbf x) + \text{LSE}( \mathbf{x}-\max( \mathbf x)), \notag\\
&\text{therefore}\space P_{i:} = \exp(S_{i:} - (m_{i} + \mathrm{LSE}(S_{i:} - m_i)))\\
&\quad\quad\quad\quad = \exp(S_{i:} - \mathrm{LSE}(S_{i:})) = \exp(S_{i:} - \mathrm{LSE_i})\notag
\end{aligned}
\end{cases}\end{split}\]
so we can jump storing \(m_i, l_i\) to compute \(A_{i:}\), but computing \(P_{i:}\) from \(S_{i:}\) directly with only \(\mathrm{LSE_i}\)
Step1. Compute Delta as a Pre-Processing
for each \(i\)-th row:
(28)\[\begin{split}\begin{cases}
\begin{aligned}
&\text{since}\space \mathrm{d}{S_{i:}} = \cfrac{\partial P_{i:}}{\partial S_{i:}}\cdot\mathrm{d}{P_{i:}} = (\mathrm{diag}(P_{i:}) - P_{i:}P_{i:}^{\mathrm T} )\cdot\mathrm{d}{P_{i:}}\\
&\quad\quad = P_{i:}\odot\mathrm{d}{P_{i:}} - (P_{i:}P_{i:}^{\mathrm T})\mathrm{d}{P_{i:}} \in \mathbb{R}^{B_k}\notag\\
&\text{then}\space \mathrm{d}{S_{i:}} = P_{i:}\odot\mathrm{d}{P_{i:}} - P_{i:}(P_{i:}^{\mathrm T}\mathrm{d}{P_{i:}}) = P_{i:}\odot\mathrm{d}{P_{i:}} - (P_{i:}^{\mathrm T}\mathrm{d}{P_{i:}})P_{i:}\notag\\
\\
&\text{define}\space \Delta_{i} = P_{i:}^{\mathrm T}\mathrm{d}{P_{i:}} \in \mathbb{R},\\
&\text{and because}\space \mathrm{d}{P_{i:}} = (\mathrm{d}{O_{i:}}^{\mathrm T}V^{\mathrm T})^{\mathrm T} = VdO_{i:} \in \mathbb{R}^{B_k}\notag\\
&\text{therefore}\space \Delta_{i} = P_{i:}^{\mathrm T}\mathrm{d}{P_{i:}} = P_{i:}^{\mathrm T}(VdO_{i:}) = (P_{i:}^{\mathrm T}V)dO_{i:} = O_{i:}^{\mathrm T}dO_{i:}\notag\\
\end{aligned}
\end{cases}\end{split}\]
then for all rows, we compute \(\Delta = \mathrm{rowsum}(O\odot dO)\in \mathbb{R}^{B_q}\) during preprocessing, so we can avoid massive matrix computing like \(P_{i:}P_{i:}^{\mathrm T} \in \mathbb{R}^{B_k\times B_k}\)
Step2. Swapped Double-Loop Tiling with Recomputation
the outer loop runs through \(j := 1 \rightarrow N_k\) for each block of \(K_j, V_j\) to compute \(dK_j, dV_j\), where \(N_k = \lceil\frac{N}{B_k}\rceil\), and for each \(j\)-th outer iteration:
(29)\[\begin{split}\begin{cases}
\begin{aligned}
&\text{load}\space K_j, V_j \in \mathbb{R}^{B_k\times d}\space \text{from HBM to SRAM, }\\
&\text{and initialize}\space dK_j^{(0)}, dV_j^{(0)} = (0)_{B_c\times d} \in \mathbb{R}^{B_k\times d} \notag \\
\\
&\text{loop over}\space i := 1 \rightarrow N_q\space \text{, and for each }\space i \text{-th inner iteration: } \notag \\
&\quad\text{write}\space dK_j = dK_j^{(N_q)}, dV_j = dV_j^{(N_q)} \space \text{back to HBM to return as output} \notag
\end{aligned}
\end{cases}\end{split}\]
in which each inner loop goes across \(i := 1 \rightarrow N_q\) for each block of \(Q_i, dO_i\) to update \(dQ_i, dK_j^{(i)}, dV_j^{(i)}\), where \(N_q = \lceil\frac{N}{B_q}\rceil\), and for each \(i\)-th inner iteration:
(30)\[\begin{split}\begin{cases}
\begin{aligned}
&\text{load}\space Q_i, dO_i, \mathrm{LSE_i}, \Delta_i\space \text{from HBM to SRAM} \notag \\
&\text{recompute}\space S_j^{(i)} = \mathrm{mask}(Q_iK_j^{\mathrm{T}} + bias_{(i,j)}) \in \mathbb{R}^{B_q\times B_k} \notag \\
&\text{recompute}\space P_j^{(i)} = \exp(S_j^{(i)} - \mathrm{LSE_i}) \in \mathbb{R}^{B_q\times B_k} \notag \\
&\text{update}\space dV_j^{(i)} = dV_j^{(i-1)} + (P_j^{(i)})^{\mathrm T} dO_i \in \mathbb{R}^{B_k\times d} \notag \\
&\text{compute}\space dP_j^{(i)} = dO_iV_j^{\mathrm T} \in \mathbb{R}^{B_q\times B_k} \notag \\
&\text{compute}\space dS_j^{(i)} = P_j^{(i)}\odot (dP_j^{(i)} - \Delta_i) \in \mathbb{R}^{B_q\times B_k} \notag \\
&\text{update}\space dK_j^{(i)} = dK_j^{(i-1)} + (dS_j^{(i)})^{\mathrm T} Q_i \in \mathbb{R}^{B_k\times d} \notag \\
&\text{update}\space dQ_i \stackrel{atomic\space add}\longleftarrow dS_j^{(i)}K_j \in \mathbb{R}^{B_q\times d} \notag
\end{aligned}
\end{cases}\end{split}\]
Citation
If you find MagiAttention useful in your research, please cite:
@misc{magiattention2025,
title={MagiAttention: A Distributed Attention Towards Linear Scalability for Ultra-Long Context, Heterogeneous Mask Training},
author={Zewei, Tao and Yunpeng, Huang},
year={2025},
howpublished={\url{https://github.com/SandAI-org/MagiAttention/}},
}
References
[1]
Tri Dao. Flashattention-2: faster attention with better parallelism and work partitioning. arXiv preprint arXiv:2307.08691, 2023.
[2]
Tri Dao, Dan Fu, Stefano Ermon, Atri Rudra, and Christopher Ré. Flashattention: fast and memory-efficient exact attention with io-awareness. Advances in Neural Information Processing Systems, 35:16344–16359, 2022.
[5]
Tao Zewei and Huang Yunpeng. Magiattention: a distributed attention towards linear scalability for ultra-long context, heterogeneous mask training. SandAI-org/MagiAttention, 2025.