Flash Attention 2 Math Derivation#

This blog post is a detailed math derivation of well-known Flash Attention 2 (FA2), a memory-efficient, highly optimized and de facto kernel implementation [Dao, 2023, Dao et al., 2022, Shah et al., 2024] of scaled dot-product attention operation introduced by Transformer [Vaswani et al., 2023], which is re-implemented and further extended in Flex-Flash-Attention kernels of MagiAttention [Zewei and Yunpeng, 2025].

Note

  1. We omit specific softmax strategies, e.g. softmax_scale, softcap, attention_sink, for simplicity.

  2. We omit any batch dimensions, e.g. batch_size, num_heads, but keep only the seqlen dimension and the head dimension for simplicity.

Forward#

Standard Attention Forward#

(20)#\[\begin{split}\begin{cases} \begin{aligned} &S = \mathrm{mask}(QK^{\mathrm{T}} + bias) \in \mathbb{R}^{N\times N} \\ &P = \mathrm{softmax}_{row\text{-}wise}(S) = \mathrm{diag}(l)^{-1}A \in \mathbb{R}^{N\times N},\\ &\quad \text{where}\; l = \mathrm{rowsum}(A) \in \mathbb{R}^{N}, \space A = \exp{(S - \mathrm{rowmax}(S))} \in \mathbb{R}^{N\times N} \\ &O = PV \in \mathbb{R}^{N\times d} \end{aligned} \end{cases}\end{split}\]
\[given\quad Q,K,V \in \mathbb{R}^{N\times d}, \space bias \in \mathbb{R}^{N\times N}\]

Flash Attention Forward#

Step1. Basic Row Decomposition#

(21)#\[\begin{split}\begin{cases} \begin{aligned} &S = \left[ S_1\quad S_2 \right] \in \mathbb{R}^{B_q\times 2B_k},\\ &\quad\text{where}\; S_i = \mathrm{mask}(QK_i^{\mathrm{T}} + \text{bias}_{i}) \in \mathbb{R}^{B_q\times B_k},\\ &\quad Q \in \mathbb{R}^{B_q\times d},\ K_i \in \mathbb{R}^{B_k\times d},\ i \in \{1,2\} \\ &m = \max\left( \mathrm{rowmax}(S_1), \mathrm{rowmax}(S_2) \right) \in \mathbb{R}^{B_q} \\ &A = \left[ A_1\quad A_2 \right] \in \mathbb{R}^{B_q\times 2B_k},\\ &\quad\text{where}\; A_i = \exp(S_i - m) \in \mathbb{R}^{B_q\times B_k},\ i \in \{1,2\} \\ &l = \mathrm{rowsum}(A_1) + \mathrm{rowsum}(A_2) \in \mathbb{R}^{B_q} \\ &P = \left[ P_1\quad P_2 \right] = \mathrm{diag}(l)^{-1} \left[ A_1\quad A_2 \right] \in \mathbb{R}^{B_q\times 2B_k} \\ &O = \left[ P_1\quad P_2 \right] \left[ \begin{matrix} V_1 \\ V_2 \end{matrix} \right] = \mathrm{diag}(l)^{-1} \left( A_1V_1 + A_2V_2 \right) \in \mathbb{R}^{B_q\times d} \end{aligned} \end{cases}\end{split}\]

Step2. Online Softmax Correction#

(22)#\[\begin{split}\text{base}: \begin{cases} \begin{aligned} &m_1 = \mathrm{rowmax}(S_1) \in \mathbb{R}^{B_q}\notag\\ &A_1 = \exp(S_1 - m_1) \in \mathbb{R}^{B_q\times B_k}\notag\\ &l_1 = \mathrm{rowsum}(A_1)\in \mathbb{R}^{B_q}\notag\\ &P_1 = \mathrm{diag}(l_1)^{-1}A_1\in \mathbb{R}^{B_q\times B_k}\notag\\ &O_1 = P_1V_1\in \mathbb{R}^{B_q\times d}\notag \end{aligned}\\ \end{cases}\end{split}\]
(23)#\[\begin{split}\text{update}: \begin{cases} \begin{aligned} &m_2 = \max(m_1, \mathrm{rowmax}(S_2)) \in \mathbb{R}^{B_q}\\ &A_2 = \exp(S_2 - m_2) \in \mathbb{R}^{B_q\times B_k}\notag\\ &l_2 = \delta_m l_1 + \mathrm{rowsum}(A_2)\in \mathbb{R}^{B_q}\\ &P_2 = \mathrm{diag}(l_2)^{-1}A_2\in \mathbb{R}^{B_q\times B_k}\notag\\ &O_2 = \mathrm{diag}(l_1/l_2)^{-1}\delta_m O_1 + P_2V_2 \in \mathbb{R}^{B_q\times d} \notag \end{aligned} \end{cases}\end{split}\]
\[\begin{aligned} &\text{where}\; \delta_m := \exp(m_1 -m_2) \end{aligned}\]

Step3. Double-Loop Tiling#

  • the outer loop runs through \(i := 1 \rightarrow N_q\) for each block of \(Q_i\) to compute \(O_i\), where \(N_q = \lceil\frac{N}{B_q}\rceil\), and for each \(i\)-th outer iteration:

(24)#\[\begin{split}\begin{cases} \begin{aligned} &\text{load}\space Q_i \in \mathbb{R}^{B_q\times d}\space \text{from HBM to SRAM}\notag\\ &\text{initialize}\space \tilde{O_{i}}^{(0)} = 0_{ B_q\times d },\space l_i^{(0)} = 0_{B_q} \in \mathbb{R}^{B_q},\space m_i^{(0)} = -\infty_{B_q} \in \mathbb{R}^{B_q} \notag\\ \\ &\text{loop over}\space j := 1 \rightarrow N_k\space \text{, and for each}\space j \text{-th inner iteration:} \notag\\ &\quad\text{compute}\space O_i = \mathrm{diag}(l_{i}^{(N_k)})^{-1} \tilde{O_i}^{(N_k)}\in \mathbb{R}^{B_q\times d}\\ &\quad\quad\text{and write it to HBM to return as output} \notag\\ &\quad\text{compute}\space \mathrm{LSE_i} = m_i^{(N_k)} + \log(l_i^{(N_k)})\in \mathbb{R}^{B_q}\\ &\quad\quad\text{and write it to HBM to save for backward} \notag \end{aligned} \end{cases}\end{split}\]
\[\begin{split}\begin{aligned} &\text{where}\; \text{LSE}( \mathbf{x}) := \log\left(\sum\limits_{i=1}^n \exp(x_i)\right) = \max( \mathbf x) + \text{LSE}( \mathbf{x}-\max( \mathbf x)),\space \mathbf x \in \mathbb{R}^{n},\\ &\quad\text{and}\space \tilde{O_i} \space\text{is the un-normalized} \space O_i, \space\text{i.e.}\space O_i = \mathrm{diag}(l_{i})^{-1}\tilde{O_i} \end{aligned}\end{split}\]
  • in which each inner loop goes across \(j := 1 \rightarrow N_k\) for each block of \(K_j,V_j\) to update \(\tilde{O_i}^{(j)}, l_i^{(j)}, m_i^{(j)}\), where \(N_k = \lceil\frac{N}{B_k}\rceil\), and for each \(j\)-th inner iteration:

(25)#\[\begin{split}\begin{cases} \begin{aligned} &\text{load}\space K_j, V_j \in \mathbb{R}^{B_k\times d}\space \text{from HBM to SRAM} \notag\\ &\text{compute}\space S_{i}^{(j)} = \text{mask}(Q_iK_j^{\mathrm T} + bias_{(i,j)}) \in \mathbb{R}^{B_q\times B_k} \notag\\ &\text{update}\space m_i^{(j)} = \max\big(m_i^{(j-1)}, \mathrm{rowmax}(S_{i}^{(j)})\big) \in \mathbb{R}^{B_q} \notag\\ &\text{compute}\space A_i^{(j)} = \exp(S_i^{(j)} - m_i^{(j)}) \in \mathbb{R}^{B_q\times B_k} \notag\\ &\text{update}\space l_i^{(j)} = \delta_{m_i^{(j)}}l_i^{(j-1)} + \mathrm{rowsum}(A_i^{(j)})\in \mathbb{R}^{B_q} \notag\\ &\text{update}\space \tilde{O_i}^{(j)} = \mathrm{diag}(\delta_{m_i^{(j)}})^{-1}\tilde{O_i}^{(j-1)} + A_i^{(j)}V_j\in \mathbb{R}^{B_q\times d} \notag \end{aligned} \end{cases}\end{split}\]
\[\begin{aligned} &\text{where}\; \delta_{m_i^{(j)}} := \exp(m_i^{(j-1)} -m_i^{(j)}) \end{aligned}\]

Backward#

Standard Attention Backward#

(26)#\[\begin{split}\begin{cases} \begin{aligned} &\mathrm{d}{V} = P^{\mathrm T} \mathrm{d}{O} \in \mathbb{R}^{N\times d}, \quad \mathrm{d}{P} = \mathrm{d}{O}V^{\mathrm T} \in \mathbb{R}^{N\times N} \notag \\ &\mathrm{d}{S_{i:}} = \cfrac{\partial P_{i:}}{\partial S_{i:}}\cdot\mathrm{d}{P_{i:}}\in \mathbb{R}^{N}, \\ &\quad where\space \cfrac{\partial P_{i:}}{\partial S_{i:}} = J_{softmax} = \mathrm{diag}(P_{i:}) - P_{i:}P_{i:}^{\mathrm T} \in \mathbb{R}^{N\times N} \notag \\ &\mathrm{d}{Q} = \mathrm{d}{S}K \in \mathbb{R}^{N\times d}, \quad \mathrm{d}{K} = \mathrm{d}{S}^{\mathrm T}Q \in \mathbb{R}^{N\times d} \notag \end{aligned} \end{cases}\end{split}\]
\[\begin{split}\begin{aligned} &\text{where}\space\space \mathrm{d}X \space\space\text{denotes}\space \cfrac{\partial{\mathbb{loss}}}{\partial{X}}, \space\text{and}\space X_{i:} \space\text{denotes the column vector}\\ &\text{made of the $i$-th row of}\space X, \space\text{for any matrix}\space X \end{aligned}\end{split}\]
\[given\quad \mathrm{d}{O} \in \mathbb{R}^{N\times d}\]

Flash Attention Backward#

Step0. Save LSE during forward#

for each \(i\)-th row:

(27)#\[\begin{split}\begin{cases} \begin{aligned} &\text{since}\space P_{i:} = \cfrac{A_{i:}}{l_{i:}} \in \mathbb{R}^{B_k}, \; l_{i} = \mathrm{sum}(A_{i:}) \in \mathbb{R}, \\ &\quad\quad A_{i:} = \exp(S_{i:} - m_{i}) \in \mathbb{R}^{B_k}, \; m_{i} = \max(S_{i:})\in \mathbb{R} \notag\\ &\text{therefore}\space P_{i:} = \cfrac{\exp(S_{i:} - m_{i})}{\mathrm{sum}(\exp(S_{i:} - m_{i}))} = \cfrac{\exp(S_{i:} - m_{i})}{\exp(\mathrm{LSE}(S_{i:} - m_{i}))}\\ &\quad\quad\quad\quad = \exp(S_{i:} - (m_{i} + \mathrm{LSE}(S_{i:} - m_i))) \notag\\ \\ &\text{and according to}\space \text{LSE}( \mathbf{x}) = \max( \mathbf x) + \text{LSE}( \mathbf{x}-\max( \mathbf x)), \notag\\ &\text{therefore}\space P_{i:} = \exp(S_{i:} - (m_{i} + \mathrm{LSE}(S_{i:} - m_i)))\\ &\quad\quad\quad\quad = \exp(S_{i:} - \mathrm{LSE}(S_{i:})) = \exp(S_{i:} - \mathrm{LSE_i})\notag \end{aligned} \end{cases}\end{split}\]

so we can jump storing \(m_i, l_i\) to compute \(A_{i:}\), but computing \(P_{i:}\) from \(S_{i:}\) directly with only \(\mathrm{LSE_i}\)

Step1. Compute Delta as a Pre-Processing#

for each \(i\)-th row:

(28)#\[\begin{split}\begin{cases} \begin{aligned} &\text{since}\space \mathrm{d}{S_{i:}} = \cfrac{\partial P_{i:}}{\partial S_{i:}}\cdot\mathrm{d}{P_{i:}} = (\mathrm{diag}(P_{i:}) - P_{i:}P_{i:}^{\mathrm T} )\cdot\mathrm{d}{P_{i:}}\\ &\quad\quad = P_{i:}\odot\mathrm{d}{P_{i:}} - (P_{i:}P_{i:}^{\mathrm T})\mathrm{d}{P_{i:}} \in \mathbb{R}^{B_k}\notag\\ &\text{then}\space \mathrm{d}{S_{i:}} = P_{i:}\odot\mathrm{d}{P_{i:}} - P_{i:}(P_{i:}^{\mathrm T}\mathrm{d}{P_{i:}}) = P_{i:}\odot\mathrm{d}{P_{i:}} - (P_{i:}^{\mathrm T}\mathrm{d}{P_{i:}})P_{i:}\notag\\ \\ &\text{define}\space \Delta_{i} = P_{i:}^{\mathrm T}\mathrm{d}{P_{i:}} \in \mathbb{R},\\ &\text{and because}\space \mathrm{d}{P_{i:}} = (\mathrm{d}{O_{i:}}^{\mathrm T}V^{\mathrm T})^{\mathrm T} = VdO_{i:} \in \mathbb{R}^{B_k}\notag\\ &\text{therefore}\space \Delta_{i} = P_{i:}^{\mathrm T}\mathrm{d}{P_{i:}} = P_{i:}^{\mathrm T}(VdO_{i:}) = (P_{i:}^{\mathrm T}V)dO_{i:} = O_{i:}^{\mathrm T}dO_{i:}\notag\\ \end{aligned} \end{cases}\end{split}\]

then for all rows, we compute \(\Delta = \mathrm{rowsum}(O\odot dO)\in \mathbb{R}^{B_q}\) during preprocessing, so we can avoid massive matrix computing like \(P_{i:}P_{i:}^{\mathrm T} \in \mathbb{R}^{B_k\times B_k}\)

Step2. Swapped Double-Loop Tiling with Recomputation#

  • the outer loop runs through \(j := 1 \rightarrow N_k\) for each block of \(K_j, V_j\) to compute \(dK_j, dV_j\), where \(N_k = \lceil\frac{N}{B_k}\rceil\), and for each \(j\)-th outer iteration:

(29)#\[\begin{split}\begin{cases} \begin{aligned} &\text{load}\space K_j, V_j \in \mathbb{R}^{B_k\times d}\space \text{from HBM to SRAM, }\\ &\text{and initialize}\space dK_j^{(0)}, dV_j^{(0)} = (0)_{B_c\times d} \in \mathbb{R}^{B_k\times d} \notag \\ \\ &\text{loop over}\space i := 1 \rightarrow N_q\space \text{, and for each }\space i \text{-th inner iteration: } \notag \\ &\quad\text{write}\space dK_j = dK_j^{(N_q)}, dV_j = dV_j^{(N_q)} \space \text{back to HBM to return as output} \notag \end{aligned} \end{cases}\end{split}\]
  • in which each inner loop goes across \(i := 1 \rightarrow N_q\) for each block of \(Q_i, dO_i\) to update \(dQ_i, dK_j^{(i)}, dV_j^{(i)}\), where \(N_q = \lceil\frac{N}{B_q}\rceil\), and for each \(i\)-th inner iteration:

(30)#\[\begin{split}\begin{cases} \begin{aligned} &\text{load}\space Q_i, dO_i, \mathrm{LSE_i}, \Delta_i\space \text{from HBM to SRAM} \notag \\ &\text{recompute}\space S_j^{(i)} = \mathrm{mask}(Q_iK_j^{\mathrm{T}} + bias_{(i,j)}) \in \mathbb{R}^{B_q\times B_k} \notag \\ &\text{recompute}\space P_j^{(i)} = \exp(S_j^{(i)} - \mathrm{LSE_i}) \in \mathbb{R}^{B_q\times B_k} \notag \\ &\text{update}\space dV_j^{(i)} = dV_j^{(i-1)} + (P_j^{(i)})^{\mathrm T} dO_i \in \mathbb{R}^{B_k\times d} \notag \\ &\text{compute}\space dP_j^{(i)} = dO_iV_j^{\mathrm T} \in \mathbb{R}^{B_q\times B_k} \notag \\ &\text{compute}\space dS_j^{(i)} = P_j^{(i)}\odot (dP_j^{(i)} - \Delta_i) \in \mathbb{R}^{B_q\times B_k} \notag \\ &\text{update}\space dK_j^{(i)} = dK_j^{(i-1)} + (dS_j^{(i)})^{\mathrm T} Q_i \in \mathbb{R}^{B_k\times d} \notag \\ &\text{update}\space dQ_i \stackrel{atomic\space add}\longleftarrow dS_j^{(i)}K_j \in \mathbb{R}^{B_q\times d} \notag \end{aligned} \end{cases}\end{split}\]

Citation#

If you find MagiAttention useful in your research, please cite:

@misc{magiattention2025,
  title={MagiAttention: A Distributed Attention Towards Linear Scalability for Ultra-Long Context, Heterogeneous Mask Training},
  author={Zewei, Tao and Yunpeng, Huang},
  year={2025},
  howpublished={\url{https://github.com/SandAI-org/MagiAttention/}},
}

References#

[1]

Tri Dao. Flashattention-2: faster attention with better parallelism and work partitioning. arXiv preprint arXiv:2307.08691, 2023.

[2]

Tri Dao, Dan Fu, Stefano Ermon, Atri Rudra, and Christopher Ré. Flashattention: fast and memory-efficient exact attention with io-awareness. Advances in Neural Information Processing Systems, 35:16344–16359, 2022.

[3]

Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, and Tri Dao. Flashattention-3: fast and accurate attention with asynchrony and low-precision. 2024. URL: https://arxiv.org/abs/2407.08608, arXiv:2407.08608.

[4]

Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, and Illia Polosukhin. Attention is all you need. 2023. URL: https://arxiv.org/abs/1706.03762, arXiv:1706.03762.

[5]

Tao Zewei and Huang Yunpeng. Magiattention: a distributed attention towards linear scalability for ultra-long context, heterogeneous mask training. SandAI-org/MagiAttention, 2025.