CUDA 算子工程:手写 FlashAttention v2 之路
第 12 章 Tensor Core GEMM:mma.sync 与 ldmatrix
第 12 章 Tensor Core GEMM:mma.sync 与 ldmatrix
"Tensor Core is not just a faster FMA — it is a different programming model. Once you grok the fragment / mma / swizzle triangle, modern GPU programming opens up." ——CUTLASS 团队的内部分享
12.1 为什么 Tensor Core 是必经之路
第 11 章我们把 SIMT GEMM 推到了 37% 算力峰值。但回顾 Hopper 算力:
FP32 SIMT 峰值: 67 TFLOPs/s
FP16 Tensor Core 峰值: 989 TFLOPs/s
FP8 Tensor Core 峰值: 1979 TFLOPs/s
Tensor Core 比 SIMT 快 15× 到 30×。任何严肃的 LLM 训练 / 推理都必须用 Tensor Core——这不是优化选项,是入场券。
但 Tensor Core 不是一个"快版本的 FMA 指令"——它是一个全新的编程模型:
- 指令是矩阵级的:一条
mma.sync算 16×8×16 矩阵乘,不是单个浮点。 - 数据需要特殊布局:mma 输入要按 NVIDIA 定义的 fragment 格式排列。
- 加载需要专用指令:
ldmatrix一次性把 16×16 数据从 SMEM 拉成 fragment。 - 输出是分布式的:累加结果分布在 32 个线程的寄存器里,不是连续存储。
这一章我们把这套新的编程模型彻底讲透。
12.2 mma.sync:一条指令算一个矩阵乘
Ampere+ 上的核心 Tensor Core 指令是:
mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32
D, A, B, C
含义:
- m16n8k16:算 D = A @ B + C,其中 A 是 16×16,B 是 16×8,D = C = 16×8。
- row.col:A 是行优先,B 是列优先(标准 GEMM 输入布局)。
- f32.f16.f16.f32:D 和 C 是 fp32,A 和 B 是 fp16。
- D, A, B, C:四组寄存器(不是矩阵指针!)
每条指令的浮点操作数:
一个 SM 4 个 Tensor Core,每周期一条 mma.sync = 每周期 16384 FLOPs/SM。乘以 132 SM × 1.83 GHz = 989 TFLOPs——这就是 H100 FP16 峰值的来路。
12.2.1 Fragment 布局
最反直觉的部分:mma.sync 的 A、B、C、D 不是单个寄存器,而是一组寄存器,分布在 32 个线程上:
A (16×16, FP16) 共 256 个 fp16 = 512 字节 = 128 个 32-bit 寄存器。 分布在 32 lane 上,每 lane 4 个寄存器(128 / 32 = 4)。
具体的分布模式很复杂,由 NVIDIA 硬件规定:
A 的 fragment layout (m16n8k16, row-major):
k=0..7 k=8..15
┌─────────────┐ ┌─────────────┐
m=0..7: │ T0 T1 ... T7│ │ T0 T1 ... T7│
│ T8 T9 ...T15│ │ T8 ... │
...
m=8..15:│ T16 T17 ...T23│ │ T16 ... │
│ T24 ...T31 │ │ T24 ... │
└─────────────┘ └─────────────┘
每 lane 持有 2 个 fp16 每 lane 持有 2 个 fp16
也就是 lane 0 持有 A[0..1, 0..1] (4 个 fp16),lane 1 持有 A[0..1, 2..3],依此类推。
读者完全不需要记这个表——下一节的 ldmatrix 会自动按这个布局排好。但重要的是理解:fragment 不是连续存储,而是分布式存储。
12.2.2 Inline PTX
CUDA C++ 写 mma.sync 用 inline PTX:
unsigned A[4]; // 4 个 32-bit, 每个 = 2 个 fp16, 共 8 个 fp16 (一行 fragment)
unsigned B[2]; // 同上
float C[4]; // 4 个 fp32 (输出 fragment 的一部分)
asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0, %1, %2, %3}, "
"{%4, %5, %6, %7}, "
"{%8, %9}, "
"{%0, %1, %2, %3};\n"
: "+f"(C[0]), "+f"(C[1]), "+f"(C[2]), "+f"(C[3])
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
"r"(B[0]), "r"(B[1]));
或者用 CUDA 11+ 的 nvcuda::wmma 包装(更高级 API,但灵活性差)。CUTLASS 用 inline PTX。
12.3 ldmatrix:把 SMEM 数据加载成 fragment
mma.sync 要求 fragment 已经在寄存器里,且按特定布局排列。怎么把 SMEM 数据装进 fragment?
最朴素的方式是每个线程自己 load:
unsigned A[4];
A[0] = reinterpret_cast<unsigned*>(&sA[m + lane_id / 4][k + (lane_id % 4) * 2])[0];
// ... 算地址再 load 4 次
地址计算超复杂,且每线程独立 load 会触发 bank conflict。
NVIDIA 提供了 ldmatrix 指令——一条指令把 SMEM 中一个 16×16 子块加载到 32 个 lane 的 fragment:
unsigned A[4];
asm("ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
"{%0, %1, %2, %3}, [%4];\n"
: "=r"(A[0]), "=r"(A[1]), "=r"(A[2]), "=r"(A[3])
: "l"(smem_ptr));
ldmatrix.x4 一次加载 4 个 8×8 fp16 子块(合计 16×16),输出 4 个寄存器/线程。32 lane × 4 寄存器 = 128 个寄存器 = 256 fp16 = 16×16 矩阵。完美匹配 mma.sync 的输入 fragment 布局。
ldmatrix 还有一个变种 ldmatrix.x4.trans——加载时就地转置。这对加载 B 矩阵特别有用,因为 GEMM 需要 B^T 形式喂给 mma。
12.4 SMEM Layout 与 Swizzle
ldmatrix 期望 SMEM 中的数据按特定 layout 排列。如果 SMEM 是简单的 row-major,ldmatrix 会触发严重的 bank conflict——因为它一次访问 32 个不同的 SMEM 地址,如果这些地址都落在同一组 bank,性能腰斩。
NVIDIA 设计了一种预定义的 swizzle layout,让 ldmatrix 访问的地址自动错开 bank:
flowchart TB
subgraph LinearLayout [Row-major Layout]
L1[行 i=0: col 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, ...]
L2[行 i=1: col 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, ...]
end
subgraph SwizzleLayout [Swizzled Layout 通过 row XOR col_high]
S1[行 i=0: col 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, ...]
S2[行 i=1: col 0,1,2,3 → col 8,9,10,11<br/>col 8..11 → col 0..3]
end
简单说,swizzle 把每行的列按一个 XOR 函数重排:
__device__ int swizzle_idx(int row, int col, int row_size) {
// 经典 NVIDIA 128-bit swizzle
int phase = (row & 0x7) ^ ((col >> 3) & 0x7);
return (col & ~0x7) | ((col & 0x7) ^ (row & 0x7));
}
具体实现还有几种变体,但核心思想都是用 row 的低位影响 col 的低位,让相邻行的同列元素落到不同 bank。
CUTLASS 提供了一组预定义的 swizzle layout(Swizzle<3,3,3>、Swizzle<2,3,3> 等),名称对应不同的 row/col 偏移参数。第 13 章会展开。
12.5 完整的 Tensor Core GEMM 骨架
把 mma + ldmatrix + swizzle 拼起来,给一个工作的 HGEMM kernel:
template <int BM = 128, int BN = 128, int BK = 32>
__global__ void hgemm_tensorcore(
const half* A, const half* B, half* C,
int M, int N, int K
) {
__shared__ half sA[BM * BK]; // 4 KB (128*32*2)
__shared__ half sB[BN * BK]; // 4 KB
const int tid = threadIdx.x;
const int warp_id = tid / 32;
const int lane_id = tid % 32;
const int warp_m = warp_id / 2; // 4 warps in M
const int warp_n = warp_id % 2; // 2 warps in N
// 一个 block 8 warp, 处理 BM × BN = 128×128
// 每 warp 64 × 64
constexpr int WM = 64, WN = 64;
constexpr int MMAS_M = WM / 16; // 4
constexpr int MMAS_N = WN / 8; // 8
// 累加器 fragment
float c_frag[MMAS_M][MMAS_N][4] = {0};
const int block_row = blockIdx.y * BM;
const int block_col = blockIdx.x * BN;
for (int k_step = 0; k_step < K; k_step += BK) {
// 1. cp.async 加载 A_tile, B_tile 到 sA / sB (使用 swizzle layout)
cp_async_load_a_tile(sA, A, block_row, k_step);
cp_async_load_b_tile(sB, B, block_col, k_step);
cp_async_commit_and_wait();
__syncthreads();
// 2. 内层 K (BK / 16 个 mma 步)
for (int kk = 0; kk < BK; kk += 16) {
// 用 ldmatrix 加载 A fragments
unsigned a_frag[MMAS_M][4];
#pragma unroll
for (int i = 0; i < MMAS_M; ++i) {
int row_offset = warp_m * WM + i * 16;
ldmatrix_x4(sA, row_offset, kk, &a_frag[i]);
}
// ldmatrix 加载 B fragments (with .trans for column-major)
unsigned b_frag[MMAS_N][2];
#pragma unroll
for (int j = 0; j < MMAS_N; ++j) {
int col_offset = warp_n * WN + j * 8;
ldmatrix_x2_trans(sB, col_offset, kk, &b_frag[j]);
}
// 3. mma.sync 累加
#pragma unroll
for (int i = 0; i < MMAS_M; ++i)
#pragma unroll
for (int j = 0; j < MMAS_N; ++j) {
asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%0, %1, %2, %3};\n"
: "+f"(c_frag[i][j][0]), "+f"(c_frag[i][j][1]),
"+f"(c_frag[i][j][2]), "+f"(c_frag[i][j][3])
: "r"(a_frag[i][0]), "r"(a_frag[i][1]),
"r"(a_frag[i][2]), "r"(a_frag[i][3]),
"r"(b_frag[j][0]), "r"(b_frag[j][1]));
}
}
__syncthreads();
}
// 4. 写 C (epilogue: fp32 -> fp16, 写回 HBM)
#pragma unroll
for (int i = 0; i < MMAS_M; ++i)
#pragma unroll
for (int j = 0; j < MMAS_N; ++j) {
int row = block_row + warp_m * WM + i * 16;
int col = block_col + warp_n * WN + j * 8;
// 每 lane 写它持有的 2 个 fp16
int my_row = row + (lane_id / 4) + (lane_id % 4) * 0; // 简化
int my_col = col + (lane_id % 4) * 2;
half2 v;
v.x = __float2half(c_frag[i][j][0]);
v.y = __float2half(c_frag[i][j][1]);
*reinterpret_cast<half2*>(&C[my_row * N + my_col]) = v;
// ... 写其他 fragment 元素
}
}
这段代码省略了细节(地址计算、swizzle 实现、写 C 的完整 epilogue),但骨架就是这样。完整可工作的代码在 CUTLASS 中:cutlass/gemm/threadblock/mma_pipelined.h。
实测:H100, FP16, M=N=K=4096:
SIMT FP32 GEMM (第 11 章 v5): ~25 TFLOPs (37% of FP32 peak)
Tensor Core HGEMM (上面骨架): ~600 TFLOPs (60% of FP16 peak)
cuBLAS HGEMM: ~750 TFLOPs (76% of FP16 peak)
到 60% 已经是不小的成就。剩余 16% 差距来自:double buffer 流水深度、CUTLASS 级别的细致 fragment 调度、PTX 微优化等。CUTLASS 把这些做到极致,能达到 80%+。
12.6 Hopper 升级:WGMMA
Hopper 引入 WGMMA(Warp-Group MMA)后,mma 指令的粒度从 warp-level 提升到 warp-group-level:
mma.sync.m16n8k16: 16×8×16 = 2048 FLOPs/指令, warp 级
wgmma.mma_async.m64n128k16: 64×128×16 = 131072 FLOPs/指令, warp-group 级
WGMMA 单条指令的算力是 mma.sync 的 64 倍——这意味着指令调度压力减少 64 倍,更易跑满 Tensor Core。
WGMMA 还是异步指令:
wgmma.mma_async ...; // 发起异步矩阵乘
wgmma.commit_group; // 提交一组
... 做别的事 ...
wgmma.wait_group 0; // 等待完成
发起 wgmma 之后 warp 可以继续做别的事(比如 TMA 加载下一个 tile),等需要结果时再同步。这是 Hopper GEMM 性能跃迁的核心机制——算和拷贝真正流水起来。
完整的 Hopper WGMMA GEMM 框架第 13 章 CUTLASS 部分会展开,第 17 章 FA2 SOTA 会用到。
12.7 这一章的小结与下一章
Tensor Core 是 GEMM 性能跃迁的关键:
- mma.sync 是矩阵级指令:单条指令算 16×8×16 = 2048 FLOPs。
- ldmatrix 是配套的矩阵 load 指令:把 SMEM 中的 16×16 子块加载到 fragment。
- Fragment 是分布式寄存器布局:32 lane 协作持有矩阵。
- SMEM Swizzle 防 bank conflict:CUTLASS 的标准 swizzle layout 解决了相邻行同列的 conflict 问题。
- WGMMA 是 Hopper 的升级:单条指令 64 倍算力 + 异步执行。
到这里,读者已经能写出一个达到 60-80% 峰值的 HGEMM。下一步是把这套手艺工业化——CUTLASS 把所有这些技巧抽象成可组合的 C++ 模板,让 NVIDIA 和工业界能用统一的工具构建各种 GEMM 变体(包括 FA2 内的 QK^T 和 PV)。
第 13 章我们剖析 CUTLASS 3.x 的设计哲学——CollectiveOp、CuTe Layout、Hopper Kernel Schedule。读完第 13 章读者会理解为什么 CUTLASS 的代码"看起来很复杂但实际上很优雅",并学会怎么读 CUTLASS 源码。
本章动手练习:
- 实现一个最简版 mma.sync HGEMM(小尺寸 M=N=K=64),亲手写 inline PTX,体验 fragment 布局。
- 阅读 CUTLASS 的
mma_pipelined.h,看双缓冲 + ldmatrix + mma 是怎么组装的。- 在 H100 上跑 cuBLAS HGEMM 和你的版本,用 Nsight Compute 看
sm__inst_executed_pipe_tensor.sum.per_cycle_active指标——你的 kernel Tensor Core 利用率是多少?