QuickStart#
Basic Usage for Flex-Flash-Attention#
import torch
from magi_attention.api import flex_flash_attn_func
# --- Define attention config --- #
total_seqlen = 2048 # 2k tokens
num_heads_q = 8 # number of attention (query) heads
num_heads_kv = 2 # number of key/value heads (GQA)
head_dim = 128 # dimension of each attention head
dtype = torch.bfloat16 # attention activation / computation dtype (while the reduction dtype is always fp32 for ffa right now)
device = "cuda"
# --- Initialize QKV tensor --- #
q = torch.randn(total_seqlen, num_heads_q, head_dim, dtype=dtype, device=device)
k = torch.randn(total_seqlen, num_heads_kv, head_dim, dtype=dtype, device=device)
v = torch.randn(total_seqlen, num_heads_kv, head_dim, dtype=dtype, device=device)
# --- Initialize FFA meta args for customized attention mask --- #
# the following customized attention mask looks like (`*` for unmasked, `0` for masked):
# - - - - - - - - -> (k)
# | * * * * 0 0 0 0
# | * * * * 0 0 0 0
# | * * * * 0 0 0 0
# | * * * * 0 0 0 0
# | * * * * * 0 0 0
# | * * * * * * 0 0
# | * * * * * * * 0
# | * * * * * * * *
# V
# (q)
q_ranges_tensor = torch.tensor([[0, 1024], [1024, 2048]], dtype=torch.int32, device=device)
k_ranges_tensor = torch.tensor([[0, 1024], [0, 2048]], dtype=torch.int32, device=device)
attn_type_map_tensor = torch.tensor([0, 1], dtype=torch.int32, device=device) # full mask for 1st slice, causal mask for 2nd
max_seqlen_q = 1024 # Max length of all q_ranges (2048 - 1024 = 1024)
max_seqlen_k = 2048 # Max length of all k_ranges (2048 - 0 = 2048)
# --- Attention computation --- #
out, _ = flex_flash_attn_func( # the second return value is `lse` (log-sum-exp), known as the online-softmax correction factor
q, k, v,
q_ranges=q_ranges_tensor,
k_ranges=k_ranges_tensor,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
attn_type_map=attn_type_map_tensor,
softmax_scale=None, # defaults to 1/sqrt(head_dim)
)
Basic Usage for MagiAttention#
NOTE: You should run the following examples in a distributed environment, e.g. using the common torchrun
script
import torch
import torch.nn as nn
from magi_attention.api import (
magi_attn_flex_dispatch, calc_attn, undispatch, # interface functions
compute_pad_size, # helper functions
)
from magi_attention.common import AttnRanges
from magi_attention.common.enum import AttnMaskType
from magi_attention.utils import setup_dist_env, clearup_dist_env
# --- Set up distributed environment --- #
rank, local_rank, world_size, world_group, device, seed = setup_dist_env()
# --- Define attention config --- #
total_seqlen = 32 * 1024 # 32k tokens, if we dispatch it to 8 GPUs, then each GPU holds 4k tokens
num_heads_q = 48 # number of attention (query) heads
num_heads_kv = 8 # number of key/value heads (GQA)
head_dim = 128 # dimension of each attention head
dtype = torch.bfloat16 # attention activation / computation dtype (while the reduction dtype for partial attention outputs is always fp32 for magi_attention right now)
chunk_size = 512 # chunk size to chunk the input tensor x along the seqlen dim for dispatch to control the granularity of computation load-balance.
# --- Initialize token embedding tensor --- #
embed_dim = 4096
x = torch.randn(total_seqlen, embed_dim, device=device, dtype=dtype, requires_grad=True)
# --- Initialize MagiAttention meta configs for customized attention mask --- #
# the following customized attention mask is known as `block-causal` mask where `block_size` = 4096 (4k),
# which looks like (`*` for unmasked, `0` for masked):
# - - - - - - - - -> (k)
# | * * 0 0 0 0 0 0
# | * * 0 0 0 0 0 0
# | * * * * 0 0 0 0
# | * * * * 0 0 0 0
# | * * * * * * 0 0
# | * * * * * * 0 0
# | * * * * * * * *
# | * * * * * * * *
# V
# (q)
q_ranges = AttnRanges.from_ranges(
[
[0, 4096], # 0~4k
[4096, 8192], # 4k~8k
[8192, 12288], # 8k~12k
[12288, 16384], # 12k~16k
[16384, 20480], # 16k~20k
[20480, 24576], # 20k~24k
[24576, 28672], # 24k~28k
[28672, 32768], # 28k~32k
]
)
k_ranges = AttnRanges.from_ranges(
[
[0, 4096], # 0~4k
[0, 8192], # 0~8k
[0, 12288], # 0~12k
[0, 16384], # 0~16k
[0, 20480], # 0~20k
[0, 24576], # 0~24k
[0, 28672], # 0~28k
[0, 32768], # 0~32k
]
)
attn_mask_type = [AttnMaskType.FULL] * len(q_ranges)
total_seqlen_q = total_seqlen_k = total_seqlen
pad_size = compute_pad_size( # pad embeds along seqlen dim for better performance
total_seqlen_q=total_seqlen_q,
cp_size=world_size, # assuming we only have 1-dim context parallelism (cp)
chunk_size=chunk_size,
)
# --- Dispatch token embedding tensor along seqlen dim to multiple ranks --- #
# NOTE:
# 1. the dispatched local token embedding may be shuffled along seqlen dim,
# so it's safe for token-wise operations such as matmul, layer-norm, etc
# while for sample-wise operations like RoPE, you might need to be more careful
# 2. the `magi_runtime_key` holds some inner meta data as one argument for many other magi_attention APIs,
# about which the users may have no bother to care
local_x, magi_attn_runtime_key = magi_attn_flex_dispatch(
x,
q_ranges=q_ranges,
k_ranges=k_ranges,
attn_mask_type=attn_mask_type,
total_seqlen_q=total_seqlen_q,
total_seqlen_k=total_seqlen_k,
pad_size=pad_size,
chunk_size=chunk_size,
cp_group_or_mesh=world_group, # assuming we only have 1-dim context parallelism (cp)
)
# --- Simulate QKV projection --- #
q_proj = nn.Linear(embed_dim, num_heads_q * head_dim, dtype=dtype, device=device)
k_proj = nn.Linear(embed_dim, num_heads_kv * head_dim, dtype=dtype, device=device)
v_proj = nn.Linear(embed_dim, num_heads_kv * head_dim, dtype=dtype, device=device)
local_q = q_proj(local_x).view(-1, num_heads_q, head_dim)
local_k = k_proj(local_x).view(-1, num_heads_kv, head_dim)
local_v = v_proj(local_x).view(-1, num_heads_kv, head_dim)
# --- Distributed attention computation --- #
local_out, _ = calc_attn( # the second return value is `local_lse` (log-sum-exp), known as the online-softmax correction factor
q=local_q,
k=local_k,
v=local_v,
key=magi_attn_runtime_key,
)
# --- Undispatch the output tensor along seqlen dim from multiple ranks and unpad --- #
# NOTE: the undispatch API may not be used until the moment you need the seqlen dimension to be compelete and ordered,
# e.g. for either aforementioned sample-wise operations, or loss computation
total_out = undispatch(
x=local_out,
key=magi_attn_runtime_key,
)
# --- Clear up distributed environment --- #
clearup_dist_env()