CUDA 算子工程:手写 FlashAttention v2 之路

第 17 章 TMA + Warp Specialization 把 FA2 写到 SOTA

作者 杨艺韬 · 2,169 字

第 17 章 TMA + Warp Specialization 把 FA2 写到 SOTA

"Hopper is not Ampere with a faster Tensor Core. It is a different programming model — and FlashAttention v3 is the first proof that the new model genuinely matters." ——FA3 论文社区评论中的常见判断

17.1 为什么 Hopper 上的写法不一样

第 15 章我们写的 FA2 forward 用的是 Ampere 时代的工具:cp.async 异步拷贝 + mma.sync.m16n8k16 矩阵乘。这套工具在 H100 上还能用,但只能跑到 ~530 TFLOPs(54% Tensor Core 峰值)。

如果换成 Hopper 原生工具:

性能能从 ~530 TFLOPs 提升到 ~740 TFLOPs(FP16)~1200 TFLOPs(FP8)

数据来源:Shah et al., FlashAttention-3, 2024.

为什么差距这么大?三个原因:

  1. TMA 比 cp.async 更高效:单线程发起、专用硬件、原生 swizzle、不占 SIMT 算术单元。
  2. WGMMA 是异步指令:发完不阻塞,warp 可以继续算/拷下一份。
  3. Warp Specialization 把"算"和"拷"真正分离:Producer 永远在拷,Consumer 永远在算,硬件流水拉满。

17.2 Producer / Consumer 的角色分配

FA2 在 Hopper 上的核心 idea 是把 4 个 warp(128 thread)拆成两类:

flowchart TB
  subgraph PRO [Producer Warp · 1 个 warp 32 线程]
    P1[发起 TMA: K, V tile]
    P2[mbarrier.arrive 通知 consumer]
  end
  subgraph CONS [Consumer Warp Group · 3 个 warp 96 线程]
    C1[mbarrier.wait 等数据]
    C2[WGMMA 发起 S = Q @ K^T]
    C3[Online softmax]
    C4[WGMMA 发起 O += P @ V]
  end
  PRO -->|信号: tile k ready| CONS
  CONS -->|信号: tile k consumed| PRO

Producer 和 Consumer 在物理上是同一个 thread block 的不同 warp,通过 mbarrier 同步。它们各自专注自己的事,硬件层面真正异步并行。

17.3 TMA Descriptor 的构建

TMA 的关键是预先构建 TMA Descriptor——一份描述张量布局、stride、swizzle 模式的元数据。Descriptor 存在 GPU 全局内存(device memory)中,每次 TMA 指令引用它。

构建 TMA descriptor 在 host 端完成:

// Host 端构建 TMA descriptor
CUtensorMap tma_desc_K;
cuTensorMapEncodeTiled(
    &tma_desc_K,
    CU_TENSOR_MAP_DATA_TYPE_FLOAT16,
    /*tensorRank=*/4,                          // [B, H, N, d]
    K_global_ptr,
    /*tensorSize=*/{d, N, H, B},
    /*tensorStride=*/{d * 2, N * d * 2, ...},  // 字节
    /*boxSize=*/{d, Bc, 1, 1},                 // 每次拷贝的 tile
    /*elementStrides=*/{1, 1, 1, 1},
    CU_TENSOR_MAP_INTERLEAVE_NONE,
    CU_TENSOR_MAP_SWIZZLE_128B,                // 128B swizzle
    CU_TENSOR_MAP_L2_PROMOTION_L2_128B,
    CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE
);
// 把 tma_desc_K 拷贝到 device memory, 让 kernel 能访问

cuTensorMapEncodeTiled 是 CUDA 12.0+ 的 API,专门用来构建 TMA descriptor。

Kernel 端用 descriptor 发 TMA:

__global__ void fa3_fwd(
    const __grid_constant__ CUtensorMap tma_desc_Q,
    const __grid_constant__ CUtensorMap tma_desc_K,
    const __grid_constant__ CUtensorMap tma_desc_V,
    half* O,
    ...
) {
    extern __shared__ alignas(128) half smem[];
    half* sQ = smem;
    half* sK[STAGES];   // 多 stage pipeline
    half* sV[STAGES];
    // 设置 sK[i], sV[i] 指针 ...

    __shared__ alignas(8) uint64_t mbar[STAGES * 2];   // 每 stage 两个 barrier (full/empty)
    if (threadIdx.x == 0) {
        for (int s = 0; s < STAGES * 2; ++s) {
            mbarrier_init(&mbar[s], /*count=*/...);
        }
    }
    __syncthreads();

    // 角色分工
    int warp_id = threadIdx.x / 32;
    if (warp_id == 0) {
        // Producer warp
        producer_main(tma_desc_K, tma_desc_V, sK, sV, mbar);
    } else {
        // Consumer warp group (warp 1-3)
        consumer_main(sQ, sK, sV, mbar, /* O accumulator */);
    }
}

__grid_constant__ 是 Hopper 引入的 const 参数修饰,让 TMA descriptor 可以高效传入 kernel。

17.4 Producer Warp 的工作

__device__ void producer_main(
    const CUtensorMap& tma_desc_K,
    const CUtensorMap& tma_desc_V,
    half** sK, half** sV,
    uint64_t* mbar
) {
    if (threadIdx.x % 32 != 0) return;  // 只 lane 0 发 TMA

    int n_k_tiles = N / Bc;
    for (int k = 0; k < n_k_tiles; ++k) {
        int stage = k % STAGES;

        // 等当前 stage 被 consumer 消费完 (empty barrier)
        mbarrier_wait(&mbar[stage * 2 + 1], /* phase = ... */);

        // 发起 K[k], V[k] 的 TMA
        cp_async_bulk_tensor_2d(
            sK[stage], &tma_desc_K, k * Bc, /*head_offset*/, /*batch_offset*/,
            &mbar[stage * 2 + 0]   // 完成时通知 full barrier
        );
        cp_async_bulk_tensor_2d(
            sV[stage], &tma_desc_V, k * Bc, ..., &mbar[stage * 2 + 0]
        );
    }
}

Producer 的循环极其简单——它只有一件事:发起 TMA、等 consumer 消费完、发下一个。一个 warp 32 线程,但只有 lane 0 真正干活,其余 31 lane 闲置。这看起来浪费,但因为 producer 不做计算(不占 ALU),实际硬件资源浪费很小——CUDA 在 hopper 上引入 setmaxnreg.dec 让 producer warp 把寄存器配额还回去,给 consumer 用。

// Producer warp 在开始时把寄存器配额降到最小 (24 个)
asm("setmaxnreg.dec.sync.aligned.u32 24;\n");

这一行 PTX 让 producer 把寄存器从默认 ~64 降到 24,省下的寄存器全部给 consumer。Hopper 上 consumer 因此能拿到 ~120 个寄存器/线程,足够存大量 fragment。

17.5 Consumer Warp Group 的工作

__device__ void consumer_main(
    half* sQ, half** sK, half** sV,
    uint64_t* mbar,
    /* O accumulator */ float* O_acc, float* row_max, float* row_sum
) {
    // Consumer warp 把寄存器配额提到最高 (240)
    asm("setmaxnreg.inc.sync.aligned.u32 240;\n");

    int n_k_tiles = N / Bc;
    for (int k = 0; k < n_k_tiles; ++k) {
        int stage = k % STAGES;

        // 等 producer 拷完当前 stage (full barrier)
        mbarrier_wait(&mbar[stage * 2 + 0], /* phase = ... */);

        // ============ S = Q @ K^T (WGMMA) ============
        float S_acc[MMAS_M * MMAS_N * 4] = {0};
        wgmma_fence();
        for (int kk = 0; kk < d; kk += 16) {
            wgmma_mma_async_m64n64k16(
                S_acc, sQ + kk_offset, sK[stage] + kk_offset, /*scale_d=*/0
            );
        }
        wgmma_commit_group();
        wgmma_wait_group(/*N=*/0);  // 等 WGMMA 完成

        // ============ Online softmax ============
        // 与第 15 章一样, 但用 fragment level reduce
        update_softmax_state(S_acc, row_max, row_sum, alpha);
        scale_O_by_alpha(O_acc, alpha);

        // ============ O += P @ V (WGMMA) ============
        // 把 S_acc cast 为 fp16 P_acc
        cast_S_to_P_fp16(S_acc, P_fp16_smem);

        wgmma_fence();
        wgmma_mma_async_m64n_d_k16(
            O_acc, P_fp16_smem, sV[stage], /*scale_d=*/1
        );
        wgmma_commit_group();
        wgmma_wait_group(0);

        // 通知 producer 这个 stage 已消费完
        mbarrier_arrive(&mbar[stage * 2 + 1]);
    }

    // 最后归一化 O_acc /= row_sum, 写到 O HBM
    finalize_and_write_O(O_acc, row_sum, ...);
}

几个关键点:

  1. WGMMA 是异步的wgmma_mma_async 发完不阻塞,需要 wgmma_commit_group + wgmma_wait_group 显式同步。
  2. wgmma_fence 在每组 wgmma 之前调用,确保前面的寄存器写入对 wgmma 可见。
  3. mbarrier 同步:consumer 用 mbarrier_wait 等 producer,用 mbarrier_arrive 通知 producer。

17.6 Pipeline Depth 的选择

STAGES(pipeline 深度)是关键参数:

更多 stage 意味着更多 SMEM 占用:

每 stage SMEM = 2 * (Bc * d * 2 byte) = 2 * Bc * d * 2

Bc=64, d=64: 每 stage 16 KB
STAGES=4:    64 KB SMEM 仅 K/V 缓冲
+ Q tile, O accumulator etc., 总 SMEM ~96 KB  (可在 H100 228 KB SMEM 内)

更多 stage 意味着更高的"流水深度",能更好容忍 producer/consumer 速度不匹配。FA3 用 4 stage 是因为 H100 的 TMA 延迟 + WGMMA 延迟综合下来需要 4 stage 才完美重叠。

17.7 FP8 的特殊处理

FA3 的另一个关键创新是支持 FP8 GEMM

wgmma_mma_async_e4m3_e4m3_f32_m64n64k32(
    accumulator,
    fp8_a_smem, fp8_b_smem,
    /*scale_d=*/0
);

FP8 WGMMA 的 K 维一次性算 32(FP16 是 16)——单条指令算力加倍。但 FP8 需要 per-tensor 或 per-token 的 scale,且数值精度低。

FA3 对 FP8 做了几个工程化处理:

  1. Per-block scale:每个 K tile 自带一个 scale,避免单一 scale 损失精度。
  2. Q 保持 FP16:S = Q(FP16) @ K(FP8) 输出 FP32,精度 OK。
  3. PV 用 FP8:但 P 是 softmax 输出,先转 FP8 再 mma。

这些细节让 FA3 能在保持精度的同时,把 FP8 算力用到 ~80% 峰值(1200 TFLOPs / 1500 实际有效峰值,因为 FP8 也有 overhead)。

17.8 性能跃迁实测

H100, FP16, head_dim=64:

实现 TFLOPs % Tensor Core peak
FA1 (cp.async + mma.sync) ~280 28%
FA2 (cp.async + mma.sync, better warp split) ~530 54%
FA3 (TMA + WGMMA + Warp Spec) ~740 75%
FA3 FP8 ~1200 60% of FP8 peak

来源:Shah et al., FA3 论文 Figure 5.

FA1 → FA3 性能提升 2.6×——没有改变算法,全部来自硬件特性的更好利用

17.9 Hopper → Blackwell 迁移

Blackwell(B200)相比 Hopper 的关键变化:

  1. 第 5 代 Tensor Core:增加 FP4 支持,FP4 算力是 FP8 的 2×。
  2. 第二代 TMA:支持更大 tile 和更复杂的 swizzle 模式。
  3. CTA Pair:两个 thread block 物理上配对,共享 SMEM。

迁移策略:

具体迁移工作量:~10-20% 代码改动。但工业上 CUTLASS 4.x 已经把 Blackwell 适配做了,用 CUTLASS 写 GEMM/Attention 会自动获得 Blackwell 优化。

17.10 这一章的小结与下一章

第 17 章是本书技术深度的高峰:

  1. TMA 替代 cp.async:单线程发起、专用硬件、不占 ALU。
  2. WGMMA 替代 mma.sync:单条指令算 64×128×16,异步执行。
  3. Warp Specialization:1 producer + 3 consumer warp-group,物理硬件并行。
  4. mbarrier 同步:producer/consumer 之间用 phase 切换的同步机制。
  5. setmaxnreg:动态调整 warp 寄存器配额,让 consumer 拿到更多寄存器。
  6. STAGES=4 流水深度:足够掩盖 TMA + WGMMA 的延迟。
  7. FA3 在 H100 上做到 740 TFLOPs(FP16)/ 1200 TFLOPs(FP8)——是 FA1 的 2.6×。

第 18 章我们回到一个更广的话题——Persistent Kernel。Persistent kernel 是另一种"永远活着"的 kernel 模式:grid_size 固定为 SM 数,每个 block 通过 grid-stride loop 处理多个 tile。这种模式对小 tile 工作负载(比如 LLM 推理 decoding 阶段的小 batch)特别有效。读完第 18 章,第四篇结束,读者就完成了从基础 kernel 到 SOTA FA2 的完整训练。

本章动手练习

  1. 构建一个 TMA descriptor,发起一次 TMA 拷贝,观察 SMEM 中的 swizzle 布局。
  2. 实现一个最简化的 Producer/Consumer kernel(单 K tile,纯 GEMM),熟悉 mbarrier 同步。
  3. 阅读 FA3 官方实现 flash-attention/csrc/flash_attn/flash_fwd_kernel_sm90.h,对照本章描述的概念找代码位置。