CUDA 算子工程:手写 FlashAttention v2 之路

第 12 章 Tensor Core GEMM:mma.sync 与 ldmatrix

作者 杨艺韬 · 2,295 字

第 12 章 Tensor Core GEMM:mma.sync 与 ldmatrix

"Tensor Core is not just a faster FMA — it is a different programming model. Once you grok the fragment / mma / swizzle triangle, modern GPU programming opens up." ——CUTLASS 团队的内部分享

12.1 为什么 Tensor Core 是必经之路

第 11 章我们把 SIMT GEMM 推到了 37% 算力峰值。但回顾 Hopper 算力:

FP32 SIMT 峰值:        67 TFLOPs/s
FP16 Tensor Core 峰值: 989 TFLOPs/s
FP8  Tensor Core 峰值: 1979 TFLOPs/s

Tensor Core 比 SIMT 快 15× 到 30×。任何严肃的 LLM 训练 / 推理都必须用 Tensor Core——这不是优化选项,是入场券

但 Tensor Core 不是一个"快版本的 FMA 指令"——它是一个全新的编程模型

这一章我们把这套新的编程模型彻底讲透。

12.2 mma.sync:一条指令算一个矩阵乘

Ampere+ 上的核心 Tensor Core 指令是:

mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32
   D, A, B, C

含义:

每条指令的浮点操作数:

16×8×16×2=4096 FLOPs16 \times 8 \times 16 \times 2 = 4096 \text{ FLOPs}

一个 SM 4 个 Tensor Core,每周期一条 mma.sync = 每周期 16384 FLOPs/SM。乘以 132 SM × 1.83 GHz = 989 TFLOPs——这就是 H100 FP16 峰值的来路。

12.2.1 Fragment 布局

最反直觉的部分:mma.sync 的 A、B、C、D 不是单个寄存器,而是一组寄存器,分布在 32 个线程上:

A (16×16, FP16) 共 256 个 fp16 = 512 字节 = 128 个 32-bit 寄存器。 分布在 32 lane 上,每 lane 4 个寄存器(128 / 32 = 4)。

具体的分布模式很复杂,由 NVIDIA 硬件规定:

A 的 fragment layout (m16n8k16, row-major):

         k=0..7              k=8..15
        ┌─────────────┐    ┌─────────────┐
m=0..7: │ T0  T1 ... T7│   │ T0  T1 ... T7│
        │ T8  T9 ...T15│   │ T8 ...    │
        ...
m=8..15:│ T16 T17 ...T23│  │ T16 ...   │
        │ T24 ...T31    │  │ T24 ...   │
        └─────────────┘    └─────────────┘
        每 lane 持有 2 个 fp16     每 lane 持有 2 个 fp16

也就是 lane 0 持有 A[0..1, 0..1] (4 个 fp16),lane 1 持有 A[0..1, 2..3],依此类推。

读者完全不需要记这个表——下一节的 ldmatrix 会自动按这个布局排好。但重要的是理解:fragment 不是连续存储,而是分布式存储

12.2.2 Inline PTX

CUDA C++ 写 mma.sync 用 inline PTX:

unsigned A[4];   // 4 个 32-bit, 每个 = 2 个 fp16, 共 8 个 fp16 (一行 fragment)
unsigned B[2];   // 同上
float C[4];      // 4 个 fp32 (输出 fragment 的一部分)

asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
    "{%0, %1, %2, %3}, "
    "{%4, %5, %6, %7}, "
    "{%8, %9}, "
    "{%0, %1, %2, %3};\n"
    : "+f"(C[0]), "+f"(C[1]), "+f"(C[2]), "+f"(C[3])
    : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
      "r"(B[0]), "r"(B[1]));

或者用 CUDA 11+ 的 nvcuda::wmma 包装(更高级 API,但灵活性差)。CUTLASS 用 inline PTX。

12.3 ldmatrix:把 SMEM 数据加载成 fragment

mma.sync 要求 fragment 已经在寄存器里,且按特定布局排列。怎么把 SMEM 数据装进 fragment?

最朴素的方式是每个线程自己 load:

unsigned A[4];
A[0] = reinterpret_cast<unsigned*>(&sA[m + lane_id / 4][k + (lane_id % 4) * 2])[0];
// ... 算地址再 load 4 次

地址计算超复杂,且每线程独立 load 会触发 bank conflict。

NVIDIA 提供了 ldmatrix 指令——一条指令把 SMEM 中一个 16×16 子块加载到 32 个 lane 的 fragment

unsigned A[4];
asm("ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
    "{%0, %1, %2, %3}, [%4];\n"
    : "=r"(A[0]), "=r"(A[1]), "=r"(A[2]), "=r"(A[3])
    : "l"(smem_ptr));

ldmatrix.x4 一次加载 4 个 8×8 fp16 子块(合计 16×16),输出 4 个寄存器/线程。32 lane × 4 寄存器 = 128 个寄存器 = 256 fp16 = 16×16 矩阵。完美匹配 mma.sync 的输入 fragment 布局

ldmatrix 还有一个变种 ldmatrix.x4.trans——加载时就地转置。这对加载 B 矩阵特别有用,因为 GEMM 需要 B^T 形式喂给 mma。

12.4 SMEM Layout 与 Swizzle

ldmatrix 期望 SMEM 中的数据按特定 layout 排列。如果 SMEM 是简单的 row-major,ldmatrix 会触发严重的 bank conflict——因为它一次访问 32 个不同的 SMEM 地址,如果这些地址都落在同一组 bank,性能腰斩。

NVIDIA 设计了一种预定义的 swizzle layout,让 ldmatrix 访问的地址自动错开 bank:

flowchart TB
  subgraph LinearLayout [Row-major Layout]
    L1[行 i=0:  col 0, 1, 2, 3, 4, 5, 6, 7,  8, 9, 10, ...]
    L2[行 i=1:  col 0, 1, 2, 3, 4, 5, 6, 7,  8, 9, 10, ...]
  end
  subgraph SwizzleLayout [Swizzled Layout 通过 row XOR col_high]
    S1[行 i=0:  col 0, 1, 2, 3, 4, 5, 6, 7,  8, 9, 10, ...]
    S2[行 i=1:  col 0,1,2,3 → col 8,9,10,11<br/>col 8..11 → col 0..3]
  end

简单说,swizzle 把每行的列按一个 XOR 函数重排:

__device__ int swizzle_idx(int row, int col, int row_size) {
    // 经典 NVIDIA 128-bit swizzle
    int phase = (row & 0x7) ^ ((col >> 3) & 0x7);
    return (col & ~0x7) | ((col & 0x7) ^ (row & 0x7));
}

具体实现还有几种变体,但核心思想都是用 row 的低位影响 col 的低位,让相邻行的同列元素落到不同 bank。

CUTLASS 提供了一组预定义的 swizzle layout(Swizzle<3,3,3>Swizzle<2,3,3> 等),名称对应不同的 row/col 偏移参数。第 13 章会展开。

12.5 完整的 Tensor Core GEMM 骨架

把 mma + ldmatrix + swizzle 拼起来,给一个工作的 HGEMM kernel:

template <int BM = 128, int BN = 128, int BK = 32>
__global__ void hgemm_tensorcore(
    const half* A, const half* B, half* C,
    int M, int N, int K
) {
    __shared__ half sA[BM * BK];   // 4 KB (128*32*2)
    __shared__ half sB[BN * BK];   // 4 KB

    const int tid = threadIdx.x;
    const int warp_id = tid / 32;
    const int lane_id = tid % 32;
    const int warp_m = warp_id / 2;   // 4 warps in M
    const int warp_n = warp_id % 2;   // 2 warps in N
    // 一个 block 8 warp, 处理 BM × BN = 128×128
    // 每 warp 64 × 64

    constexpr int WM = 64, WN = 64;
    constexpr int MMAS_M = WM / 16;  // 4
    constexpr int MMAS_N = WN / 8;   // 8

    // 累加器 fragment
    float c_frag[MMAS_M][MMAS_N][4] = {0};

    const int block_row = blockIdx.y * BM;
    const int block_col = blockIdx.x * BN;

    for (int k_step = 0; k_step < K; k_step += BK) {
        // 1. cp.async 加载 A_tile, B_tile 到 sA / sB (使用 swizzle layout)
        cp_async_load_a_tile(sA, A, block_row, k_step);
        cp_async_load_b_tile(sB, B, block_col, k_step);
        cp_async_commit_and_wait();
        __syncthreads();

        // 2. 内层 K (BK / 16 个 mma 步)
        for (int kk = 0; kk < BK; kk += 16) {
            // 用 ldmatrix 加载 A fragments
            unsigned a_frag[MMAS_M][4];
            #pragma unroll
            for (int i = 0; i < MMAS_M; ++i) {
                int row_offset = warp_m * WM + i * 16;
                ldmatrix_x4(sA, row_offset, kk, &a_frag[i]);
            }

            // ldmatrix 加载 B fragments (with .trans for column-major)
            unsigned b_frag[MMAS_N][2];
            #pragma unroll
            for (int j = 0; j < MMAS_N; ++j) {
                int col_offset = warp_n * WN + j * 8;
                ldmatrix_x2_trans(sB, col_offset, kk, &b_frag[j]);
            }

            // 3. mma.sync 累加
            #pragma unroll
            for (int i = 0; i < MMAS_M; ++i)
                #pragma unroll
                for (int j = 0; j < MMAS_N; ++j) {
                    asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
                        "{%0, %1, %2, %3},"
                        "{%4, %5, %6, %7},"
                        "{%8, %9},"
                        "{%0, %1, %2, %3};\n"
                        : "+f"(c_frag[i][j][0]), "+f"(c_frag[i][j][1]),
                          "+f"(c_frag[i][j][2]), "+f"(c_frag[i][j][3])
                        : "r"(a_frag[i][0]), "r"(a_frag[i][1]),
                          "r"(a_frag[i][2]), "r"(a_frag[i][3]),
                          "r"(b_frag[j][0]), "r"(b_frag[j][1]));
                }
        }
        __syncthreads();
    }

    // 4. 写 C (epilogue: fp32 -> fp16, 写回 HBM)
    #pragma unroll
    for (int i = 0; i < MMAS_M; ++i)
        #pragma unroll
        for (int j = 0; j < MMAS_N; ++j) {
            int row = block_row + warp_m * WM + i * 16;
            int col = block_col + warp_n * WN + j * 8;
            // 每 lane 写它持有的 2 个 fp16
            int my_row = row + (lane_id / 4) + (lane_id % 4) * 0; // 简化
            int my_col = col + (lane_id % 4) * 2;
            half2 v;
            v.x = __float2half(c_frag[i][j][0]);
            v.y = __float2half(c_frag[i][j][1]);
            *reinterpret_cast<half2*>(&C[my_row * N + my_col]) = v;
            // ... 写其他 fragment 元素
        }
}

这段代码省略了细节(地址计算、swizzle 实现、写 C 的完整 epilogue),但骨架就是这样。完整可工作的代码在 CUTLASS 中:cutlass/gemm/threadblock/mma_pipelined.h

实测:H100, FP16, M=N=K=4096:

SIMT FP32 GEMM (第 11 章 v5):     ~25 TFLOPs    (37% of FP32 peak)
Tensor Core HGEMM (上面骨架):     ~600 TFLOPs   (60% of FP16 peak)
cuBLAS HGEMM:                     ~750 TFLOPs   (76% of FP16 peak)

到 60% 已经是不小的成就。剩余 16% 差距来自:double buffer 流水深度、CUTLASS 级别的细致 fragment 调度、PTX 微优化等。CUTLASS 把这些做到极致,能达到 80%+。

12.6 Hopper 升级:WGMMA

Hopper 引入 WGMMA(Warp-Group MMA)后,mma 指令的粒度从 warp-level 提升到 warp-group-level:

mma.sync.m16n8k16:        16×8×16 = 2048 FLOPs/指令, warp 级
wgmma.mma_async.m64n128k16:  64×128×16 = 131072 FLOPs/指令, warp-group 级

WGMMA 单条指令的算力是 mma.sync 的 64 倍——这意味着指令调度压力减少 64 倍,更易跑满 Tensor Core。

WGMMA 还是异步指令:

wgmma.mma_async ...;       // 发起异步矩阵乘
wgmma.commit_group;        // 提交一组
... 做别的事 ...
wgmma.wait_group 0;        // 等待完成

发起 wgmma 之后 warp 可以继续做别的事(比如 TMA 加载下一个 tile),等需要结果时再同步。这是 Hopper GEMM 性能跃迁的核心机制——算和拷贝真正流水起来

完整的 Hopper WGMMA GEMM 框架第 13 章 CUTLASS 部分会展开,第 17 章 FA2 SOTA 会用到。

12.7 这一章的小结与下一章

Tensor Core 是 GEMM 性能跃迁的关键:

  1. mma.sync 是矩阵级指令:单条指令算 16×8×16 = 2048 FLOPs。
  2. ldmatrix 是配套的矩阵 load 指令:把 SMEM 中的 16×16 子块加载到 fragment。
  3. Fragment 是分布式寄存器布局:32 lane 协作持有矩阵。
  4. SMEM Swizzle 防 bank conflict:CUTLASS 的标准 swizzle layout 解决了相邻行同列的 conflict 问题。
  5. WGMMA 是 Hopper 的升级:单条指令 64 倍算力 + 异步执行。

到这里,读者已经能写出一个达到 60-80% 峰值的 HGEMM。下一步是把这套手艺工业化——CUTLASS 把所有这些技巧抽象成可组合的 C++ 模板,让 NVIDIA 和工业界能用统一的工具构建各种 GEMM 变体(包括 FA2 内的 QK^T 和 PV)。

第 13 章我们剖析 CUTLASS 3.x 的设计哲学——CollectiveOp、CuTe Layout、Hopper Kernel Schedule。读完第 13 章读者会理解为什么 CUTLASS 的代码"看起来很复杂但实际上很优雅",并学会怎么读 CUTLASS 源码。

本章动手练习

  1. 实现一个最简版 mma.sync HGEMM(小尺寸 M=N=K=64),亲手写 inline PTX,体验 fragment 布局。
  2. 阅读 CUTLASS 的 mma_pipelined.h,看双缓冲 + ldmatrix + mma 是怎么组装的。
  3. 在 H100 上跑 cuBLAS HGEMM 和你的版本,用 Nsight Compute 看 sm__inst_executed_pipe_tensor.sum.per_cycle_active 指标——你的 kernel Tensor Core 利用率是多少?