API Reference#
Flexible Flash Attention#
To support computing irregular-shaped masks, we implemented a flexible_flash_attention
kernel, which can be invoked through the following interface.
- magi_attention.functional.flex_flash_attn.flex_flash_attn_func()#
An interface similar to flash attention that doesn’t require distributed environment, dispatch or undispatch. Directly call magi_attn_kernel to get attention output and lse. This is faster when you don’t need context parallel.
- Parameters:
q (torch.Tensor) – Query tensor.
k (torch.Tensor) – Key tensor.
v (torch.Tensor) – Value tensor.
q_ranges (torch.Tensor) – query ranges tensor to represent the attn mask.
k_ranges (torch.Tensor) – key ranges tensor to represent the attn mask.
max_seqlen_q (int) – Maximum sequence length of q_ranges.
max_seqlen_k (int) – Maximum sequence length of k_ranges.
attn_type_map (torch.Tensor) –
Attention type map tenspr with dtype=torch.int32. The values specify the attention type for each token:
0: full attention
1: causal attention
2: inverse causal attention
3: bidirectional causal attention
More information about the attention type map can be found in the
Note
below.softmax_scale (float, optional) – Softmax scale, defaults to 1/sqrt(head_dim).
softcap (float, optional) – Softcap value, defaults to 0.
deterministic (bool, optional) – Whether to use deterministic attention, defaults to False.
sm_margin (int, optional) – the amount of SMs(streaming multiprocessors) reserved for communication.
disable_fwd_atomic_reduction (bool) –
Whether to disable forward atomic reduction:
If you can ensure q_ranges has no overlap, you can set this to True for better performance. Overlap in q_ranges is defined as: if any two q_ranges have non-empty intersection, then there is overlap. For example, q_ranges =
[[0, 15], [10, 20], [20, 30]]
has overlap because[0, 15]
and[10, 20]
intersect. While q_ranges =[[0, 15], [15, 20], [20, 30]]
has no overlap.auto_range_merge (bool, optional) –
Whether to automatically merge k_ranges for the same q_range, defaults to False.
Note: This flag is usually used in sparse attention cases but still under development.
- Returns:
out (torch.Tensor): Attention output tensor
lse (torch.Tensor): Log-sum-exp values with dtype=torch.float32.
- Return type:
tuple[torch.Tensor, torch.Tensor]
- Shape:
q: (num_tokens_q, num_heads, head_dim)
k: (num_tokens_kv, num_heads, head_dim)
v: (num_tokens_kv, num_heads, head_dim)
q_ranges: (num_ranges, 2)
k_ranges: (num_ranges, 2)
attn_type_map: (num_ranges, )
out: (num_tokens_q, num_heads, head_dim)
lse: (num_heads, num_tokens_q)
Note
The attn_type_map explains the semantics of different attention mask types. In addition to the descriptions below, see our blog for a visual explanation: https://sandai-org.github.io/MagiAttention/blog/#flex-flash-attn
- Full attention:
If seqlen_q = 5 and seqlen_k = 2:
1 1 1 1 1 1 1 1 1 1
If seqlen_q = 2 and seqlen_k = 5:
1 1 1 1 1 1 1 1 1 1
If seqlen_q = 5 and seqlen_k = 5:
1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
- Causal attention (bottom-right aligned):
If seqlen_q = 5 and seqlen_k = 2:
0 0 0 0 0 0 1 0 1 1
If seqlen_q = 2 and seqlen_k = 5:
1 1 1 1 0 1 1 1 1 1
If seqlen_q = 5 and seqlen_k = 5:
1 0 0 0 0 1 1 0 0 0 1 1 1 0 0 1 1 1 1 0 1 1 1 1 1
- Inverse causal attention (top-left aligned):
If seqlen_q = 5 and seqlen_k = 2:
1 1 0 1 0 0 0 0 0 0
If seqlen_q = 2 and seqlen_k = 5:
1 1 1 1 1 0 1 1 1 1
If seqlen_q = 5 and seqlen_k = 5:
1 1 1 1 1 0 1 1 1 1 0 0 1 1 1 0 0 0 1 1 0 0 0 0 1
- Bidirectional causal attention (intersection of causal and inverse causal):
This is the element-wise AND of causal and inverse causal masks.
If seqlen_q = 5 and seqlen_k = 2:
0 0 0 0 0 0 0 0 0 0
If seqlen_q = 2 and seqlen_k = 5:
1 1 1 1 0 0 1 1 1 1
If seqlen_q = 5 and seqlen_k = 5:
1 0 0 0 0 0 1 0 0 0 0 0 1 0 0 0 0 0 1 0 0 0 0 0 1
Dispatch#
Varlen Dispatch#
If you’re using a mask defined by cu_seqlens
, such as a varlen full or varlen causal mask, we’ve designed a similar interface inspired by FlashAttention’s API, making it easy for you to get started quickly. In the function named magi_attn_varlen_dispatch
, you can obtain the dispatched x
and key
.
- magi_attention.api.magi_attn_interface.magi_attn_varlen_dispatch()#
This is a flash_attn_varlen like interface, to generate q_ranges, k_ranges and attn_mask_type from cu_seqlens_q, cu_seqlens_k and causal flag, further caculate DistAttnRuntimeKey, generate the corr. inner DistAttnRuntimeMgr, finally pad and dispatch the input tensor to local tensor.
- Parameters:
x (torch.Tensor) – input tensor
cu_seqlens_q (torch.Tensor) – Cumulative sequence lengths for queries.
cu_seqlens_k (torch.Tensor) – Cumulative sequence lengths for keys.
pad_size (int) – the size to pad along seq_dim. The seq_len need to be divisable by chunk_size * cp_size,
chunk_size (int) – chunk size to chunk the input tensor x along the seqlen dim for dispatch to control the granularity of computation load-balance.
cp_group_or_mesh (dist.ProcessGroup | DeviceMesh) – process group or device mesh. NOTE: for process group, we only support nccl backend for now, and for device mesh, we only support 1D or 2D mesh for now.
causal (bool) – if True, all attn_mask_type is CAUSAL. else, all attn_mask_type is FULL.
dist_attn_config (DistAttnConfig) – dist attn config.
- Returns:
x (torch.Tensor): the input tensor after padding.
key (DistAttnRuntimeKey): the key points to the inner DistAttnRuntimeMgr.
- Return type:
tuple[torch.Tensor, DistAttnRuntimeKey]
Example
>>> local_x, dist_attn_runtime_key = magi_attn_varlen_dispatch( ... x=torch.randn( ... 4096, # seqlen ... 2048, # hidden_size ... device=device, ... dtype=dtype, ... requires_grad = True ... ), ... cu_seqlen_q=torch.tensor( ... [0, 2048, 4096], dtype=torch.int32 ... ), ... cu_seqlen_k=torch.tensor( ... [0, 2048, 4096], dtype=torch.int32 ... ), ... pad_size=compute_pad_size(4096, 4, 512), # seqlen, cp_size, chunk_size ... chunk_size=512, ... cp_group_or_mesh=dist.new_group(list(range(4)), backend="nccl"), ... causal=False, ... dist_attn_config=DistAttnConfig( ... dispatch_config=DispatchConfig(alg=MinHeapDispatchAlg()), ... overlap_config=OverlapConfig( ... enable=True, ... mode=AttnOverlapMode.STATIC, ... degree=2, ... min_chunk_size=512, ... max_num_chunks=64, ... alg=OverlapAlgType.UNIFORM, ... ), ... ), ... ) >>> local_q, local_k, local_v = q_project(local_x), k_project(local_x), v_project(local_x) >>> # Do local attention computation >>> local_out, _ = calc_attn(local_q, local_k, local_v, dist_attn_runtime_key) >>> # Gather local attention results to global result >>> total_out = undispatch(local_out, dist_attn_runtime_key)
The logic of the magi_attn_varlen_dispatch
function mainly consists of two parts: it first calls magi_attn_varlen_key
to compute a key value, and then uses this key to dispatch the input x. The description of magi_attn_varlen_key
is as follows.
- magi_attention.api.magi_attn_interface.magi_attn_varlen_key()#
This is a flash_attn_varlen like interface, to generate q_ranges, k_ranges and attn_mask_type from cu_seqlens_q, cu_seqlens_k and causal, caculate DistAttnRuntimeKey and generate the corr. inner DistAttnRuntimeMgr.
- Parameters:
cu_seqlens_q (torch.Tensor) – Cumulative sequence lengths for queries.
cu_seqlens_k (torch.Tensor) – Cumulative sequence lengths for keys.
pad_size (int) – the size to pad along seq_dim. The seq_len need to be divisable by chunk_size * cp_size,
chunk_size (int) – chunk size to chunk the input tensor x along the seqlen dim for dispatch to control the granularity of computation load-balance.
cp_group_or_mesh (dist.ProcessGroup | DeviceMesh) – process group or device mesh. NOTE: for process group, we only support nccl backend for now, and for device mesh, we only support 1D or 2D mesh for now.
causal (bool) – if True, all attn_mask_type is CAUSAL. else, all attn_mask_type is FULL.
dist_attn_config (DistAttnConfig) – dist attn config.
- Returns:
the key points to the inner DistAttnRuntimeMgr.
- Return type:
DistAttnRuntimeKey
Example
>>> dist_attn_runtime_key = magi_attn_varlen_key( ... cu_seqlen_q=torch.tensor( [0, 2048, 4096], dtype=torch.int32 ), ... cu_seqlen_k=torch.tensor( [0, 2048, 4096], dtype=torch.int32 ), ... pad_size=compute_pad_size(4096, 4, 512), # seqlne, cp_size, chunk_size ... chunk_size=512, ... cp_group_or_mesh=dist.new_group(list(range(4)), backend="nccl"), ... causal=False, ... dist_attn_config=DistAttnConfig( ... dispatch_config=DispatchConfig(alg=MinHeapDispatchAlg()), ... overlap_config=OverlapConfig( ... enable=True, ... mode=AttnOverlapMode.STATIC, ... degree=2, ... min_chunk_size=512, ... max_num_chunks=64, ... alg=OverlapAlgType.UNIFORM, ... ), ... ), ... ) >>> # Dispatch global query tensor to local query tensor >>> local_q = dispatch(total_q, dist_attn_runtime_key) >>> # Dispatch global key tensor to local key tensor >>> local_k = dispatch(total_k, dist_attn_runtime_key) >>> # Dispatch global value tensor to local value tensor >>> local_v = dispatch(total_v, dist_attn_runtime_key) >>> # Calculate local attention result >>> local_out, _ = calc_attn(local_q, local_k, local_v, dist_attn_runtime_key) >>> # Gather local attention results to global result >>> total_out = undispatch(local_out, dist_attn_runtime_key)
Flexible Dispatch#
If the masks you’re using are not limited to varlen full or varlen causal, but also include sliding window masks or other more diverse types, we recommend using the following API. By calling magi_attn_flex_dispatch
, you can obtain the dispatched x and key.
- magi_attention.api.magi_attn_interface.magi_attn_flex_dispatch()#
This is the most flexible interface, directly passing in q_ranges, k_ranges and attn_mask_type to caculate DistAttnRuntimeKey, generate the corr. inner DistAttnRuntimeMgr, finally pad and dispatch the input tensor to local tensor.
- Parameters:
x (torch.Tensor) – input tensor
q_ranges (AttnRanges) – global query ranges in the ref attn mask
k_ranges (AttnRanges) – global key ranges in the ref attn mask
attn_mask_type (str | AttnMaskType | list[str | AttnMaskType]) – attn mask type (list) represented by str or enum
AttnMaskType
or their mixed combinationtotal_seqlen_q (int) – the total seqlen of query (i.e. number of rows in the ref attn mask)
total_seqlen_k (int) – the total seqlen of key (i.e. number of columns in the ref attn mask)
pad_size (int) – the size to pad along seq_dim. The seq_len need to be divisable by chunk_size * cp_size,
chunk_size (int) – chunk size to chunk the input tensor x along the seqlen dim for dispatch to control the granularity of computation load-balance.
cp_group_or_mesh (dist.ProcessGroup | DeviceMesh) – process group or device mesh. NOTE: for process group, we only support nccl backend for now, and for device mesh, we only support 1D or 2D mesh for now.
dist_attn_config (DistAttnConfig) – dist attn config
is_same_source (bool) – is query tensor and key tensor share the same source
is_q_permutable (bool) – is query tensor permutable
is_k_permutable (bool) – is key tensor permutable
- Returns:
local_x (torch.Tensor): the local input x after padding.
key (DistAttnRuntimeKey): the key points to the inner DistAttnRuntimeMgr.
- Return type:
tuple[torch.Tensor, DistAttnRuntimeKey]
Note
For decoder-only transformers (e.g., GPT), it applies ‘self-attn’ as follows:
is_same_source
is True.Both
q
andk
are permutable, as long as they are permuted in the same way.
For encoder-decoder transformers (e.g., T5), it applies ‘cross-attn’ as follows:
is_same_source
is False.q
is permutable butk
is not.
For multi-modal transformers with external encoders, it applies ‘cross-attn’ as follows:
is_same_source
is False.q
is unpermutable due to self-attn, butk
is permutable even in a different way.
Example
>>> local_x, dist_attn_runtime_key = magi_attn_flex_dispatch( ... x = torch.randn( ... 4096, # seqlen ... 2048, # hidden_size ... device=device, ... dtype=dtype, ... requires_grad = True ... ), ... q_ranges=AttnRanges.from_ranges([[0, 2048], [2048, 4096]]), ... k_ranges=AttnRanges.from_ranges([[0, 2048], [0, 4096]]), ... attn_mask_type="full", ... total_seqlen_q=4096, ... total_seqlen_k=4096, ... pad_size=compute_pad_size(4096, 4, 512), # seqlen, cp_size, chun_size ... chunk_size=512, ... cp_group_or_mesh=dist.new_group(list(range(4)), backend="nccl"), ... dist_attn_config=DistAttnConfig( ... dispatch_config=DispatchConfig(alg=MinHeapDispatchAlg()), ... overlap_config=OverlapConfig( ... enable=True, ... mode=AttnOverlapMode.STATIC, ... degree=2, ... min_chunk_size=512, ... max_num_chunks=64, ... alg=OverlapAlgType.UNIFORM, ... ), ... ), ... is_same_source=True, ... is_q_permutable=True, ... is_k_permutable=True, ... ) >>> local_q, local_k, local_v = q_project(local_x), k_project(local_x), v_project(local_x) >>> # Do local attention computation >>> local_out, _ = calc_attn(local_q, local_k, local_v, dist_attn_runtime_key) >>> # Gather local attention results to global result >>> total_out = undispatch(local_out, dist_attn_runtime_key)
Similar to the logic of magi_attn_varlen_dispatch
, magi_attn_flex_dispatch
first calls magi_attn_flex_key
to obtain a key, and then uses this key to dispatch x. The description of magi_attn_flex_key
is as follows.
- magi_attention.api.magi_attn_interface.magi_attn_flex_key()#
This is the most flexible interface, directly passing in q_ranges, k_ranges and attn_mask_type to caculate DistAttnRuntimeKey and generate the corr. inner DistAttnRuntimeMgr.
- Parameters:
x (torch.Tensor) – input tensor
q_ranges (AttnRanges) – global query ranges in the ref attn mask
k_ranges (AttnRanges) – global key ranges in the ref attn mask
attn_mask_type (str | AttnMaskType | list[str | AttnMaskType]) – attn mask type (list) represented by str or enum
AttnMaskType
or their mixed combinationtotal_seqlen_q (int) – the total seqlen of query (i.e. number of rows in the ref attn mask)
total_seqlen_k (int) – the total seqlen of key (i.e. number of columns in the ref attn mask)
pad_size (int) – the size to pad along seq_dim. The seq_len need to be divisable by chunk_size * cp_size,
chunk_size (int) – chunk size to chunk the input tensor x along the seqlen dim for dispatch to control the granularity of computation load-balance.
cp_group_or_mesh (dist.ProcessGroup | DeviceMesh) – process group or device mesh. NOTE: for process group, we only support nccl backend for now, and for device mesh, we only support 1D or 2D mesh for now.
dist_attn_config (DistAttnConfig) – dist attn config
is_same_source (bool) – is query tensor and key tensor share the same source
is_q_permutable (bool) – is query tensor permutable
is_k_permutable (bool) – is key tensor permutable
- Returns:
the key points to the inner DistAttnRuntimeMgr.
- Return type:
DistAttnRuntimeKey
Note
For decoder-only transformers (e.g., GPT), it applies ‘self-attn’ as follows:
is_same_source
is True.Both
q
andk
are permutable, as long as they are permuted in the same way.
For encoder-decoder transformers (e.g., T5), it applies ‘cross-attn’ as follows:
is_same_source
is False.q
is permutable butk
is not.
For multi-modal transformers with external encoders, it applies ‘cross-attn’ as follows:
is_same_source
is False.q
is unpermutable due to self-attn, butk
is permutable even in a different way.
Example
>>> dist_attn_runtime_key = magi_attn_flex_key( ... q_ranges=AttnRanges.from_ranges([[0, 2048], [2048, 4096]]), ... k_ranges=AttnRanges.from_ranges([[0, 2048], [0, 4096]]), ... attn_mask_type="full", ... total_seqlen_q=4096, ... total_seqlen_k=4096, ... pad_size=compute_pad_size(4096, 4, 512), # seqlen, cp_size, chunk_size ... chunk_size=512, ... cp_group_or_mesh=dist.new_group(list(range(4)), backend="nccl"), ... is_same_source=True, ... is_q_permutable=True, ... is_k_permutable=True, ... dist_attn_config=DistAttnConfig( ... dispatch_config=DispatchConfig(alg=MinHeapDispatchAlg()), ... overlap_config=OverlapConfig( ... enable=True, ... mode=AttnOverlapMode.STATIC, ... degree=2, ... min_chunk_size=512, ... max_num_chunks=64, ... alg=OverlapAlgType.UNIFORM, ... ), ... ), ... ) >>> # Dispatch global query tensor to local query tensor >>> local_q = dispatch(total_q, dist_attn_runtime_key) >>> # Dispatch global key tensor to local key tensor >>> local_k = dispatch(total_k, dist_attn_runtime_key) >>> # Dispatch global value tensor to local value tensor >>> local_v = dispatch(total_v, dist_attn_runtime_key) >>> # Calculate local attention result >>> local_out, _ = calc_attn(local_q, local_k, local_v, dist_attn_runtime_key) >>> # Gather local attention results to global result >>> total_out = undispatch(local_out, dist_attn_runtime_key)
Dispatch Function#
If you already have the key, you can call dispatch
function to get the padded and dispatched local tensor.
- magi_attention.api.magi_attn_interface.dispatch()#
Pad and dispatch the input tensor to local tensor on each rank along the seqlen dim.
- Parameters:
x (torch.Tensor) – input total tensor.
key (DistAttnRuntimeKey) – the key that holds some inner meta data, as one argument for many other magi_attention APIs, about which the users may have no bother to care.
- Returns:
the padded and dispatched local tensor.
- Return type:
torch.Tensor
- Raises:
ValueError – If the provided
key
does not exist indist_attn_runtime_dict
.
Calculate Attention#
After dispatch and projection, you should obtain the query, key, and value needed for computation. Using the key obtained from the dispatch function mentioned above, you can perform the computation by calling calc_attn
, which returns the results out and lse. The description of calc_attn is as follows.
- magi_attention.api.magi_attn_interface.calc_attn()#
Do attention computation.
- Parameters:
q (torch.Tensor) – Query tensor of shape
(num_tokens_q, num_heads, head_dim)
.k (torch.Tensor) – Key tensor of shape
(num_tokens_k, num_heads, head_dim)
.v (torch.Tensor) – Value tensor of shape
(num_tokens_v, num_heads, head_dim)
.key (DistAttnRuntimeKey) – the object that holds some inner meta data as one argument for many other magi_attention APIs, about which the users may have no bother to care.
- Returns:
out (torch.Tensor): Attention output tensor of shape.
lse (torch.Tensor): Log-sum-exp values for numerical stability.
- Return type:
tuple[torch.Tensor, torch.Tensor]
- Raises:
ValueError – If the provided
key
does not exist indist_attn_runtime_dict
.
Undispatch#
Undispatch Function#
When you need to recover the complete global tensor from the local tensor like computing the loss, you can call undispatch
function to unpad and undispatch the local tensor along the seqlen dim.
- magi_attention.api.magi_attn_interface.undispatch()#
Undispatch and unpad the local tensor to global tensor along the seqlen dim.
- Parameters:
x (torch.Tensor) – local tensor
key (DistAttnRuntimeKey) – the key that holds some inner meta data, as one argument for many other magi_attention APIs, about which the users may have no bother to care.
- Returns:
the undispatched and unpadded tensor.
- Return type:
torch.Tensor
- Raises:
ValueError – If the provided
key
does not exist indist_attn_runtime_dict
.
Utility Functions#
Compute Pad Size and Padding#
During the use of MagiAttention, we divide the total_seqlen
into multiple chunks of size chunk_size
and evenly distribute them across multiple GPUs. To ensure that total_seqlen
is divisible by chunk_size
and that each GPU receives the same number of chunks, we need to pad the original input. You can call compute_pad_size
to calculate the required padding length, and use this value as a parameter in subsequent functions.
- magi_attention.api.functools.compute_pad_size()#
Compute the size to pad to the input tensor along the seqlen dim at last.
- Parameters:
total_seqlen_q (int) – seqlen of q.
cp_size (int) – The size of cp group.
chunk_size (int) – chunk size to chunk the input tensor x along the seqlen dim for dispatch to control the granularity of computation load-balance.
- Returns:
the number of tokens to pad.
- Return type:
int
After obtaining pad_size
, you can use pad_at_dim
and unpad_at_dim
function to pad and unpad the tensor.
- magi_attention.api.functools.pad_at_dim()#
Pads a tensor along a specified dimension with a given value, either on the left or right side.
- Parameters:
x (torch.Tensor) – Input tensor to be padded.
dim (int) – The dimension along which to apply padding.
pad_size (int) – The number of values to pad.
value (float, optional) – The padding value. Default is 0.
side (str, optional) – Side on which to apply the padding, either “left” or “right”. Default is “right”.
- Returns:
The padded tensor with the same number of dimensions as the input.
- Return type:
torch.Tensor
- magi_attention.api.functools.unpad_at_dim()#
Removes padding from a tensor along a specified dimension.
- Parameters:
x (torch.Tensor) – Input tensor from which padding will be removed.
dim (int) – The dimension along which to remove padding.
pad_size (int) – The number of elements to remove from the end of the specified dimension.
- Returns:
The tensor with padding removed along the specified dimension.
- Return type:
torch.Tensor
Similarly, you can use pad_size
along with total_seqlen
and other related information to apply padding to a (q_ranges, k_ranges, masktypes) tuple using apply_padding
function. This function fills the padding region with invalid slices.
- magi_attention.api.functools.apply_padding()#
Appends padding to the attention ranges and updates the corresponding mask type.
This function adds a padding query range at the end of q_ranges, a dummy key range to k_ranges, and appends a FULL attention mask type to maintain alignment. It is typically used when padding is required for alignment or block-wise processing.
- Parameters:
q_ranges (AttnRanges) – Query token ranges before padding.
k_ranges (AttnRanges) – Key token ranges before padding.
attn_mask_type (list[AttnMaskType]) – List of attention mask types corresponding to the ranges.
total_seqlen (int) – The total original sequence length (used to place the padding at the end).
pad_size (int) – The size of the padding to append.
- Returns:
Updated query ranges with padding added.
Updated key ranges with a dummy range for padding.
Updated attention mask type list with a FULL mask for the padding block.
- Return type:
tuple[AttnRanges, AttnRanges, list[AttnMaskType]]
Get Position Ids#
Since MagiAttention needs to permute the input tensor along the seqlen dim, some token-aware ops might be affected, such as RoPE. Therefore, we provide a function get_position_ids
to get the position ids of the input tensor similar to Llama.
- magi_attention.api.magi_attn_interface.get_position_ids()#
Get the position ids of local tensor to global tensor after dispatching.
- Parameters:
key (DistAttnRuntimeKey) – the key that holds some inner meta data, as one argument for many other magi_attention APIs, about which the users may have no bother to care.
- Returns:
postion ids of local tensor w.r.t. global tensor.
- Return type:
torch.Tensor
Get Most Recent Key#
If you have trouble accessing the meta key, and meanwhile you need to get the most recent key, then you can call get_most_recent_key
to get it. However, we strongly recommend you to access the key passed through the arguments, in case of unexpected inconsistency.
- magi_attention.api.magi_attn_interface.get_most_recent_key()#
Get the most recent inserted key.
This is useful when you can not access the key through the arguments, and meanwhile you only need the most recent inserted key. However, we strongly recommend you to access the key passed through the arguments, in case of unexpected inconsistency.
- Returns:
the most recent inserted key.
- Return type:
DistAttnRuntimeKey
Infer Varlen Masks#
If you want to use a varlen mask where each segment has the same length, we provide a infer_varlen_mask_from_batch
function that generates the corresponding cu_seqlens tensors for you.
- magi_attention.api.functools.infer_varlen_mask_from_batch()#
Converts fixed-length full attention into varlen fulll attention format by generating cumulative sequence lengths for queries and keys.
- Parameters:
batch_size (int) – The number of sequences in the batch.
seq_len (int) – The fixed sequence length for each sequence in the batch.
- Returns:
A pair of 1D tensors (cu_seqlens_q, cu_seqlens_k), each of shape (batch_size + 1,), representing the cumulative sequence lengths for the queries and keys respectively.
- Return type:
tuple[torch.Tensor, torch.Tensor]
During the use of varlen mask, it is often necessary to reshape a tensor of shape [batch_size × seq_len, ...]
into [batch_size × seq_len, ...]
. To facilitate the use of the above APIs, we provide the squash_batch_dim
function to merge the tensor dimensions.
- magi_attention.api.functools.squash_batch_dim()#
Reshapes a tensor from shape
[b, s, ...]
to[b x s, ...]
, effectively flattening the batch and sequence dimensions into a single leading dimension.- Parameters:
x (torch.Tensor) – Input tensor of shape
[batch_size, seq_len, ...]
to be merged.- Returns:
Reshaped tensor of shape
[batch_size x seq_len, ...]
.- Return type:
torch.Tensor
Infer Sliding Window Masks#
In the design of MagiAttention
, we use a (q_range, k_range, masktype) tuple to represent a slice. For sliding window masks, we do not provide a dedicated masktype to represent them directly. However, a sliding window mask can be decomposed into a combination of existing masktypes such as full
, causal
, inv_causal
, and bi_causal
. If you’re unsure how to perform this decomposition, we provide infer_attn_mask_from_sliding_window
function to handle this process for you.
- magi_attention.api.functools.infer_attn_mask_from_sliding_window()#
Convert only one sliding window masks into representations using q_range, k_range, and mask type. The mask type is specified using window_size.
- Parameters:
q_range (AttnRange) – q_range of this sliding window mask
k_range (AttnRange) – k_range of this sliding window mask
window_size (list[int]) – window_size of sliding window mask
- Returns:
processed
(q_ranges, k_ranges, masktypes)
triple, sliding window mask have been cutted into triple representation.- Return type:
tuple[AttnRanges, AttnRanges, list[AttnMaskType]]
Example
Here’s an example of
infer_attn_mask_from_sliding_window
:>>> q_ranges, k_ranges, attn_mask_type = infer_attn_mask_from_sliding_window( ... q_range=AttnRange.from_range([5, 15]), ... k_range=AttnRange.from_range([5, 15]), ... window_size=(2, 3), ... )
The code above represents the sliding window mask within the
[5, 15] x [5, 15]
region with a window size of(2, 3)
.