FFA with Attention Sink
Integrating Flex-Flash-Attention with Attention Sink
Introduction
Large-Scaled Models (LMs) assign significant attention to few tokens (such as the intial tokens in the sequence), even if they are not semantically important, which is known as attention sink
However, since sink tokens only affect the $softmax$ operation during the attention forward/backward passes w.r.t. the GPT-OSS implementation
Overview
With the release of MagiAttention-v1.0.5, we have not only supported the learnable attention sink mechanism for our own kernel / distributed implementations of Flex-Flash-Attention / MagiAttention respectively, but also provided the plug-and-play implementations to integrate the original Flash Attention 2/3 interface
In this blog, we will share our own methods about how to integrate the attention implementations in the Flash-Attention style with the learnable attention sink mechanism, including:
- the User Interface update for Flex-Flash-Attention, MagiAttention and Flash-Attention Extension.
- the Math Derivation of applying the attention sink in both forward and backward passes of Flex-Flash-Attention.
- the Implementations of the (distributed) learnable attention sink mechanism for Flex-Flash-Attention and MagiAttention, as well as the naive Torch Reference.
User Interface
Below, we show the minor update of the user interfaces to support learnable attention sink mechanism for original Flex-Flash-Attention, MagiAttention, as well as the Flash-Attention 2/3 as one of the MagiAttention Extensions.
FFA API
- Just add an optional tensor
sinkto the argument list offlex_flash_attn_func. - And when and only when
sinktensor is given,flex_flash_attn_funcwill apply attention sink during the forward pass, and computedsink(the gradient ofsink) during the backward pass. - Otherwise, attention sink is skipped and
dsinkis also returned asNone. - dtype:
float32only. - shape:
[seqlen_sink, num_heads_q], whereseqlen_sinkin[1, 8]. - interface difference with the original
flex_flash_attn_func:
def flex_flash_attn_func(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
q_ranges: torch.Tensor,
k_ranges: torch.Tensor,
attn_type_map: torch.Tensor | None = None,
+ sink: torch.Tensor | None = None,
softmax_scale: float | None = None,
softcap: float = 0.0,
deterministic: bool = False,
sm_margin: int = 0,
disable_fwd_atomic_reduction: bool = False,
auto_range_merge: bool = False,
ref_block_size: tuple[int, int] | None = None,
profile_mode: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
...
MagiAttn API
- Just add an optional replicated tensor
sinkto the argument list ofcalc_attn. - And when and only when replicated
sinktensor is given,calc_attnwill apply attention sink during the forward pass for each local query token, and compute partialdsinkduring the backward pass. - And an
all-reducecommunication might be applied across cp ranks to return the reduceddsinkif required (see the environment variableMAGI_ATTENTION_DSINK_ALL_REDUCE_OPin our docs). - Otherwise, attention sink is skipped and
dsinkis also returned asNone. - dtype:
float32only. - shape:
[seqlen_sink, num_heads_q], whereseqlen_sinkin[1, 8]. - parallel style:
Replicate. - interface difference with the original
calc_attn:
def calc_attn(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
key: DistAttnRuntimeKey,
+ sink: torch.Tensor | None = None,
softmax_scale: float | None = None,
softcap: float = 0.0,
) -> tuple[torch.Tensor, torch.Tensor]:
...
Flash Attention Extension
- Just add an optional tensor
sinkto the argument list offlash_attn_func,flash_attn_varlen_func, etc. - And when and only when
sinktensor is given, flash attention will apply attention sink during the forward pass, and computedsinkduring the backward pass. - Otherwise, attention sink is skipped and
dsinkis also returned asNone. - dtype:
float32only. - shape:
[seqlen_sink, num_heads_q], whereseqlen_sinkhas no limit. - interface difference with the original flash attention:
- def flash_attn_func(
+ def flash_attn_func_with_sink(
q,
k,
v,
+ sink=None,
softmax_scale=None,
causal=False,
qv=None,
q_descale=None,
k_descale=None,
v_descale=None,
window_size=(-1, -1),
attention_chunk=0,
softcap=0.0,
num_splits=1,
pack_gqa=None,
deterministic=False,
sm_margin=0,
return_attn_probs=False,
):
...
- def flash_attn_varlen_func(
+ def flash_attn_varlen_func_with_sink(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
+ sink=None,
seqused_q=None,
seqused_k=None,
softmax_scale=None,
causal=False,
qv=None,
q_descale=None,
k_descale=None,
v_descale=None,
window_size=(-1, -1),
attention_chunk=0,
softcap=0.0,
num_splits=1,
pack_gqa=None,
deterministic=False,
sm_margin=0,
return_attn_probs=False,
):
...
Math Derivation
Below, we provide the step-by-step math derivation of the original forward / backward passes for Flex-Flash-Attention (the same as Flash-Attention) w/o sink tokens, and then the differences when involving the learnable attention sink mechanism, serving as the guidence for our implementations in the next section.
NOTE:
1. To simplify the derivation, we drop the batch dimension and only keep the num_heads dimension to the leftmost acting as the implicit batch dimension.
2. To focus on the attention sink mechanism, we assume you’re already familiar with Flash Attention and will skip over its finer details, like the double-loop tiling strategy and the derivation of online softmax correction based on log-sum-exp operations.
3. If you are new to Flash Attention or well-interested in the full original math derivation, we highly recommend this blog post: Flash Attention 2 Math Derivation.
Symbol Notation:
| symbol | notation |
|---|---|
| $\times$ | matrix multiplication |
| $\cdot$ | scalar multiplication |
| $\odot$ | element-wise multiplication (Hadamard product) |
| $sq, sk, s\_sink$ | the sequence length of query tokens, key tokens, and attention sink tokens |
| $nhq, nhk$ | the number of heads of query tokens and key tokens |
| $hd$ | the head dimension of query, key and value tokens |
| $X_i$ | the column vector made by the $i$-th row of matrix $X$ along the sequence dimension |
FFA Forward
FFA forward w/o sink tokens
- step1:
- step2:
- step3:
FFA forward with sink tokens
-
step1: (the same)
-
step2:
- step3:
- sink correction: (as a post-processing of original ffa forward w/o sink tokens)
FFA Backward
FFA backward w/o sink tokens
- step1: (as a pre-processing)
- step2:(recomputation)
- step3:
- step4:
- step5:
FFA backward with sink tokens
- step1: (as a pre-processing as well)
- step2:(recomputation)
- step3:
- step4:
-
step5: (the same)
-
dsink computation: (as another pre-processing of original ffa backward w/o sink tokens)
Implementations
Based on the math derivation in the previous section, folding a learnable attention sink into the attention implementations in the Flash Attention style boils down to just two edits:
- For forward pass, we have nothing to change about the original implementation, but should apply an additional post-processing to correct the returned
outandlsewithsinktokens (see the sink correction of the FFA forward with sink tokens). - For backward pass, we have nothing to change about the original implementation, but should apply an additional pre-processing to compute the
dsink, i.e. the gradient ofsink(see the dsink computation of the FFA backward with sink tokens).
Therefore, we share the following code snippets to present our implementations of the learnable attention sink mechanism: a naive PyTorch reference, Flex-Flash-Attention (both internal and external to the kernels, which fit Flash Attention as well), and the distributed implementation of MagiAttention.
Torch Reference
- reference implementation w/o sink tokens:
# apply `S = Q x K.T * scale + bias`
# where S.shape = [nhq, sq, sk]
s = q @ k.transpose(-2, -1) * softmax_scale + bias
# apply row-wise lse `LSE = logsumexp(S, dim=-1)`
# where LSE.shape = [nhq, sq, 1]
lse = s.logsumexp(dim=-1, keepdim=True)
# apply row-wise softmax `P = softmax(S, dim=-1)`
# where P.shape = [nhq, sq, sk]
p = softmax(s).to(q.dtype)
# apply `O = P x V`
# where O.shape = [nhq, sq, d]
out = p @ v
return out, lse
- reference implementation difference with sink tokens:
# apply `S = Q x K.T * scale + bias`
# where S.shape = [nhq, sq, sk]
s = q @ k.T * softmax_scale + bias
+ # apply `S = S.concat(sink, dim=-1)`
+ # where S.shape = [nhq, sq, sk + s_sink]
+ s = torch.concat([s, sink], dim=-1)
# apply row-wise lse `LSE = logsumexp(S, dim=-1)`
# where LSE.shape = [nhq, sq, 1]
lse = s.logsumexp(dim=-1, keepdim=True)
# apply row-wise softmax `P = softmax(S, dim=-1)`
- # where P.shape = [nhq, sq, sk]
+ # where P.shape = [nhq, sq, sk + s_sink]
p = softmax(s).to(q.dtype)
+ # apply `P = P.drop(sink, dim=-1)`
+ # where P.shape = [nhq, sq, sk]
+ p = p[..., : -sink.size(dim=-1)]
# apply `O = P x V`
# where O.shape = [nhq, sq, d]
out = p @ v
return out, lse
FFA Impl
FFA Forward Impl
External Impl
- Use sink correction to correct
out,lseafter the ffa forward kernel returns, as an external post-processing kernel (which is the way we extend the Flash Attention 2/3 forward with sink tokens, and see the source code for more detals):
# given sink with shape: [s_sink, nhq]
# calculate and repeat to lse_sink with shape: [sq, nhq]
lse_sink = sink.logsumexp(dim=0, keepdim=True).repeat(sq, 1)
# given ffa returned lse with shape: [sq, nhq]
# correct lse with lse_sink
corrected_lse = log(exp(lse) + exp(lse_sink))
# given ffa returned out with shape: [sq, nhq, hd]
# correct out with corrected_lse and original lse
out *= exp(lse - corrected_lse)
return out, lse
Internal Impl
-
Since FFA forward already has a post-processing kernel
FlashAttnFwdPostprocessto zero-fill up the never-stored rows ofO, indicated by “whether the corr. row oflseis still-inf”, … -
Then we can fuse the sink correction process into the
FlashAttnFwdPostprocesskernel as follows (see the source code for more details):- As for lse correction:
- If the current row of
lseis not-inf, then we update this row oflsewithlse_sink. - Otherwise, the
lseshould also be filled up withlse_sink, instead of-inf.
- If the current row of
- As for out correction:
- If the current row of
lseis not-inf, then load the corr. row ofO, rescale it and write it back. - Otherwise, the corr. row of
Ostill needs to be filled up with0, so the same as before.
- If the current row of
- As for lse correction:
FFA Backward Impl
External Impl
- Use dsink computation to compute dsink before the ffa backward kernel launchs, as an external pre-processing kernel (which is the way we extend the Flash Attention 2/3 backward with sink tokens, and see the source code for more detals):
# calculate delta = (o * do).sum(dim=-1)
# where o.shape = [sq, nhq, d]
# do.shape = [sq, nhq, d]
# delta.shape = [nhq, sq, 1]
delta = reduce((o * do).to(lse.dtype), "sq hq d -> hq sq 1", "sum")
# calculate p_sink = exp(sink - lse)
# where sink.shape = [nhq, sq, s_sink]
# lse.shape = [nhq, sq, 1]
# p_sink.shape = [nhq, sq, s_sink]
p_sink = torch.exp(sink - lse)
# calculate dsink = p_sink.T x -delta
# where p_sink.shape = [nhq, sq, s_sink]
# delta.shape = [nhq, sq, 1]
# dsink.shape = [s_sink, nhq]
dsink = reduce(p_sink * -delta, "nhq sq s_sink -> s_sink nhq", "sum")
return dsink
Internal Impl
-
Since FFA backward already has a pre-processing kernel
FlashAttnBwdPreprocessto compute $\Delta$ (in FA / FFA, we name itdPsum), w.r.t. the step1 in the FFA backward w/o sink tokens, … -
The we can fuse the dsink computation process into the
FlashAttnBwdPreprocesskernel as follows (see the source code for more details):-
As for
lse, the same as before, each thread in one block loads one unique row oflse. -
As for
p_sink, the firstseqlen_sinkof threads in one block load thesinkto shared memory, and each thread computesp_sink = exp(sink - lse)with its own unique row oflse, storing to shared memory as well. -
As for
dPsum, the same as before, each block loads a uniquekBlockMrows ofOanddO, appliesO * dO, reduces across the head dimension to get the local block ofdPsumin register files, and stores it to global memory. -
As for
d_sink, since it requires to be reduced across the wholeseqlen_qdimension, the following steps are performed:- step1: each thread loads a unique row of
dPsumfrom register files and the corr. row ofp_sinkfrom shared memory, and computes thread-partialdsink = p_sink * -dPsumfor this row, and stores to shared memory first (sincep_sinkis not used afterwards, we can reuse its shared memory buffer to storedsink). - step2: each block loads all the thread-partial
dsinkfrom shared memory, applies ablock-reductionto get the block-reduceddsinkfor thesekBlockMrows, and stores it to a temporary buffer in global memory. - step3: after a device-level memory fence, the last block who stores its block-reduced
dsinkloads all the block-reduceddsinkback from the temporary buffer, applies anotherblock-reductionto get the reduceddsinkacross the wholeseqlen_qdimension, and finally stores it to global memory.
- step1: each thread loads a unique row of
-
MagiAttn Impl
MagiAttn Forward
- Since
sinkis replicated across cp ranks, we can easily apply attention sink by just passingsinkinto_flex_flash_attn_forward. - However, the attention sink is supposed to be applied once and only once for the same query token, thus we can apply it at the host stage, i.e. each cp rank only applies to their own local
q. - Then, If the host stage is not skipped, just apply attention sink by passing
sinkinto_flex_flash_attn_forward:
partial_out, partial_lse = _flex_flash_attn_forward(
q=q,
k=k,
v=v,
+ # NOTE: sink token needs to be applied only once
+ # thus we only apply it at the host stage if not skipped
+ sink=sink if is_host_stage else None,
out=out_acc,
lse=lse_acc,
**attn_arg.to_ffa_args(is_bwd=False),
...
)
- Otherwise, we should zero-initialize
local_outas before, but initializelocal_lsewithlse_sink, instead of-inf
out = torch.zeros_like(
q,
dtype=torch.float32,
device=q.device,
)
+ if sink is not None:
+ # in skipped host stage if sink is given,
+ # we directly use lse_sink to initialize lse
+ lse = calc_lse_sink(
+ sink=sink,
+ seqlen_lse=q.size(0),
+ )
+ else:
lse = torch.full(
(q.size(0), q.size(1)),
fill_value=float("-inf"),
dtype=float32,
device=q.device,
)
return out, lse
MagiAttn Backward
- The same to the forward, to form a complete, non-overlapping breakdown of
dsinkcomputation, we can compute partialdsinkby just passingsinkinto_flex_flash_attn_backwardonly at the host stage, if not skipped.
(
partial_dq,
partial_dk,
partial_dv,
+ partial_dsink,
) = _flex_flash_attn_backward(
dout=do,
q=q,
k=k,
v=v,
+ # NOTE: dsink should be computed only once
+ # thus we only compute it at the host stage if not skipped
+ sink=sink if is_host_stage else None,
out=o,
lse=lse,
dq=dq_acc,
dk=partial_dk,
dv=partial_dv,
+ dsink=None, # let kernel initialize dsink if required
**attn_arg.to_ffa_args(is_bwd=True),
...
)
- And according to the formula of dsink computation,
dsinkis required to be sum-reduced along theseqlen_qdim, therefore, to get the reduceddsinkfor each cp rank, we have to additionally launch an all-reduce communication withReduceOp.Sum, and wait it to complete before returning from the backward. - However, the tricky thing is that during the acutal training scenario, the learnable
sinktensor will be considered as a regular parameter in the model similar tobiasinnn.Linearlayer. So under some popular training frameworks, such asMegatron-LM,FSDP, the sum-reduction across cp ranks of the partial gradients ofsinkmight be automatically applied within the wholedp x cpmesh. - To avoid repeated reduction, we provide the environment variable
MAGI_ATTENTION_DSINK_ALL_REDUCE_OPto let the user specify the all-reduce op fordsinkwithin MagiAttention (see the docs for more details). Defaults tononeto NOT apply any reduction todsinkand let the framework handle it. Other options includesumandavgif needed.
+ # after the host stage when the partial dsink is ready
+ work = dist.all_reduce(
+ dsink,
+ op=dsink_reduce_op, # specified by `MAGI_ATTENTION_DSINK_ALL_REDUCE_OP`
+ group=self.cp_group_gc,
+ async_op=True,
+ )
...
+ # before returning from the backward
+ work.wait()
...
- return dq, dk, dv, ...
+ return dq, dk, dv, dsink, ...
Citation
If you find MagiAttention useful in your research, please cite:
@misc{magiattention2025,
title={MagiAttention: A Distributed Attention Towards Linear Scalability for Ultra-Long Context, Heterogeneous Mask Training},
author={Zewei, Tao and Yunpeng, Huang},
year={2025},
howpublished={\url{https://github.com/SandAI-org/MagiAttention/}},
}