CUDA 算子工程:手写 FlashAttention v2 之路
第 17 章 TMA + Warp Specialization 把 FA2 写到 SOTA
第 17 章 TMA + Warp Specialization 把 FA2 写到 SOTA
"Hopper is not Ampere with a faster Tensor Core. It is a different programming model — and FlashAttention v3 is the first proof that the new model genuinely matters." ——FA3 论文社区评论中的常见判断
17.1 为什么 Hopper 上的写法不一样
第 15 章我们写的 FA2 forward 用的是 Ampere 时代的工具:cp.async 异步拷贝 + mma.sync.m16n8k16 矩阵乘。这套工具在 H100 上还能用,但只能跑到 ~530 TFLOPs(54% Tensor Core 峰值)。
如果换成 Hopper 原生工具:
- TMA 替代
cp.async - WGMMA 替代
mma.sync - Warp Specialization 替代对称 warp
性能能从 ~530 TFLOPs 提升到 ~740 TFLOPs(FP16) 或 ~1200 TFLOPs(FP8)。
数据来源:Shah et al., FlashAttention-3, 2024.
为什么差距这么大?三个原因:
- TMA 比 cp.async 更高效:单线程发起、专用硬件、原生 swizzle、不占 SIMT 算术单元。
- WGMMA 是异步指令:发完不阻塞,warp 可以继续算/拷下一份。
- Warp Specialization 把"算"和"拷"真正分离:Producer 永远在拷,Consumer 永远在算,硬件流水拉满。
17.2 Producer / Consumer 的角色分配
FA2 在 Hopper 上的核心 idea 是把 4 个 warp(128 thread)拆成两类:
flowchart TB
subgraph PRO [Producer Warp · 1 个 warp 32 线程]
P1[发起 TMA: K, V tile]
P2[mbarrier.arrive 通知 consumer]
end
subgraph CONS [Consumer Warp Group · 3 个 warp 96 线程]
C1[mbarrier.wait 等数据]
C2[WGMMA 发起 S = Q @ K^T]
C3[Online softmax]
C4[WGMMA 发起 O += P @ V]
end
PRO -->|信号: tile k ready| CONS
CONS -->|信号: tile k consumed| PRO
Producer 和 Consumer 在物理上是同一个 thread block 的不同 warp,通过 mbarrier 同步。它们各自专注自己的事,硬件层面真正异步并行。
17.3 TMA Descriptor 的构建
TMA 的关键是预先构建 TMA Descriptor——一份描述张量布局、stride、swizzle 模式的元数据。Descriptor 存在 GPU 全局内存(device memory)中,每次 TMA 指令引用它。
构建 TMA descriptor 在 host 端完成:
// Host 端构建 TMA descriptor
CUtensorMap tma_desc_K;
cuTensorMapEncodeTiled(
&tma_desc_K,
CU_TENSOR_MAP_DATA_TYPE_FLOAT16,
/*tensorRank=*/4, // [B, H, N, d]
K_global_ptr,
/*tensorSize=*/{d, N, H, B},
/*tensorStride=*/{d * 2, N * d * 2, ...}, // 字节
/*boxSize=*/{d, Bc, 1, 1}, // 每次拷贝的 tile
/*elementStrides=*/{1, 1, 1, 1},
CU_TENSOR_MAP_INTERLEAVE_NONE,
CU_TENSOR_MAP_SWIZZLE_128B, // 128B swizzle
CU_TENSOR_MAP_L2_PROMOTION_L2_128B,
CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE
);
// 把 tma_desc_K 拷贝到 device memory, 让 kernel 能访问
cuTensorMapEncodeTiled 是 CUDA 12.0+ 的 API,专门用来构建 TMA descriptor。
Kernel 端用 descriptor 发 TMA:
__global__ void fa3_fwd(
const __grid_constant__ CUtensorMap tma_desc_Q,
const __grid_constant__ CUtensorMap tma_desc_K,
const __grid_constant__ CUtensorMap tma_desc_V,
half* O,
...
) {
extern __shared__ alignas(128) half smem[];
half* sQ = smem;
half* sK[STAGES]; // 多 stage pipeline
half* sV[STAGES];
// 设置 sK[i], sV[i] 指针 ...
__shared__ alignas(8) uint64_t mbar[STAGES * 2]; // 每 stage 两个 barrier (full/empty)
if (threadIdx.x == 0) {
for (int s = 0; s < STAGES * 2; ++s) {
mbarrier_init(&mbar[s], /*count=*/...);
}
}
__syncthreads();
// 角色分工
int warp_id = threadIdx.x / 32;
if (warp_id == 0) {
// Producer warp
producer_main(tma_desc_K, tma_desc_V, sK, sV, mbar);
} else {
// Consumer warp group (warp 1-3)
consumer_main(sQ, sK, sV, mbar, /* O accumulator */);
}
}
__grid_constant__ 是 Hopper 引入的 const 参数修饰,让 TMA descriptor 可以高效传入 kernel。
17.4 Producer Warp 的工作
__device__ void producer_main(
const CUtensorMap& tma_desc_K,
const CUtensorMap& tma_desc_V,
half** sK, half** sV,
uint64_t* mbar
) {
if (threadIdx.x % 32 != 0) return; // 只 lane 0 发 TMA
int n_k_tiles = N / Bc;
for (int k = 0; k < n_k_tiles; ++k) {
int stage = k % STAGES;
// 等当前 stage 被 consumer 消费完 (empty barrier)
mbarrier_wait(&mbar[stage * 2 + 1], /* phase = ... */);
// 发起 K[k], V[k] 的 TMA
cp_async_bulk_tensor_2d(
sK[stage], &tma_desc_K, k * Bc, /*head_offset*/, /*batch_offset*/,
&mbar[stage * 2 + 0] // 完成时通知 full barrier
);
cp_async_bulk_tensor_2d(
sV[stage], &tma_desc_V, k * Bc, ..., &mbar[stage * 2 + 0]
);
}
}
Producer 的循环极其简单——它只有一件事:发起 TMA、等 consumer 消费完、发下一个。一个 warp 32 线程,但只有 lane 0 真正干活,其余 31 lane 闲置。这看起来浪费,但因为 producer 不做计算(不占 ALU),实际硬件资源浪费很小——CUDA 在 hopper 上引入 setmaxnreg.dec 让 producer warp 把寄存器配额还回去,给 consumer 用。
// Producer warp 在开始时把寄存器配额降到最小 (24 个)
asm("setmaxnreg.dec.sync.aligned.u32 24;\n");
这一行 PTX 让 producer 把寄存器从默认 ~64 降到 24,省下的寄存器全部给 consumer。Hopper 上 consumer 因此能拿到 ~120 个寄存器/线程,足够存大量 fragment。
17.5 Consumer Warp Group 的工作
__device__ void consumer_main(
half* sQ, half** sK, half** sV,
uint64_t* mbar,
/* O accumulator */ float* O_acc, float* row_max, float* row_sum
) {
// Consumer warp 把寄存器配额提到最高 (240)
asm("setmaxnreg.inc.sync.aligned.u32 240;\n");
int n_k_tiles = N / Bc;
for (int k = 0; k < n_k_tiles; ++k) {
int stage = k % STAGES;
// 等 producer 拷完当前 stage (full barrier)
mbarrier_wait(&mbar[stage * 2 + 0], /* phase = ... */);
// ============ S = Q @ K^T (WGMMA) ============
float S_acc[MMAS_M * MMAS_N * 4] = {0};
wgmma_fence();
for (int kk = 0; kk < d; kk += 16) {
wgmma_mma_async_m64n64k16(
S_acc, sQ + kk_offset, sK[stage] + kk_offset, /*scale_d=*/0
);
}
wgmma_commit_group();
wgmma_wait_group(/*N=*/0); // 等 WGMMA 完成
// ============ Online softmax ============
// 与第 15 章一样, 但用 fragment level reduce
update_softmax_state(S_acc, row_max, row_sum, alpha);
scale_O_by_alpha(O_acc, alpha);
// ============ O += P @ V (WGMMA) ============
// 把 S_acc cast 为 fp16 P_acc
cast_S_to_P_fp16(S_acc, P_fp16_smem);
wgmma_fence();
wgmma_mma_async_m64n_d_k16(
O_acc, P_fp16_smem, sV[stage], /*scale_d=*/1
);
wgmma_commit_group();
wgmma_wait_group(0);
// 通知 producer 这个 stage 已消费完
mbarrier_arrive(&mbar[stage * 2 + 1]);
}
// 最后归一化 O_acc /= row_sum, 写到 O HBM
finalize_and_write_O(O_acc, row_sum, ...);
}
几个关键点:
- WGMMA 是异步的:
wgmma_mma_async发完不阻塞,需要wgmma_commit_group+wgmma_wait_group显式同步。 - wgmma_fence 在每组 wgmma 之前调用,确保前面的寄存器写入对 wgmma 可见。
- mbarrier 同步:consumer 用
mbarrier_wait等 producer,用mbarrier_arrive通知 producer。
17.6 Pipeline Depth 的选择
STAGES(pipeline 深度)是关键参数:
- STAGES=2:经典 double buffer。SMEM 占用最小,但拷贝完全等于计算时延才能完美重叠。
- STAGES=3:三缓冲。多一份缓冲能容忍轻微的不对齐。
- STAGES=4:四缓冲。FA3 默认。
更多 stage 意味着更多 SMEM 占用:
每 stage SMEM = 2 * (Bc * d * 2 byte) = 2 * Bc * d * 2
Bc=64, d=64: 每 stage 16 KB
STAGES=4: 64 KB SMEM 仅 K/V 缓冲
+ Q tile, O accumulator etc., 总 SMEM ~96 KB (可在 H100 228 KB SMEM 内)
更多 stage 意味着更高的"流水深度",能更好容忍 producer/consumer 速度不匹配。FA3 用 4 stage 是因为 H100 的 TMA 延迟 + WGMMA 延迟综合下来需要 4 stage 才完美重叠。
17.7 FP8 的特殊处理
FA3 的另一个关键创新是支持 FP8 GEMM:
wgmma_mma_async_e4m3_e4m3_f32_m64n64k32(
accumulator,
fp8_a_smem, fp8_b_smem,
/*scale_d=*/0
);
FP8 WGMMA 的 K 维一次性算 32(FP16 是 16)——单条指令算力加倍。但 FP8 需要 per-tensor 或 per-token 的 scale,且数值精度低。
FA3 对 FP8 做了几个工程化处理:
- Per-block scale:每个 K tile 自带一个 scale,避免单一 scale 损失精度。
- Q 保持 FP16:S = Q(FP16) @ K(FP8) 输出 FP32,精度 OK。
- PV 用 FP8:但 P 是 softmax 输出,先转 FP8 再 mma。
这些细节让 FA3 能在保持精度的同时,把 FP8 算力用到 ~80% 峰值(1200 TFLOPs / 1500 实际有效峰值,因为 FP8 也有 overhead)。
17.8 性能跃迁实测
H100, FP16, head_dim=64:
| 实现 | TFLOPs | % Tensor Core peak |
|---|---|---|
| FA1 (cp.async + mma.sync) | ~280 | 28% |
| FA2 (cp.async + mma.sync, better warp split) | ~530 | 54% |
| FA3 (TMA + WGMMA + Warp Spec) | ~740 | 75% |
| FA3 FP8 | ~1200 | 60% of FP8 peak |
来源:Shah et al., FA3 论文 Figure 5.
FA1 → FA3 性能提升 2.6×——没有改变算法,全部来自硬件特性的更好利用。
17.9 Hopper → Blackwell 迁移
Blackwell(B200)相比 Hopper 的关键变化:
- 第 5 代 Tensor Core:增加 FP4 支持,FP4 算力是 FP8 的 2×。
- 第二代 TMA:支持更大 tile 和更复杂的 swizzle 模式。
- CTA Pair:两个 thread block 物理上配对,共享 SMEM。
迁移策略:
- TMA 描述符接口几乎一样:换枚举值即可。
- WGMMA 升级到 mma.sm100.async:API 类似,K 维大小不同。
- Warp Specialization 框架不变:producer/consumer 模式继续用。
- CTA Pair 是新东西:可以让 attention kernel 进一步增大有效 tile。
具体迁移工作量:~10-20% 代码改动。但工业上 CUTLASS 4.x 已经把 Blackwell 适配做了,用 CUTLASS 写 GEMM/Attention 会自动获得 Blackwell 优化。
17.10 这一章的小结与下一章
第 17 章是本书技术深度的高峰:
- TMA 替代 cp.async:单线程发起、专用硬件、不占 ALU。
- WGMMA 替代 mma.sync:单条指令算 64×128×16,异步执行。
- Warp Specialization:1 producer + 3 consumer warp-group,物理硬件并行。
- mbarrier 同步:producer/consumer 之间用 phase 切换的同步机制。
- setmaxnreg:动态调整 warp 寄存器配额,让 consumer 拿到更多寄存器。
- STAGES=4 流水深度:足够掩盖 TMA + WGMMA 的延迟。
- FA3 在 H100 上做到 740 TFLOPs(FP16)/ 1200 TFLOPs(FP8)——是 FA1 的 2.6×。
第 18 章我们回到一个更广的话题——Persistent Kernel。Persistent kernel 是另一种"永远活着"的 kernel 模式:grid_size 固定为 SM 数,每个 block 通过 grid-stride loop 处理多个 tile。这种模式对小 tile 工作负载(比如 LLM 推理 decoding 阶段的小 batch)特别有效。读完第 18 章,第四篇结束,读者就完成了从基础 kernel 到 SOTA FA2 的完整训练。
本章动手练习:
- 构建一个 TMA descriptor,发起一次 TMA 拷贝,观察 SMEM 中的 swizzle 布局。
- 实现一个最简化的 Producer/Consumer kernel(单 K tile,纯 GEMM),熟悉 mbarrier 同步。
- 阅读 FA3 官方实现
flash-attention/csrc/flash_attn/flash_fwd_kernel_sm90.h,对照本章描述的概念找代码位置。