CUDA 算子工程:手写 FlashAttention v2 之路
第 11 章 Tiled GEMM:Shared Memory 与 Double Buffer
第 11 章 Tiled GEMM:Shared Memory 与 Double Buffer
"Tiling is to GPU what divide-and-conquer is to algorithms — once you see it, you start seeing it everywhere." ——CUDA 教学传统
11.1 为什么需要 Tiling
第 10 章看到朴素 GEMM 只能跑到 1% 峰值,瓶颈是实际算术强度只有 ~1。要让算力发挥出来,必须把 A 和 B 的元素复用——同一个元素被读到 SMEM 后,被多个线程多次使用,再写回。
Tiling(瓦片化)就是这种复用的工程实现。基本结构是:
flowchart TB
subgraph HBM_Layer [HBM 层]
A[A 矩阵 M×K]
B[B 矩阵 K×N]
C[C 矩阵 M×N]
end
subgraph SMEM_Layer [SMEM 层 · per Block]
SA[A_tile 128×Tk]
SB[B_tile Tk×128]
end
subgraph Reg_Layer [Register 层 · per Warp]
RA[A_frag 64×Tk_inner]
RB[B_frag Tk_inner×64]
RC[C_acc 64×64]
end
HBM_Layer -->|Block iteration<br/>每次 K 维移动 Tk| SMEM_Layer
SMEM_Layer -->|Warp iteration<br/>每次 K 维移动 Tk_inner| Reg_Layer
Reg_Layer -->|FMA 累加| Reg_Layer
每个 Block 处理 C 的一个 M_block × N_block 子块(典型 128×128),扫过 K 维度时不断从 HBM 加载新的 A_tile / B_tile 到 SMEM。每个 Warp 处理 block 内的一个 M_warp × N_warp 子块(典型 64×64),从 SMEM 读 fragment 到寄存器。最内层是寄存器中的 8×8 累加(每线程)。
11.2 Block Tile:SMEM 中的复用
定义 tile 大小:
constexpr int BM = 128; // Block tile M 维
constexpr int BN = 128; // Block tile N 维
constexpr int BK = 16; // Block tile K 维
每个 Block 处理 BM × BN 的 C 子块。内层 K 维分块 BK,每次从 HBM 拉 BM × BK 的 A_tile 和 BK × BN 的 B_tile 到 SMEM。
SMEM 占用:
A_tile: 128 × 16 × 4 = 8192 bytes = 8 KB
B_tile: 16 × 128 × 4 = 8192 bytes = 8 KB
合计: 16 KB
H100 单 SM SMEM 228 KB,能放 14 个 block 的 single-buffer(实际 occupancy 受寄存器限制更严,2-4 个 block)。
11.2.1 加载 A_tile / B_tile 到 SMEM
每个 Block 启动 BM × BN / (TM × TN) = 256 个线程(每线程算 8×8 = 64 个 C 元素)。256 线程协作加载 8 KB 的 A_tile 和 8 KB 的 B_tile。
__shared__ float sA[BM][BK]; // 128 × 16
__shared__ float sB[BK][BN]; // 16 × 128
const int tid = threadIdx.y * blockDim.x + threadIdx.x;
const int row = blockIdx.y * BM;
const int col = blockIdx.x * BN;
// A_tile 加载: 256 线程加载 128 × 16 = 2048 个元素
// 每线程加载 2048 / 256 = 8 个元素 (但 128 行 × 16 列, 安排成每 16 线程一行)
for (int load = 0; load < BM * BK / 256; load++) {
int idx = tid + load * 256;
int load_row = idx / BK;
int load_col = idx % BK;
sA[load_row][load_col] = A[(row + load_row) * K + (k_step + load_col)];
}
类似地加载 B_tile。
11.2.2 内层乘法:从 SMEM 到寄存器
每个 Warp(32 线程)处理 64 × 64 子块,每线程算 8 × 8 个 C 元素:
constexpr int WM = 64; // Warp tile M
constexpr int WN = 64; // Warp tile N
constexpr int TM = 8; // Thread tile M
constexpr int TN = 8; // Thread tile N
float c[TM][TN] = {0}; // 64 个寄存器
for (int kk = 0; kk < BK; ++kk) {
float a[TM], b[TN];
#pragma unroll
for (int i = 0; i < TM; ++i) a[i] = sA[warp_m * WM + thread_m * TM + i][kk];
#pragma unroll
for (int j = 0; j < TN; ++j) b[j] = sB[kk][warp_n * WN + thread_n * TN + j];
#pragma unroll
for (int i = 0; i < TM; ++i)
#pragma unroll
for (int j = 0; j < TN; ++j)
c[i][j] += a[i] * b[j];
}
每个线程从 SMEM 读 8 个 a + 8 个 b = 16 个浮点(64 字节),做 64 次 mul-add(128 FLOPs)。单线程算术强度 = 128 / 64 = 2 FLOPs/byte——比朴素 GEMM 提升 8×。
11.3 完整的 Tiled GEMM Kernel
把上面的 piece 组装起来:
template <int BM, int BN, int BK, int WM, int WN, int TM, int TN>
__global__ void gemm_tiled(
const float* A, const float* B, float* C,
int M, int N, int K
) {
__shared__ float sA[BM][BK];
__shared__ float sB[BK][BN];
const int tid = threadIdx.y * blockDim.x + threadIdx.x;
const int warp_id = tid / 32;
const int lane_id = tid % 32;
const int warp_m = warp_id / (BN / WN);
const int warp_n = warp_id % (BN / WN);
const int thread_m = lane_id / (WN / TN);
const int thread_n = lane_id % (WN / TN);
const int block_row = blockIdx.y * BM;
const int block_col = blockIdx.x * BN;
float c[TM][TN] = {0};
// 沿 K 维度迭代
for (int k_step = 0; k_step < K; k_step += BK) {
// 1. 协作加载 A_tile, B_tile 到 SMEM
#pragma unroll
for (int i = tid; i < BM * BK; i += blockDim.x * blockDim.y) {
int r = i / BK, c_ = i % BK;
sA[r][c_] = A[(block_row + r) * K + (k_step + c_)];
}
#pragma unroll
for (int i = tid; i < BK * BN; i += blockDim.x * blockDim.y) {
int r = i / BN, c_ = i % BN;
sB[r][c_] = B[(k_step + r) * N + (block_col + c_)];
}
__syncthreads();
// 2. 每线程算 TM × TN 个累加
for (int kk = 0; kk < BK; ++kk) {
float a[TM], b[TN];
#pragma unroll
for (int i = 0; i < TM; ++i)
a[i] = sA[warp_m * WM + thread_m * TM + i][kk];
#pragma unroll
for (int j = 0; j < TN; ++j)
b[j] = sB[kk][warp_n * WN + thread_n * TN + j];
#pragma unroll
for (int i = 0; i < TM; ++i)
#pragma unroll
for (int j = 0; j < TN; ++j)
c[i][j] += a[i] * b[j];
}
__syncthreads();
}
// 3. 写 C
#pragma unroll
for (int i = 0; i < TM; ++i) {
#pragma unroll
for (int j = 0; j < TN; ++j) {
int r = block_row + warp_m * WM + thread_m * TM + i;
int c_ = block_col + warp_n * WN + thread_n * TN + j;
if (r < M && c_ < N) C[r * N + c_] = c[i][j];
}
}
}
// Launch
dim3 block(16, 16); // 256 线程
dim3 grid(N / BN, M / BM);
gemm_tiled<128, 128, 16, 64, 64, 8, 8><<<grid, block>>>(A, B, C, M, N, K);
实测:H100, M=N=K=4096, FP32:
朴素 GEMM: ~700 GFLOPs (1%)
Thread tile 4×4: ~2800 GFLOPs (4%)
Tiled GEMM: ~12 TFLOPs (18%)
12 TFLOPs,提升到 18%。但还远没到 SIMT 上限。
11.4 优化 1:Double Buffer + Async Copy
上面的代码有一个明显的"同步等待":每次 K 迭代完,要 __syncthreads() 等所有线程算完,才能加载下一个 K_tile。这段时间 SM 处于"等待加载"状态——计算单元闲置。
Double buffer(双缓冲)让计算和加载重叠:
__shared__ float sA[2][BM][BK]; // 两份缓冲
__shared__ float sB[2][BK][BN];
// 预加载第一个 buffer
load_to_smem(sA[0], sB[0], k_step=0);
__syncthreads();
for (int k_step = BK; k_step < K; k_step += BK) {
int cur = ((k_step / BK) - 1) % 2;
int next = (k_step / BK) % 2;
// 异步加载下一个 buffer
cp_async_global_to_shared(sA[next], A_addr_at(k_step));
cp_async_global_to_shared(sB[next], B_addr_at(k_step));
cp_async_commit();
// 同时计算当前 buffer
compute_on_smem(sA[cur], sB[cur], &c);
cp_async_wait_all();
__syncthreads();
}
// 算最后一个 buffer
compute_on_smem(sA[(K/BK - 1) % 2], sB[(K/BK - 1) % 2], &c);
这里用到了 Ampere+ 的 cp.async.cg.shared.global 指令——异步地从 HBM 拷贝到 SMEM,期间 SIMT cores 可以继续算。
实测带来 ~30% 提升:12 TFLOPs → 16 TFLOPs(24%)。
11.5 优化 2:解决 SMEM Bank Conflict
读 sA[warp_m * WM + thread_m * TM + i][kk] 这一句,KK 固定时,TM=8 个连续行:
thread_m=0: sA[0][kk], sA[1][kk], ..., sA[7][kk]
分别落到 bank (0, 16, 0, 16, 0, 16, 0, 16) (重复!)
每行 BK=16 个 float,跨行步长 16 个 float = 16 banks。所以 sA[i][kk] 和 sA[i+1][kk] 落在同一个 bank。一个 warp 内 32 线程读 8 个连续行 → 严重 bank conflict。
解决方法:+1 padding 或 swizzled 布局。
11.5.1 +1 Padding
__shared__ float sA[BM][BK + 1]; // 17 列而不是 16
每行 17 列后,sA[i][kk] 和 sA[i+1][kk] 落到 bank (kk % 32, (kk + 17) % 32)——错开了。但每行多一列浪费 SMEM 6%。
11.5.2 Swizzled Layout
更高级的做法是按"对角线"放置元素:
__device__ __forceinline__ int swizzle(int row, int col) {
return col ^ (row & 0x7);
}
sA[i][swizzle(i, kk)] = ...;
把行号和列号 XOR 一下,让相邻行的元素自动落到不同 bank。这是 CUTLASS 中的标准技巧,第 12 章会详细看。
实测引入 padding 后:16 TFLOPs → 21 TFLOPs(31%)。
11.6 优化 3:寄存器 Blocking 与读取顺序
最内层的 mul-add 循环:
for (int kk = 0; kk < BK; ++kk) {
for (int i = 0; i < TM; ++i) a[i] = sA[...][kk];
for (int j = 0; j < TN; ++j) b[j] = sB[kk][...];
for (int i = 0; i < TM; ++i)
for (int j = 0; j < TN; ++j)
c[i][j] += a[i] * b[j];
}
这段代码每次 kk 都要重新 load a 和 b。如果 BK=16,那总共 16 次 load × (TM + TN) = 256 次 SMEM 访问。
更好的方式是外层 unroll kk 几步,复用 a 和 b 寄存器:
for (int kk = 0; kk < BK; kk += 4) {
float a[4][TM], b[4][TN];
#pragma unroll
for (int u = 0; u < 4; ++u) {
for (int i = 0; i < TM; ++i) a[u][i] = sA[...][kk + u];
for (int j = 0; j < TN; ++j) b[u][j] = sB[kk + u][...];
}
#pragma unroll
for (int u = 0; u < 4; ++u) {
#pragma unroll
for (int i = 0; i < TM; ++i)
#pragma unroll
for (int j = 0; j < TN; ++j)
c[i][j] += a[u][i] * b[u][j];
}
}
或者用更流行的"Outer Product 累加"方式。这些细节在 CUTLASS 里都有现成实现。
11.7 性能演进表
把所有优化加上:
| 版本 | 优化 | TFLOPs (H100, FP32) | % FP32 SIMT 峰值 |
|---|---|---|---|
| v0 | 朴素 | 0.7 | 1% |
| v1 | Thread tile 4×4 | 2.8 | 4% |
| v2 | + Block tile + SMEM | 12 | 18% |
| v3 | + Double buffer (cp.async) | 16 | 24% |
| v4 | + Bank conflict fix | 21 | 31% |
| v5 | + 寄存器 blocking + unroll | 25 | 37% |
| 目标 | SIMT 极限 (cuBLAS SGEMM) | ~50 | ~75% |
到 v5 我们达到了 SIMT 写法的合理水平(37%),距离 cuBLAS 还有 2× 差距。剩下的 2× 差距,70% 来自 Tensor Core——SIMT 单核的浮点率根本打不过 Tensor Core,FP32 SIMT 峰值 67 TFLOPs vs FP16 Tensor Core 989 TFLOPs。
这就是为什么第 12 章必须引入 Tensor Core——SIMT 路线已经到顶了。
11.8 这一章的小结与下一章
Tiled GEMM 是 GEMM 优化的"地基":
- 三层 tile(block / warp / thread)让数据在 SMEM 和寄存器中层层复用,把实际算术强度从 1 提到 ~100+。
- Double buffer + cp.async 让 HBM 拷贝和计算重叠,隐藏 HBM 延迟。
- Bank conflict 处理(padding 或 swizzle)确保 SMEM 带宽。
- 寄存器 blocking 和 unroll 减少 SMEM 访问。
- SIMT 路径的极限是 ~37%——不是技巧不够,是 FP32 SIMT 算力本身就打不过 Tensor Core。
第 12 章我们引入 Tensor Core——mma.sync 指令、ldmatrix 指令、layout swizzle。这是 GEMM 性能再提升 2× 的关键。读完第 12 章,读者写出的 HGEMM kernel 能达到 cuBLAS 80%+ 的水平。
本章动手练习:
- 把 v0..v5 都实现一遍,记录性能演进。
- 用 Nsight Compute 看 v2 vs v4 的
smsp__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum指标,验证 padding 消除了 bank conflict。- 思考:为什么 BM=BN=128 比 BM=BN=64 更优?(提示:复用率与 SMEM 占用)