CUDA 算子工程:手写 FlashAttention v2 之路

第 21 章 性能陷阱与反模式

作者 杨艺韬 · 2,043 字

第 21 章 性能陷阱与反模式

"Every CUDA optimization story starts the same way: 'I thought I was being clever, then I checked the SASS.'" ——CUDA 工程社区的常见自嘲

21.1 陷阱 1:寄存器 spill 让 kernel 变慢 100×

症状:kernel 性能远低于预期,ncu 显示 local memory throughput 高、Stall Long Scoreboard 高。

原因:每线程使用的寄存器超过 255,编译器把多余的寄存器溢出到 local memory(实际是 HBM)。每次访问慢 800 cycles,对原本预期 0 cycle 的寄存器访问,慢 100 倍以上。

诊断

nvcc -arch=sm_90 -Xptxas -v my_kernel.cu
# 输出: ptxas info: Used 80 registers, 0 stack, 256 bytes spill stores, 256 bytes spill loads

spill stores/loads != 0 就是问题。

修复

  1. 减少局部变量数量(合并、复用)。
  2. 降低 #pragma unroll 程度。
  3. __launch_bounds__(256, 4) 提示编译器降低寄存器使用(参数:每 block 最多 256 thread,每 SM 最少 4 block)。
  4. 把不常用的状态存到 SMEM 而不是寄存器。

21.2 陷阱 2:SMEM Bank Conflict

症状:SMEM 访问慢,ncu 显示 Bank Conflicts > 0

原因:第 4 章讲过——一个 warp 内多个线程访问同一个 bank 的不同地址。

诊断:ncu 的 metric smsp__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum

修复

  1. +1 padding__shared__ float smem[32][33] 而不是 [32][32]
  2. Swizzled layout:用 row XOR col 函数。
  3. Vectorized access:用 float4 让 4 个线程读 16 字节,bank 跨度变大。
  4. TMA + swizzle:Hopper 上让硬件自动 swizzle。

21.3 陷阱 3:Warp Divergence

症状:算力利用率低,ncu 显示 smsp__sass_average_branch_targets_threads_uniform.pct 远低于 100%。

原因:warp 内 32 个线程走不同分支,必须串行执行不同分支。

反例

// 反例: 数据相关分支
if (arr[tid] > 0) {
    do_something_a();
} else {
    do_something_b();
}

如果 arr 值随机,warp 内一半走 a 一半走 b,性能减半。

修复

  1. 数据预排序:让相邻线程的数据相同分支。
  2. Mask 化:把分支变成数学(y = mask * a + (1-mask) * b)。
  3. Warp-uniform 分支:让分支条件是 warp_id 而不是 thread_id(整 warp 走同一边)。
  4. 接受代价:如果分支不可避免,至少让"重分支"是少数(常见路径快)。

21.4 陷阱 4:L2 Thrashing

症状:L2 miss 高,但 L2 容量足够。

原因:多个数据流互相驱逐对方。最经典的例子:跑一个 kernel 同时访问 weight(固定)和 KV cache(每请求不同),两者在 L2 上"打架"。

诊断:ncu 的 lts__t_sectors_op_read_lookup_hit.sum / lts__t_sectors_op_read_lookup_miss.sum

修复

  1. L2 Persistence:把热数据(weight、cos/sin 表)锁在 L2。
  2. Streaming load:用 __ldcs 让冷数据绕过 L2。
  3. Tile 重新设计:让一次 kernel 内访问的数据集中在 L2 容量内。
// 用 L2 persistence 锁定 weight 在 L2
cudaStreamAttrValue attr = {};
attr.accessPolicyWindow.base_ptr = weight_ptr;
attr.accessPolicyWindow.num_bytes = 32 * 1024 * 1024;  // 32 MB
attr.accessPolicyWindow.hitRatio = 1.0;
attr.accessPolicyWindow.hitProp = cudaAccessPropertyPersisting;
cudaStreamSetAttribute(stream,
    cudaStreamAttributeAccessPolicyWindow, &attr);

21.5 陷阱 5:滥用全局 atomic

症状:kernel 性能塌陷,看 SASS 有大量 atomic 指令。

原因:HBM atomic 慢(几百 cycle),且高竞争时变成完全串行。

反例:第 5 章 v0 的 reduce kernel——N 个线程同时 atomic add 到一个全局变量。

修复:分层归约。

  1. Warp 内 reduce(__shfl)。
  2. Block 内 reduce(SMEM)。
  3. Cluster 内 reduce(DSMEM, Hopper)。
  4. 最后只用极少 atomic 写最终结果。

警告:fp16 atomic 在 H100 上原生支持,但精度差;fp32 atomic 慢但准确。LLM 训练中 atomic 通常用 fp32。

21.6 陷阱 6:Block Size 错误

症状:算法正确,性能与预期相差 2-3×。

原因

修复

  1. 常用值:256、512、128。先试 256。
  2. 针对算子调整
    • LayerNorm/Softmax:256 或 512(一行一 block)。
    • GEMM (Tiled):128 或 256。
    • Reduce:128 或 256。
    • Attention (FA2):128(4 warps)。
  3. 用 Occupancy Calculator:CUDA 提供的小工具,根据每 block 的 SMEM/寄存器使用估算 occupancy。

21.7 陷阱 7:迷信高 Occupancy

症状:把 occupancy 调到 100% 反而变慢。

原因:高 occupancy 让每个线程的寄存器配额变小,可能导致 spill 或减少 ILP。

反直觉事实FA、cuBLAS GEMM、CUTLASS 高性能 kernel 的 occupancy 普遍只有 25-50%。它们的优势不是"warp 多",而是"每个 warp 干的活多"。

指导原则

21.8 陷阱 8:Stream 与 Graph 误用

症状:用了 stream 但 GPU 仍然串行。

原因:Stream 之间的依赖没设置好,或者 host 端单线程派发太慢。

反例

// 反例: 默认 stream 阻塞所有
kernel1<<<...>>>(...);
cudaMemcpy(...);  // 默认 stream, 阻塞
kernel2<<<...>>>(...);

修复

  1. 用显式 stream:
cudaStream_t s1, s2;
cudaStreamCreate(&s1);
cudaStreamCreate(&s2);
kernel1<<<..., 0, s1>>>(...);
cudaMemcpyAsync(..., s2);  // 不同 stream, 并行
kernel2<<<..., 0, s1>>>(...);
  1. 用 CUDA Graph:把整个推理 forward 录制成一个 graph,每次只 launch graph:
cudaStreamBeginCapture(stream, cudaStreamCaptureModeGlobal);
// 录制
forward(input, output, stream);
cudaStreamEndCapture(stream, &graph);
cudaGraphInstantiate(&graphExec, graph);

// 后续每次推理:
cudaGraphLaunch(graphExec, stream);  // 一次提交所有 kernel

21.9 陷阱 9:浮点精度问题

症状:FP16 / BF16 训练 loss 不收敛或推理输出异常。

原因

  1. 累加用错精度:FP16 累加溢出/精度损失,必须 FP32 累加(Tensor Core 自动这么做)。
  2. Softmax 不 safe:忘了减 max,溢出。
  3. LayerNorm 用朴素方差公式:第 7 章讲过,必须用 Welford。
  4. 量化 scale 选错:per-tensor scale 对异常值敏感。

修复

  1. 累加器永远用 FP32(甚至 FP64)。
  2. 所有 reduce 类算子都要数值稳定版本。
  3. 量化用 per-channel/per-group。

21.10 陷阱 10:误判带宽 vs 算力 bound

症状:花一周优化算法(减少 FLOPs),性能没变化。

原因:kernel 是带宽 bound,FLOPs 不是瓶颈,HBM 流量才是。

正确诊断

  1. ncu 看 Roofline:点在带宽屋顶下→带宽 bound;在算力屋顶下→算力 bound。
  2. Compute Throughput vs Memory Throughput:哪个高哪个 bound。

修复方向

21.11 陷阱 11:忽略 Kernel Launch 开销

症状:单 kernel 性能不错,整体推理慢。

原因:每次 kernel<<<...>>>() 启动有 ~5μs overhead。LLM 推理一次 forward 几百次 kernel launch,累计 ms 级。

诊断:nsys timeline 看 kernel 之间的 gap。

修复

  1. Kernel Fusion:减少 kernel 数(第 8 章)。
  2. CUDA Graph:把多个 kernel 录成 graph 一次提交。
  3. Persistent Kernel:用一个 kernel 处理多个 tile(第 18 章)。

21.12 陷阱 12:忽略 Host 端瓶颈

症状:GPU 利用率(nvidia-smi 看)只有 50%,但 GPU profiler 看每个 kernel 都很快。

原因:CPU 端阻塞——可能是数据加载、预处理、Python overhead。

诊断:nsys timeline 看 CPU thread 是否在某些点上忙。

修复

  1. Async data loading:dataloader 用多 worker。
  2. Pin memory + prefetch:减少 H2D 拷贝同步。
  3. Compile heavy logic:用 torch.compile 或写 C++ extension。

21.13 一份 LLM Kernel 优化清单

最后给读者一份 LLM kernel 优化时的快速清单:

□ 1. 用 nsys 确认这是热点 kernel (占总时间 > 5%)
□ 2. 用 ncu 看 Roofline 位置, 判断带宽 vs 算力 bound
□ 3. 检查 spill (ptxas info)
□ 4. 检查 bank conflict (ncu)
□ 5. 检查 occupancy (ncu)
□ 6. 检查 cache hit rate (ncu L2)
□ 7. 检查 warp divergence (ncu)
□ 8. 看 SASS 找次优指令
□ 9. 与 cuBLAS / CUTLASS 同尺寸对比, 看差距来自哪
□ 10. 数值精度验证 (与 reference 实现对比)

每次优化前过一遍这个清单,能避免 80% 的"白忙活"。

21.14 第五篇收官与下一篇

第五篇我们建立了性能调优的工具链与避坑指南:

到这里本书的核心内容(第 1-21 章)全部完成。读者已经掌握了:

  1. 基础(第一篇 1-4 章):GPU 范式、Hopper 架构、编程模型、内存层级。
  2. 小算子(第二篇 5-9 章):Reduction、Online Softmax、LayerNorm、Element-wise Fusion、Quantization。
  3. 大算子之 GEMM(第三篇 10-13 章):朴素到 CUTLASS。
  4. 大算子之 Attention(第四篇 14-18 章):FA1 思想到 FA3 SOTA。
  5. 性能工程(第五篇 19-21 章):诊断与避坑。

附录 A、B、C 会补充三个实用主题:CUDA Graph 与 Stream(异步执行模型)、CUDA C++ 与 Triton 的对比(什么时候选哪个)、与 vLLM·Transformer 那两本书的衔接路径(让读者知道下一步该读什么)。

附录写完,本书就完成了。

本章动手练习

  1. 写一个故意有 register spill 的 kernel(用 __launch_bounds__ 强制低寄存器),用 ncu 看性能差距。
  2. 找你自己写过的一个 CUDA kernel,按 21.13 节清单过一遍,记录每一项的状态。
  3. 思考:本书介绍的 12 个陷阱中,你之前没意识到、但确实经常踩的是哪几个?