CUDA 算子工程:手写 FlashAttention v2 之路
第 10 章 朴素 GEMM 与 Roofline 分析
第 10 章 朴素 GEMM 与 Roofline 分析
"GEMM is to GPU what
printfis to C — the canonical example, the first benchmark, and the hidden gateway to mastering the platform." ——经典 CUDA 工程师箴言
10.1 GEMM 是什么,为什么它占据 LLM 算力的 90%
GEMM = General Matrix Multiplication,通用矩阵乘:
LLM 推理 / 训练中无处不在:
- QKV projection:
X @ W_qkv(B×H × H×3H) - Attention 中的 Q@K^T 与 P@V:(B×H_q × H_k×Seq, ...)
- Output projection:
O @ W_o - FFN gate / up / down:3 个独立 GEMM
一个 LLaMA-7B 单 token decoding 的浮点操作分布大约是:
GEMM: ~85-90%
Attention: ~5-8% (但 attention 的内部本身也是 GEMM)
LayerNorm: ~1-2%
Element-wise: ~1-2%
其他: <1%
如果把 attention 内部的 GEMM 也算进去,95%+ 的浮点运算都在 GEMM 上。所以 GEMM 性能直接决定 LLM 性能。
cuBLAS 的 SGEMM (FP32 GEMM) 在 H100 上 M=N=K=4096 能跑到 ~50 TFLOPs,约 75% 的 FP32 SIMT 峰值。Tensor Core 的 HGEMM (FP16 GEMM) 能跑到 ~750 TFLOPs,约 76% 的 FP16 Tensor Core 峰值。这个性能不是凭空来的——它是几十年 GPU 架构演进 + 几十万行 CUTLASS 代码堆出来的。
要理解 GEMM 性能的来路,得从最简单的版本开始。
10.2 朴素 GEMM:教科书式的三重循环
最直观的写法:每个线程算 C 的一个元素。
__global__ void gemm_naive(
const float* A, const float* B, float* C,
int M, int N, int K
) {
int row = blockIdx.y * blockDim.y + threadIdx.y;
int col = blockIdx.x * blockDim.x + threadIdx.x;
if (row >= M || col >= N) return;
float sum = 0.0f;
for (int k = 0; k < K; ++k) {
sum += A[row * K + k] * B[k * N + col];
}
C[row * N + col] = sum;
}
启动配置:
dim3 block(16, 16);
dim3 grid((N + 15) / 16, (M + 15) / 16);
gemm_naive<<<grid, block>>>(A, B, C, M, N, K);
这段代码逻辑正确,能算对 GEMM。但性能呢?
10.2.1 实测性能
H100, M=N=K=4096, FP32:
朴素 GEMM: ~700 GFLOPs (1% of FP32 peak)
cuBLAS SGEMM: ~50 TFLOPs (75% of FP32 peak)
差距: 70x
70 倍的差距。这就是"会写"和"写好"的距离。
10.3 用 Roofline 找瓶颈
朴素 GEMM 的瓶颈在哪?用第 4 章学的 Roofline 模型分析。
10.3.1 朴素 GEMM 的算术强度
每个线程计算 C 的一个元素,需要:
- 读 A 的一行: 个 float = 字节
- 读 B 的一列: 个 float = 字节
- 写 C 的一个元素:4 字节
- 计算: 个 FLOPs(K 次 mul + K 次 add)
每个 thread 的算术强度:
K 大时趋近于 1/4。这是非常低的算术强度——比 LayerNorm 和 RMSNorm 还低!
但整个 kernel 的算术强度不是这样算的。多个线程之间会复用 A 和 B 的不同部分:
- 同一行线程都读同一行 A
- 同一列线程都读同一列 B
如果 cache 完美命中,整个 kernel 的总访存只需要:
- A 全部读 1 次: 字节
- B 全部读 1 次: 字节
- C 全部写 1 次: 字节
- 计算: FLOPs
整体算术强度:
M=N=K=4096 时:
683 FLOPs/byte 的算术强度,远超过 H100 的临界点 295。这意味着 GEMM 在大尺寸下应该是完全 compute-bound 的——理论上能跑到算力峰值。
但朴素 GEMM 只跑了 1%。问题不是算力或带宽不够,而是访存模式让 cache 完全失效。
10.3.2 朴素 GEMM 为什么 cache 失效
看朴素 GEMM 的访存模式:
sum += A[row * K + k] * B[k * N + col];
线程 (row, col=0) 和 (row, col=1) 都读 A 的同一行,但它们在不同的 SM 上(gridDim 维度大),它们之间的 A 数据在 L2 cache 里有可能命中,但远不是必然。
更致命的是 B 的访问。线程 (row, col) 读 B[k * N + col]:
- k 从 0 到 K-1
- 每次 k++,地址跨越 N×4 字节 = 16 KB(N=4096, FP32)
B 的访问步长是 16 KB——意味着每读一次 B,都跳过 4 个 cache line 大小。L1 cache 完全无效(线性扫描下 cache 永远 miss)。L2 cache 50 MB,N=4096 时 B 整个矩阵 64 MB——也放不下。
最终结果:朴素 GEMM 的 B 矩阵几乎每次都要从 HBM 读,算术强度从理论 683 退化到约 1(每读 1 字节算 1 FLOP)。带宽 bound,跑不快。
10.3.3 朴素 GEMM 的 Roofline 位置
实际 TFLOPs/s
▲
989 ─────┼────────────────────── Tensor Core FP16 峰值
│
500 ─┤
│
│ cuBLAS SGEMM
│ AI~683, perf~50
100 ─┤
│
│
10 ─┤ ╱
│ ╱
│ 朴素 GEMM
│ 实际 AI ~1, perf ~0.7
1 ──┤
└─────────────────────► 算术强度
1 10 100 1000
带宽屋顶 (3.35 TB/s)
朴素 GEMM 的"实际算术强度"约 1,落在带宽屋顶下,所以无法发挥算力优势。优化的方向是把算术强度从 1 提到 ~683——这要求充分复用 SMEM 和寄存器中的 A/B 数据,让"理论算术强度"和"实际算术强度"接近。
10.4 GEMM 优化的三大武器
要让 GEMM 跑到算力峰值,需要三层 tile:
flowchart TB HBM[A, B 在 HBM] -->|每次 K_tile 列| GTile[Block-level tile] GTile -->|放进 SMEM| SMEM[SMEM tile A, B] SMEM -->|每次 mma_k 列| WTile[Warp-level tile] WTile -->|放进寄存器| REG[Register fragment] REG -->|喂给 Tensor Core| TC[mma.sync 一次<br/>16×8×16 矩阵乘] TC -->|累加| ACC[寄存器中的 C accumulator]
三层 tile:
- Block-level tile:一个 Block 处理 C 的一个 M_tile × N_tile 子块(典型 128×128)。一个 SM 上 active 的 block 数 = SMEM_total / (block_smem_size)。
- Warp-level tile:一个 warp 处理 block tile 内的一个 M_warp × N_warp 子块(典型 64×64)。一个 block 通常 8 个 warp。
- Tensor Core mma:一个 mma.sync 操作 16×8×16(Hopper 上是 64×N×16 with WGMMA)。
这三层 tile 配合 double buffering(异步拷贝下一个 K tile 时算当前),构成现代 GEMM 的核心结构。
10.4.1 数据复用计算
理想情况下:
- Block tile (128×128):每个 K 列的 A_tile (128) + B_tile (128) 共 256 个 float = 1 KB,被用来算 128×128 = 16384 个 C 元素。复用率 16384 / 256 = 64×。
- Warp tile (64×64):一组 fragment 数据用 64×64 = 4096 次。
- Register tile (8×8):每个寄存器的元素被用 ~8 次。
这种层层复用让每个从 HBM 读上来的字节被多次使用,有效算术强度 接近理论上限。
10.4.2 Tensor Core 的角色
第 12 章会详细讲,这里先建立直觉:
- Volta+ 的 Tensor Core 一条 mma 指令算 16×8×16 矩阵乘。
- 算 16×8×16 矩阵乘需要的浮点操作 = 16 × 8 × 16 × 2 = 4096 FLOPs。
- 一个 SM 4 个 Tensor Core,每个 Tensor Core 每周期发一次 mma。
- 峰值 = 4 × 4096 / 1 周期 = 16384 FLOPs/周期/SM。
- 132 SM × 1.83 GHz × 16384 FLOPs / 周期 = 989 TFLOPs/s。
这就是 H100 FP16 Tensor Core 峰值的来路。要跑满这个峰值,每个 SM 每周期必须发出一条 mma 指令——这要求所有数据准备好、SMEM/寄存器完美配合。任何一个环节 stall(cache miss、bank conflict、寄存器 spill),峰值就打不到。
10.5 优化路径预告
后三章我们会一步步从朴素 GEMM 推进到 cuBLAS 性能:
- 第 11 章 Tiled GEMM:用 SMEM tile + double buffer,实现 SIMT GEMM,目标 ~30% 峰值。这是 Volta 之前的写法巅峰。
- 第 12 章 Tensor Core GEMM:用 mma.sync + ldmatrix + swizzle,引入 Tensor Core,目标 ~80% 峰值。
- 第 13 章 CUTLASS 设计哲学:剖析 NVIDIA 官方模板库的设计,理解工业级 GEMM 是怎么组装的。
每一步都对应一组真实的工程技巧,每一组技巧都把性能往上推一截。这条优化路径同时也是后续 FA2 优化的基础——FA2 内部的 QK^T 和 PV 矩阵乘,都基于这套 GEMM 框架。
10.6 一个朴素优化:Block 内复用
在进入第 11 章之前,先做一个小优化作为"warm-up"——让 thread 算多个 C 元素:
template <int TM = 4, int TN = 4> // 每线程算 TM × TN 个 C 元素
__global__ void gemm_thread_tile(
const float* A, const float* B, float* C,
int M, int N, int K
) {
int row = (blockIdx.y * blockDim.y + threadIdx.y) * TM;
int col = (blockIdx.x * blockDim.x + threadIdx.x) * TN;
float c[TM][TN] = {0};
for (int k = 0; k < K; ++k) {
float a[TM], b[TN];
#pragma unroll
for (int i = 0; i < TM; ++i) a[i] = A[(row + i) * K + k];
#pragma unroll
for (int j = 0; j < TN; ++j) b[j] = B[k * N + (col + j)];
#pragma unroll
for (int i = 0; i < TM; ++i)
#pragma unroll
for (int j = 0; j < TN; ++j)
c[i][j] += a[i] * b[j];
}
#pragma unroll
for (int i = 0; i < TM; ++i)
#pragma unroll
for (int j = 0; j < TN; ++j)
C[(row + i) * N + (col + j)] = c[i][j];
}
每线程算 4×4=16 个 C 元素。读取 A 的 4 个元素 + B 的 4 个元素 = 8 次访存,做 16 次 mul-add。算术强度变成 32 / 32 = 1 FLOPs/byte——比朴素的 0.25 提升 4 倍。
实测性能:朴素 GEMM 700 GFLOPs → thread tile 4×4 ≈ 2800 GFLOPs(4× 提升)。
但这远没到目标。下一步要做的是让多个 thread 共享 A/B——这就是 Block tile + SMEM 的舞台。
10.7 这一章的小结与下一章
这一章我们建立了 GEMM 优化的"地图":
- GEMM 是 LLM 算力消耗的 95%+:所有优化的最大投资点。
- 朴素 GEMM 只跑 1% 峰值:不是算力或带宽不够,而是 cache 完全失效。
- 理论算术强度高(~683),实际算术强度低(~1):必须靠 SMEM + 寄存器层层复用提升实际 AI。
- 三层 tile 是 GEMM 优化的核心结构:Block tile / Warp tile / Tensor Core mma。
- 简单的 thread tile 就能 4× 提升:但还远没到峰值,需要 SMEM 和 Tensor Core。
第 11 章我们正式进入 SMEM 优化的世界——把 Block tile 放进 SMEM,配合 double buffering 让 SMEM 和 HBM 流水起来。这个版本(不用 Tensor Core)能达到 ~30-40% 算力峰值,是 SIMT GEMM 的合理目标。第 12 章再加上 Tensor Core 后,性能会一举跃升到 ~80%。
本章动手练习:
- 实现朴素 GEMM 和 thread tile 4×4 GEMM,对比性能。
- 用 Nsight Compute 看朴素 GEMM 的
dram__sectors_read.sum,估算实际 HBM 流量,验证"实际算术强度 ~1"的判断。- 思考:为什么 thread tile 要选 4×4 而不是 8×8?(提示:寄存器压力)