QuickStart#

Basic Usage for Flex-Flash-Attention#

import torch
from magi_attention.api import flex_flash_attn_func

# --- Define attention config --- #

total_seqlen = 2048    # 2k tokens
seqlen_sink = 4        # 4 sink tokens
num_heads_q = 8        # number of attention (query) heads
num_heads_kv = 2       # number of key/value heads (GQA)
head_dim = 128         # dimension of each attention head
dtype = torch.bfloat16 # attention activation / computation dtype (while the reduction dtype is always fp32 for ffa right now)
device = "cuda"
has_sink = True        # whether to apply attention sink

# --- Initialize q,k,v,do tensors --- #

q = torch.randn(total_seqlen, num_heads_q, head_dim, dtype=dtype, device=device, requires_grad=True)
k = torch.randn(total_seqlen, num_heads_kv, head_dim, dtype=dtype, device=device, requires_grad=True)
v = torch.randn(total_seqlen, num_heads_kv, head_dim, dtype=dtype, device=device, requires_grad=True)
do = torch.randn_like(q)

# --- Initialize optional sink tensor --- #

sink = torch.randn(seqlen_sink, num_heads_q, dtype=torch.float32, device=device, requires_grad=True) if has_sink else None

# --- Initialize FFA meta args for customized attention mask --- #

# the following customized attention mask looks like (`*` for unmasked, `0` for masked):
#     - - - - - - - - -> (k)
#   | * * * * 0 0 0 0
#   | * * * * 0 0 0 0
#   | * * * * 0 0 0 0
#   | * * * * 0 0 0 0
#   | * * * * * 0 0 0
#   | * * * * * * 0 0
#   | * * * * * * * 0
#   | * * * * * * * *
#   V
#  (q)
q_ranges_tensor = torch.tensor([[0, 1024], [1024, 2048]], dtype=torch.int32, device=device)
k_ranges_tensor = torch.tensor([[0, 1024], [0, 2048]], dtype=torch.int32, device=device)
attn_type_map_tensor = torch.tensor([0, 1], dtype=torch.int32, device=device) # full mask for 1st slice, causal mask for 2nd

# --- Attention computation --- #

out, lse = flex_flash_attn_func(
    q=q,
    k=k,
    v=v,
    q_ranges=q_ranges_tensor,
    k_ranges=k_ranges_tensor,
    attn_type_map=attn_type_map_tensor,
    sink=sink, # Defaults to None to not apply attention sink
    softmax_scale=None, # Defaults to 1/sqrt(head_dim)
    softcap=0, # Defaults to 0
)

out.backward(do)

dq, dk, dv = q.grad, k.grad, v.grad
dsink = sink.grad if has_sink else None

Basic Usage for MagiAttention#

NOTE: You need to run the following examples in a distributed environment, e.g. using the common torchrun script

# run this python script with the command like:
# torchrun --standalone --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 ${SCRIPT_PATH}
import torch
import torch.nn as nn
import torch.distributed as dist

import magi_attention
from magi_attention.api import (
    magi_attn_flex_dispatch, calc_attn, undispatch, # interface functions
    compute_pad_size, # helper functions
)
from magi_attention.common import AttnRanges
from magi_attention.common.enum import AttnMaskType
from magi_attention.utils import setup_dist_env, clearup_dist_env

# --- Set up distributed environment --- #

rank, local_rank, world_size, num_nodes, num_local_ranks, world_group, device, seed = setup_dist_env()

# --- Define attention config --- #

total_seqlen = 32 * 1024   # 32k tokens, if we dispatch it to 8 GPUs, then each GPU holds 4k tokens
seqlen_sink = 4            # 4 sink tokens
num_heads_q = 48           # number of attention (query) heads
num_heads_kv = 8           # number of key/value heads (GQA)
head_dim = 128             # dimension of each attention head
chunk_size = 512           # chunk size to chunk the input tensor x along the seqlen dim for dispatch to control the granularity of computation load-balance.
dtype = torch.bfloat16     # attention activation / computation dtype (while the reduction dtype for partial attention outputs is always fp32 for magi_attention right now)
has_sink = True            # whether to apply attention sink

# --- Initialize token embedding tensor --- #

embed_dim = 4096
x = torch.randn(total_seqlen, embed_dim, device=device, dtype=dtype, requires_grad=True)

# --- Initialize MagiAttention meta configs for customized attention mask --- #

# the following customized attention mask is known as `block-causal` mask where `block_size` = 4096 (4k),
# which looks like (`*` for unmasked, `0` for masked):
#     - - - - - - - - -> (k)
#   | * * 0 0 0 0 0 0
#   | * * 0 0 0 0 0 0
#   | * * * * 0 0 0 0
#   | * * * * 0 0 0 0
#   | * * * * * * 0 0
#   | * * * * * * 0 0
#   | * * * * * * * *
#   | * * * * * * * *
#   V
#  (q)
q_ranges = AttnRanges.from_ranges(
    [
        [0, 4096], # 0~4k
        [4096, 8192], # 4k~8k
        [8192, 12288], # 8k~12k
        [12288, 16384], # 12k~16k
        [16384, 20480], # 16k~20k
        [20480, 24576], # 20k~24k
        [24576, 28672], # 24k~28k
        [28672, 32768], # 28k~32k
    ]
)
k_ranges = AttnRanges.from_ranges(
    [
        [0, 4096], # 0~4k
        [0, 8192], # 0~8k
        [0, 12288], # 0~12k
        [0, 16384], # 0~16k
        [0, 20480], # 0~20k
        [0, 24576], # 0~24k
        [0, 28672], # 0~28k
        [0, 32768], # 0~32k
    ]
)
attn_mask_type = [AttnMaskType.FULL] * len(q_ranges)
total_seqlen_q = total_seqlen_k = total_seqlen
pad_size = compute_pad_size( # pad embeds along seqlen dim for better performance
    total_seqlen_q=total_seqlen_q,
    cp_size=world_size, # assuming we only have 1-dim context parallelism (cp)
    chunk_size=chunk_size,
)

# --- Dispatch token embedding tensor along seqlen dim to multiple ranks --- #

# NOTE:
# 1. the dispatched local token embedding may be shuffled along seqlen dim,
#    so it's safe for token-wise operations such as matmul, layer-norm, etc
#    while for sample-wise operations like RoPE, you might need to be more careful
# 2. the `magi_attn_runtime_key` holds some inner meta data as one argument for many other magi_attention APIs,
#    which users don’t have to bother with
local_x, magi_attn_runtime_key = magi_attn_flex_dispatch(
    x,
    q_ranges=q_ranges,
    k_ranges=k_ranges,
    attn_mask_type=attn_mask_type,
    total_seqlen_q=total_seqlen_q,
    total_seqlen_k=total_seqlen_k,
    pad_size=pad_size,
    chunk_size=chunk_size,
    cp_group_or_mesh=world_group, # assuming we only have 1-dim context parallelism (cp)
)

# --- Simulate QKV projection --- #

q_proj = nn.Linear(embed_dim, num_heads_q * head_dim, dtype=dtype, device=device)
k_proj = nn.Linear(embed_dim, num_heads_kv * head_dim, dtype=dtype, device=device)
v_proj = nn.Linear(embed_dim, num_heads_kv * head_dim, dtype=dtype, device=device)

local_q = q_proj(local_x).view(-1, num_heads_q, head_dim)
local_k = k_proj(local_x).view(-1, num_heads_kv, head_dim)
local_v = v_proj(local_x).view(-1, num_heads_kv, head_dim)

# --- Simulate attention sink parameter --- #

global_sink = nn.Parameter(torch.randn(seqlen_sink, num_heads_q, dtype=torch.float32, device=device)) if has_sink else None

# --- Distributed attention computation --- #

local_out, local_lse = calc_attn(
    q=local_q,
    k=local_k,
    v=local_v,
    key=magi_attn_runtime_key,
    sink=global_sink, # Defaults to None to not apply attention sink
)

# --- Undispatch the output tensor along seqlen dim from multiple ranks and unpad --- #

# NOTE: the undispatch API may not be used until the moment you need the seqlen dimension to be compelete and ordered,
# e.g. for either aforementioned sample-wise operations, or loss computation
total_out = undispatch(
    x=local_out,
    key=magi_attn_runtime_key,
)

# --- Simulate loss computation --- #

loss = total_out.sum()

# --- Simulate backward pass --- #

loss.backward()

dx = x.grad
dq_proj, dk_proj, dv_proj = q_proj.weight.grad, k_proj.weight.grad, v_proj.weight.grad

if has_sink:
    dsink = global_sink.grad
    # NOTE: since usually the training framework such as Megatron-LM, FSDP
    # will handle the reduction of parameters' gradients across the whole dp x cp group
    # so by default, MagiAttention will skip the reduction of sink's gradients
    # unless the users specify the environment variable `MAGI_ATTENTION_DSINK_ALL_REDUCE_OP` (see our docs for more details)
    if (op:=magi_attention.comm.dsink_all_reduce_op()) != "none":
        match op:
            case "sum":
                dist.all_reduce(dsink, op=dist.ReduceOp.SUM, group=world_group)
            case "avg":
                dist.all_reduce(dsink, op=dist.ReduceOp.AVG, group=world_group)
            case _:
                raise ValueError(f"Unknown all_reduce_op: {op}")

# --- Clear up distributed environment --- #

clearup_dist_env()