API Reference#

Flexible Flash Attention#

To support computing irregular-shaped masks, we implemented a flexible_flash_attention kernel, which can be invoked through the following interface.

magi_attention.functional.flex_flash_attn.flex_flash_attn_func()#

An interface similar to flash attention that doesn’t require distributed environment, dispatch or undispatch. Directly call magi_attn_kernel to get attention output and lse. This is faster when you don’t need context parallel.

Parameters:
  • q (torch.Tensor) – Query tensor.

  • k (torch.Tensor) – Key tensor.

  • v (torch.Tensor) – Value tensor.

  • q_ranges (torch.Tensor) – query ranges tensor to represent the attn mask.

  • k_ranges (torch.Tensor) – key ranges tensor to represent the attn mask.

  • max_seqlen_q (int) – Maximum sequence length of q_ranges.

  • max_seqlen_k (int) – Maximum sequence length of k_ranges.

  • attn_type_map (torch.Tensor) –

    Attention type map tenspr with dtype=torch.int32. The values specify the attention type for each token:

    • 0: full attention

    • 1: causal attention

    • 2: inverse causal attention

    • 3: bidirectional causal attention

    More information about the attention type map can be found in the Note below.

  • softmax_scale (float, optional) – Softmax scale, defaults to 1/sqrt(head_dim).

  • softcap (float, optional) – Softcap value, defaults to 0.

  • deterministic (bool, optional) – Whether to use deterministic attention, defaults to False.

  • sm_margin (int, optional) – the amount of SMs(streaming multiprocessors) reserved for communication.

  • disable_fwd_atomic_reduction (bool) –

    Whether to disable forward atomic reduction:

    If you can ensure q_ranges has no overlap, you can set this to True for better performance. Overlap in q_ranges is defined as: if any two q_ranges have non-empty intersection, then there is overlap. For example, q_ranges = [[0, 15], [10, 20], [20, 30]] has overlap because [0, 15] and [10, 20] intersect. While q_ranges = [[0, 15], [15, 20], [20, 30]] has no overlap.

  • auto_range_merge (bool, optional) –

    Whether to automatically merge k_ranges for the same q_range, defaults to False.

    Note: This flag is usually used in sparse attention cases but still under development.

Returns:

  • out (torch.Tensor): Attention output tensor

  • lse (torch.Tensor): Log-sum-exp values with dtype=torch.float32.

Return type:

tuple[torch.Tensor, torch.Tensor]

Shape:
  • q: (num_tokens_q, num_heads, head_dim)

  • k: (num_tokens_kv, num_heads, head_dim)

  • v: (num_tokens_kv, num_heads, head_dim)

  • q_ranges: (num_ranges, 2)

  • k_ranges: (num_ranges, 2)

  • attn_type_map: (num_ranges, )

  • out: (num_tokens_q, num_heads, head_dim)

  • lse: (num_heads, num_tokens_q)

Note

The attn_type_map explains the semantics of different attention mask types. In addition to the descriptions below, see our blog for a visual explanation: https://sandai-org.github.io/MagiAttention/blog/#flex-flash-attn

  1. Full attention:

    If seqlen_q = 5 and seqlen_k = 2:

    1 1
    1 1
    1 1
    1 1
    1 1
    

    If seqlen_q = 2 and seqlen_k = 5:

    1 1 1 1 1
    1 1 1 1 1
    

    If seqlen_q = 5 and seqlen_k = 5:

    1 1 1 1 1
    1 1 1 1 1
    1 1 1 1 1
    1 1 1 1 1
    1 1 1 1 1
    
  2. Causal attention (bottom-right aligned):

    If seqlen_q = 5 and seqlen_k = 2:

    0 0
    0 0
    0 0
    1 0
    1 1
    

    If seqlen_q = 2 and seqlen_k = 5:

    1 1 1 1 0
    1 1 1 1 1
    

    If seqlen_q = 5 and seqlen_k = 5:

    1 0 0 0 0
    1 1 0 0 0
    1 1 1 0 0
    1 1 1 1 0
    1 1 1 1 1
    
  3. Inverse causal attention (top-left aligned):

    If seqlen_q = 5 and seqlen_k = 2:

    1 1
    0 1
    0 0
    0 0
    0 0
    

    If seqlen_q = 2 and seqlen_k = 5:

    1 1 1 1 1
    0 1 1 1 1
    

    If seqlen_q = 5 and seqlen_k = 5:

    1 1 1 1 1
    0 1 1 1 1
    0 0 1 1 1
    0 0 0 1 1
    0 0 0 0 1
    
  4. Bidirectional causal attention (intersection of causal and inverse causal):

    This is the element-wise AND of causal and inverse causal masks.

    If seqlen_q = 5 and seqlen_k = 2:

    0 0
    0 0
    0 0
    0 0
    0 0
    

    If seqlen_q = 2 and seqlen_k = 5:

    1 1 1 1 0
    0 1 1 1 1
    

    If seqlen_q = 5 and seqlen_k = 5:

    1 0 0 0 0
    0 1 0 0 0
    0 0 1 0 0
    0 0 0 1 0
    0 0 0 0 1
    

Dispatch#

Varlen Dispatch#

If you’re using a mask defined by cu_seqlens, such as a varlen full or varlen causal mask, we’ve designed a similar interface inspired by FlashAttention’s API, making it easy for you to get started quickly. In the function named magi_attn_varlen_dispatch, you can obtain the dispatched x and key.

magi_attention.api.magi_attn_interface.magi_attn_varlen_dispatch()#

This is a flash_attn_varlen like interface, to generate q_ranges, k_ranges and attn_mask_type from cu_seqlens_q, cu_seqlens_k and causal flag, further caculate DistAttnRuntimeKey, generate the corr. inner DistAttnRuntimeMgr, finally pad and dispatch the input tensor to local tensor.

Parameters:
  • x (torch.Tensor) – input tensor

  • cu_seqlens_q (torch.Tensor) – Cumulative sequence lengths for queries.

  • cu_seqlens_k (torch.Tensor) – Cumulative sequence lengths for keys.

  • pad_size (int) – the size to pad along seq_dim. The seq_len need to be divisable by chunk_size * cp_size,

  • chunk_size (int) – chunk size to chunk the input tensor x along the seqlen dim for dispatch to control the granularity of computation load-balance.

  • cp_group_or_mesh (dist.ProcessGroup | DeviceMesh) – process group or device mesh. NOTE: for process group, we only support nccl backend for now, and for device mesh, we only support 1D or 2D mesh for now.

  • causal (bool) – if True, all attn_mask_type is CAUSAL. else, all attn_mask_type is FULL.

  • dist_attn_config (DistAttnConfig) – dist attn config.

Returns:

  • x (torch.Tensor): the input tensor after padding.

  • key (DistAttnRuntimeKey): the key points to the inner DistAttnRuntimeMgr.

Return type:

tuple[torch.Tensor, DistAttnRuntimeKey]

Example

>>> local_x, dist_attn_runtime_key = magi_attn_varlen_dispatch(
...     x=torch.randn(
...         4096,  # seqlen
...         2048,  # hidden_size
...         device=device,
...         dtype=dtype,
...         requires_grad = True
...     ),
...     cu_seqlen_q=torch.tensor(
...         [0, 2048, 4096], dtype=torch.int32
...     ),
...     cu_seqlen_k=torch.tensor(
...         [0, 2048, 4096], dtype=torch.int32
...     ),
...     pad_size=compute_pad_size(4096, 4, 512),  # seqlen, cp_size, chunk_size
...     chunk_size=512,
...     cp_group_or_mesh=dist.new_group(list(range(4)), backend="nccl"),
...     causal=False,
...     dist_attn_config=DistAttnConfig(
...         dispatch_config=DispatchConfig(alg=MinHeapDispatchAlg()),
...         overlap_config=OverlapConfig(
...             enable=True,
...             mode=AttnOverlapMode.STATIC,
...             degree=2,
...             min_chunk_size=512,
...             max_num_chunks=64,
...             alg=OverlapAlgType.UNIFORM,
...         ),
...     ),
... )
>>> local_q, local_k, local_v = q_project(local_x), k_project(local_x), v_project(local_x)
>>> # Do local attention computation
>>> local_out, _ = calc_attn(local_q, local_k, local_v, dist_attn_runtime_key)
>>> # Gather local attention results to global result
>>> total_out = undispatch(local_out, dist_attn_runtime_key)

The logic of the magi_attn_varlen_dispatch function mainly consists of two parts: it first calls magi_attn_varlen_key to compute a key value, and then uses this key to dispatch the input x. The description of magi_attn_varlen_key is as follows.

magi_attention.api.magi_attn_interface.magi_attn_varlen_key()#

This is a flash_attn_varlen like interface, to generate q_ranges, k_ranges and attn_mask_type from cu_seqlens_q, cu_seqlens_k and causal, caculate DistAttnRuntimeKey and generate the corr. inner DistAttnRuntimeMgr.

Parameters:
  • cu_seqlens_q (torch.Tensor) – Cumulative sequence lengths for queries.

  • cu_seqlens_k (torch.Tensor) – Cumulative sequence lengths for keys.

  • pad_size (int) – the size to pad along seq_dim. The seq_len need to be divisable by chunk_size * cp_size,

  • chunk_size (int) – chunk size to chunk the input tensor x along the seqlen dim for dispatch to control the granularity of computation load-balance.

  • cp_group_or_mesh (dist.ProcessGroup | DeviceMesh) – process group or device mesh. NOTE: for process group, we only support nccl backend for now, and for device mesh, we only support 1D or 2D mesh for now.

  • causal (bool) – if True, all attn_mask_type is CAUSAL. else, all attn_mask_type is FULL.

  • dist_attn_config (DistAttnConfig) – dist attn config.

Returns:

the key points to the inner DistAttnRuntimeMgr.

Return type:

DistAttnRuntimeKey

Example

>>> dist_attn_runtime_key = magi_attn_varlen_key(
...     cu_seqlen_q=torch.tensor(
            [0, 2048, 4096], dtype=torch.int32
        ),
...     cu_seqlen_k=torch.tensor(
            [0, 2048, 4096], dtype=torch.int32
        ),
...     pad_size=compute_pad_size(4096, 4, 512), # seqlne, cp_size, chunk_size
...     chunk_size=512,
...     cp_group_or_mesh=dist.new_group(list(range(4)), backend="nccl"),
...     causal=False,
...     dist_attn_config=DistAttnConfig(
...         dispatch_config=DispatchConfig(alg=MinHeapDispatchAlg()),
...         overlap_config=OverlapConfig(
...             enable=True,
...             mode=AttnOverlapMode.STATIC,
...             degree=2,
...             min_chunk_size=512,
...             max_num_chunks=64,
...             alg=OverlapAlgType.UNIFORM,
...         ),
...     ),
... )
>>> # Dispatch global query tensor to local query tensor
>>> local_q = dispatch(total_q, dist_attn_runtime_key)
>>> # Dispatch global key tensor to local key tensor
>>> local_k = dispatch(total_k, dist_attn_runtime_key)
>>> # Dispatch global value tensor to local value tensor
>>> local_v = dispatch(total_v, dist_attn_runtime_key)
>>> # Calculate local attention result
>>> local_out, _ = calc_attn(local_q, local_k, local_v, dist_attn_runtime_key)
>>> # Gather local attention results to global result
>>> total_out = undispatch(local_out, dist_attn_runtime_key)

Flexible Dispatch#

If the masks you’re using are not limited to varlen full or varlen causal, but also include sliding window masks or other more diverse types, we recommend using the following API. By calling magi_attn_flex_dispatch, you can obtain the dispatched x and key.

magi_attention.api.magi_attn_interface.magi_attn_flex_dispatch()#

This is the most flexible interface, directly passing in q_ranges, k_ranges and attn_mask_type to caculate DistAttnRuntimeKey, generate the corr. inner DistAttnRuntimeMgr, finally pad and dispatch the input tensor to local tensor.

Parameters:
  • x (torch.Tensor) – input tensor

  • q_ranges (AttnRanges) – global query ranges in the ref attn mask

  • k_ranges (AttnRanges) – global key ranges in the ref attn mask

  • attn_mask_type (str | AttnMaskType | list[str | AttnMaskType]) – attn mask type (list) represented by str or enum AttnMaskType or their mixed combination

  • total_seqlen_q (int) – the total seqlen of query (i.e. number of rows in the ref attn mask)

  • total_seqlen_k (int) – the total seqlen of key (i.e. number of columns in the ref attn mask)

  • pad_size (int) – the size to pad along seq_dim. The seq_len need to be divisable by chunk_size * cp_size,

  • chunk_size (int) – chunk size to chunk the input tensor x along the seqlen dim for dispatch to control the granularity of computation load-balance.

  • cp_group_or_mesh (dist.ProcessGroup | DeviceMesh) – process group or device mesh. NOTE: for process group, we only support nccl backend for now, and for device mesh, we only support 1D or 2D mesh for now.

  • dist_attn_config (DistAttnConfig) – dist attn config

  • is_same_source (bool) – is query tensor and key tensor share the same source

  • is_q_permutable (bool) – is query tensor permutable

  • is_k_permutable (bool) – is key tensor permutable

Returns:

  • local_x (torch.Tensor): the local input x after padding.

  • key (DistAttnRuntimeKey): the key points to the inner DistAttnRuntimeMgr.

Return type:

tuple[torch.Tensor, DistAttnRuntimeKey]

Note

  1. For decoder-only transformers (e.g., GPT), it applies ‘self-attn’ as follows:

    1. is_same_source is True.

    2. Both q and k are permutable, as long as they are permuted in the same way.

  2. For encoder-decoder transformers (e.g., T5), it applies ‘cross-attn’ as follows:

    1. is_same_source is False.

    2. q is permutable but k is not.

  3. For multi-modal transformers with external encoders, it applies ‘cross-attn’ as follows:

    1. is_same_source is False.

    2. q is unpermutable due to self-attn, but k is permutable even in a different way.

Example

>>> local_x, dist_attn_runtime_key = magi_attn_flex_dispatch(
...     x = torch.randn(
...         4096,   # seqlen
...         2048,   # hidden_size
...         device=device,
...         dtype=dtype,
...         requires_grad = True
...     ),
...     q_ranges=AttnRanges.from_ranges([[0, 2048], [2048, 4096]]),
...     k_ranges=AttnRanges.from_ranges([[0, 2048], [0, 4096]]),
...     attn_mask_type="full",
...     total_seqlen_q=4096,
...     total_seqlen_k=4096,
...     pad_size=compute_pad_size(4096, 4, 512),  # seqlen, cp_size, chun_size
...     chunk_size=512,
...     cp_group_or_mesh=dist.new_group(list(range(4)), backend="nccl"),
...     dist_attn_config=DistAttnConfig(
...         dispatch_config=DispatchConfig(alg=MinHeapDispatchAlg()),
...         overlap_config=OverlapConfig(
...             enable=True,
...             mode=AttnOverlapMode.STATIC,
...             degree=2,
...             min_chunk_size=512,
...             max_num_chunks=64,
...             alg=OverlapAlgType.UNIFORM,
...         ),
...     ),
...     is_same_source=True,
...     is_q_permutable=True,
...     is_k_permutable=True,
... )
>>> local_q, local_k, local_v = q_project(local_x), k_project(local_x), v_project(local_x)
>>> # Do local attention computation
>>> local_out, _ = calc_attn(local_q, local_k, local_v, dist_attn_runtime_key)
>>> # Gather local attention results to global result
>>> total_out = undispatch(local_out, dist_attn_runtime_key)

Similar to the logic of magi_attn_varlen_dispatch, magi_attn_flex_dispatch first calls magi_attn_flex_key to obtain a key, and then uses this key to dispatch x. The description of magi_attn_flex_key is as follows.

magi_attention.api.magi_attn_interface.magi_attn_flex_key()#

This is the most flexible interface, directly passing in q_ranges, k_ranges and attn_mask_type to caculate DistAttnRuntimeKey and generate the corr. inner DistAttnRuntimeMgr.

Parameters:
  • x (torch.Tensor) – input tensor

  • q_ranges (AttnRanges) – global query ranges in the ref attn mask

  • k_ranges (AttnRanges) – global key ranges in the ref attn mask

  • attn_mask_type (str | AttnMaskType | list[str | AttnMaskType]) – attn mask type (list) represented by str or enum AttnMaskType or their mixed combination

  • total_seqlen_q (int) – the total seqlen of query (i.e. number of rows in the ref attn mask)

  • total_seqlen_k (int) – the total seqlen of key (i.e. number of columns in the ref attn mask)

  • pad_size (int) – the size to pad along seq_dim. The seq_len need to be divisable by chunk_size * cp_size,

  • chunk_size (int) – chunk size to chunk the input tensor x along the seqlen dim for dispatch to control the granularity of computation load-balance.

  • cp_group_or_mesh (dist.ProcessGroup | DeviceMesh) – process group or device mesh. NOTE: for process group, we only support nccl backend for now, and for device mesh, we only support 1D or 2D mesh for now.

  • dist_attn_config (DistAttnConfig) – dist attn config

  • is_same_source (bool) – is query tensor and key tensor share the same source

  • is_q_permutable (bool) – is query tensor permutable

  • is_k_permutable (bool) – is key tensor permutable

Returns:

the key points to the inner DistAttnRuntimeMgr.

Return type:

DistAttnRuntimeKey

Note

  1. For decoder-only transformers (e.g., GPT), it applies ‘self-attn’ as follows:

    1. is_same_source is True.

    2. Both q and k are permutable, as long as they are permuted in the same way.

  2. For encoder-decoder transformers (e.g., T5), it applies ‘cross-attn’ as follows:

    1. is_same_source is False.

    2. q is permutable but k is not.

  3. For multi-modal transformers with external encoders, it applies ‘cross-attn’ as follows:

    1. is_same_source is False.

    2. q is unpermutable due to self-attn, but k is permutable even in a different way.

Example

>>> dist_attn_runtime_key = magi_attn_flex_key(
...     q_ranges=AttnRanges.from_ranges([[0, 2048], [2048, 4096]]),
...     k_ranges=AttnRanges.from_ranges([[0, 2048], [0, 4096]]),
...     attn_mask_type="full",
...     total_seqlen_q=4096,
...     total_seqlen_k=4096,
...     pad_size=compute_pad_size(4096, 4, 512),  # seqlen, cp_size, chunk_size
...     chunk_size=512,
...     cp_group_or_mesh=dist.new_group(list(range(4)), backend="nccl"),
...     is_same_source=True,
...     is_q_permutable=True,
...     is_k_permutable=True,
...     dist_attn_config=DistAttnConfig(
...         dispatch_config=DispatchConfig(alg=MinHeapDispatchAlg()),
...         overlap_config=OverlapConfig(
...             enable=True,
...             mode=AttnOverlapMode.STATIC,
...             degree=2,
...             min_chunk_size=512,
...             max_num_chunks=64,
...             alg=OverlapAlgType.UNIFORM,
...         ),
...     ),
... )
>>> # Dispatch global query tensor to local query tensor
>>> local_q = dispatch(total_q, dist_attn_runtime_key)
>>> # Dispatch global key tensor to local key tensor
>>> local_k = dispatch(total_k, dist_attn_runtime_key)
>>> # Dispatch global value tensor to local value tensor
>>> local_v = dispatch(total_v, dist_attn_runtime_key)
>>> # Calculate local attention result
>>> local_out, _ = calc_attn(local_q, local_k, local_v, dist_attn_runtime_key)
>>> # Gather local attention results to global result
>>> total_out = undispatch(local_out, dist_attn_runtime_key)

Dispatch Function#

If you already have the key, you can call dispatch function to get the padded and dispatched local tensor.

magi_attention.api.magi_attn_interface.dispatch()#

Pad and dispatch the input tensor to local tensor on each rank along the seqlen dim.

Parameters:
  • x (torch.Tensor) – input total tensor.

  • key (DistAttnRuntimeKey) – the key that holds some inner meta data, as one argument for many other magi_attention APIs, about which the users may have no bother to care.

Returns:

the padded and dispatched local tensor.

Return type:

torch.Tensor

Raises:

ValueError – If the provided key does not exist in dist_attn_runtime_dict.

Calculate Attention#

After dispatch and projection, you should obtain the query, key, and value needed for computation. Using the key obtained from the dispatch function mentioned above, you can perform the computation by calling calc_attn, which returns the results out and lse. The description of calc_attn is as follows.

magi_attention.api.magi_attn_interface.calc_attn()#

Do attention computation.

Parameters:
  • q (torch.Tensor) – Query tensor of shape (num_tokens_q, num_heads, head_dim).

  • k (torch.Tensor) – Key tensor of shape (num_tokens_k, num_heads, head_dim).

  • v (torch.Tensor) – Value tensor of shape (num_tokens_v, num_heads, head_dim).

  • key (DistAttnRuntimeKey) – the object that holds some inner meta data as one argument for many other magi_attention APIs, about which the users may have no bother to care.

Returns:

  • out (torch.Tensor): Attention output tensor of shape.

  • lse (torch.Tensor): Log-sum-exp values for numerical stability.

Return type:

tuple[torch.Tensor, torch.Tensor]

Raises:

ValueError – If the provided key does not exist in dist_attn_runtime_dict.

Undispatch#

Undispatch Function#

When you need to recover the complete global tensor from the local tensor like computing the loss, you can call undispatch function to unpad and undispatch the local tensor along the seqlen dim.

magi_attention.api.magi_attn_interface.undispatch()#

Undispatch and unpad the local tensor to global tensor along the seqlen dim.

Parameters:
  • x (torch.Tensor) – local tensor

  • key (DistAttnRuntimeKey) – the key that holds some inner meta data, as one argument for many other magi_attention APIs, about which the users may have no bother to care.

Returns:

the undispatched and unpadded tensor.

Return type:

torch.Tensor

Raises:

ValueError – If the provided key does not exist in dist_attn_runtime_dict.

Utility Functions#

Compute Pad Size and Padding#

During the use of MagiAttention, we divide the total_seqlen into multiple chunks of size chunk_size and evenly distribute them across multiple GPUs. To ensure that total_seqlen is divisible by chunk_size and that each GPU receives the same number of chunks, we need to pad the original input. You can call compute_pad_size to calculate the required padding length, and use this value as a parameter in subsequent functions.

magi_attention.api.functools.compute_pad_size()#

Compute the size to pad to the input tensor along the seqlen dim at last.

Parameters:
  • total_seqlen_q (int) – seqlen of q.

  • cp_size (int) – The size of cp group.

  • chunk_size (int) – chunk size to chunk the input tensor x along the seqlen dim for dispatch to control the granularity of computation load-balance.

Returns:

the number of tokens to pad.

Return type:

int

After obtaining pad_size, you can use pad_at_dim and unpad_at_dim function to pad and unpad the tensor.

magi_attention.api.functools.pad_at_dim()#

Pads a tensor along a specified dimension with a given value, either on the left or right side.

Parameters:
  • x (torch.Tensor) – Input tensor to be padded.

  • dim (int) – The dimension along which to apply padding.

  • pad_size (int) – The number of values to pad.

  • value (float, optional) – The padding value. Default is 0.

  • side (str, optional) – Side on which to apply the padding, either “left” or “right”. Default is “right”.

Returns:

The padded tensor with the same number of dimensions as the input.

Return type:

torch.Tensor

magi_attention.api.functools.unpad_at_dim()#

Removes padding from a tensor along a specified dimension.

Parameters:
  • x (torch.Tensor) – Input tensor from which padding will be removed.

  • dim (int) – The dimension along which to remove padding.

  • pad_size (int) – The number of elements to remove from the end of the specified dimension.

Returns:

The tensor with padding removed along the specified dimension.

Return type:

torch.Tensor

Similarly, you can use pad_size along with total_seqlen and other related information to apply padding to a (q_ranges, k_ranges, masktypes) tuple using apply_padding function. This function fills the padding region with invalid slices.

magi_attention.api.functools.apply_padding()#

Appends padding to the attention ranges and updates the corresponding mask type.

This function adds a padding query range at the end of q_ranges, a dummy key range to k_ranges, and appends a FULL attention mask type to maintain alignment. It is typically used when padding is required for alignment or block-wise processing.

Parameters:
  • q_ranges (AttnRanges) – Query token ranges before padding.

  • k_ranges (AttnRanges) – Key token ranges before padding.

  • attn_mask_type (list[AttnMaskType]) – List of attention mask types corresponding to the ranges.

  • total_seqlen (int) – The total original sequence length (used to place the padding at the end).

  • pad_size (int) – The size of the padding to append.

Returns:

  • Updated query ranges with padding added.

  • Updated key ranges with a dummy range for padding.

  • Updated attention mask type list with a FULL mask for the padding block.

Return type:

tuple[AttnRanges, AttnRanges, list[AttnMaskType]]

Get Position Ids#

Since MagiAttention needs to permute the input tensor along the seqlen dim, some token-aware ops might be affected, such as RoPE. Therefore, we provide a function get_position_ids to get the position ids of the input tensor similar to Llama.

magi_attention.api.magi_attn_interface.get_position_ids()#

Get the position ids of local tensor to global tensor after dispatching.

Parameters:

key (DistAttnRuntimeKey) – the key that holds some inner meta data, as one argument for many other magi_attention APIs, about which the users may have no bother to care.

Returns:

postion ids of local tensor w.r.t. global tensor.

Return type:

torch.Tensor

Get Most Recent Key#

If you have trouble accessing the meta key, and meanwhile you need to get the most recent key, then you can call get_most_recent_key to get it. However, we strongly recommend you to access the key passed through the arguments, in case of unexpected inconsistency.

magi_attention.api.magi_attn_interface.get_most_recent_key()#

Get the most recent inserted key.

This is useful when you can not access the key through the arguments, and meanwhile you only need the most recent inserted key. However, we strongly recommend you to access the key passed through the arguments, in case of unexpected inconsistency.

Returns:

the most recent inserted key.

Return type:

DistAttnRuntimeKey

Infer Varlen Masks#

If you want to use a varlen mask where each segment has the same length, we provide a infer_varlen_mask_from_batch function that generates the corresponding cu_seqlens tensors for you.

magi_attention.api.functools.infer_varlen_mask_from_batch()#

Converts fixed-length full attention into varlen fulll attention format by generating cumulative sequence lengths for queries and keys.

Parameters:
  • batch_size (int) – The number of sequences in the batch.

  • seq_len (int) – The fixed sequence length for each sequence in the batch.

Returns:

A pair of 1D tensors (cu_seqlens_q, cu_seqlens_k), each of shape (batch_size + 1,), representing the cumulative sequence lengths for the queries and keys respectively.

Return type:

tuple[torch.Tensor, torch.Tensor]

During the use of varlen mask, it is often necessary to reshape a tensor of shape [batch_size × seq_len, ...] into [batch_size × seq_len, ...]. To facilitate the use of the above APIs, we provide the squash_batch_dim function to merge the tensor dimensions.

magi_attention.api.functools.squash_batch_dim()#

Reshapes a tensor from shape [b, s, ...] to [b x s, ...], effectively flattening the batch and sequence dimensions into a single leading dimension.

Parameters:

x (torch.Tensor) – Input tensor of shape [batch_size, seq_len, ...] to be merged.

Returns:

Reshaped tensor of shape [batch_size x seq_len, ...].

Return type:

torch.Tensor

Infer Sliding Window Masks#

In the design of MagiAttention, we use a (q_range, k_range, masktype) tuple to represent a slice. For sliding window masks, we do not provide a dedicated masktype to represent them directly. However, a sliding window mask can be decomposed into a combination of existing masktypes such as full, causal, inv_causal, and bi_causal. If you’re unsure how to perform this decomposition, we provide infer_attn_mask_from_sliding_window function to handle this process for you.

magi_attention.api.functools.infer_attn_mask_from_sliding_window()#

Convert only one sliding window masks into representations using q_range, k_range, and mask type. The mask type is specified using window_size.

Parameters:
  • q_range (AttnRange) – q_range of this sliding window mask

  • k_range (AttnRange) – k_range of this sliding window mask

  • window_size (list[int]) – window_size of sliding window mask

Returns:

processed (q_ranges, k_ranges, masktypes) triple, sliding window mask have been cutted into triple representation.

Return type:

tuple[AttnRanges, AttnRanges, list[AttnMaskType]]

Example

Here’s an example of infer_attn_mask_from_sliding_window:

>>> q_ranges, k_ranges, attn_mask_type = infer_attn_mask_from_sliding_window(
...     q_range=AttnRange.from_range([5, 15]),
...     k_range=AttnRange.from_range([5, 15]),
...     window_size=(2, 3),
... )

The code above represents the sliding window mask within the [5, 15] x [5, 15] region with a window size of (2, 3).