CUDA 算子工程:手写 FlashAttention v2 之路
第 4 章 内存层级与代价表
第 4 章 内存层级与代价表
"Memory is the new CPU. Bandwidth is the new clock speed. Cache is the new memory." ——常见于现代体系结构课程
4.1 一张图看清五层内存
打开任何一份 GPU 性能调优指南,第一张图大概率是这样:
Hopper GPU 内存层级(H100 SXM5)
延迟 带宽
┌───────────────────────────────────────────┐
│ Registers · 256 KB / SM × 132 SM = 33 MB │ 0 cycles 无限大
│ (但每线程私有, 不可共享) │
└────────────────────┬──────────────────────┘
│
┌────────────────────▼──────────────────────┐
│ Shared Memory (SMEM) · 228 KB / SM │ ~25 cycles ~50 TB/s/SM
│ 软件管理, block 内共享 │ ~6.6 PB/s 全卡
└────────────────────┬──────────────────────┘
│
┌────────────────────▼──────────────────────┐
│ L1 Cache · 28 KB / SM (与 SMEM 共享) │ ~30 cycles ~50 TB/s/SM
│ 硬件管理, 自动命中 │
└────────────────────┬──────────────────────┘
│
┌────────────────────▼──────────────────────┐
│ L2 Cache · 50 MB 全卡共享 │ ~250 cycles ~12 TB/s
│ 硬件管理, 自动命中, set-associative │
└────────────────────┬──────────────────────┘
│
┌────────────────────▼──────────────────────┐
│ HBM3 Memory · 80 GB │ ~800 cycles 3.35 TB/s
│ HBM 物理是 5 stack × 16 GB │
└───────────────────────────────────────────┘
数据来源:综合 NVIDIA H100 Whitepaper、Hopper Tuning Guide、与第三方 microbenchmark(Luo et al., 2024)。延迟为 SIMT 端到端典型值,带宽为单 SM 或全卡聚合。
把这张图记住——这本书后面每一章都会回到它。它是 LLM 算子优化的"地图"。
几个值得反复强调的事实:
- 从 HBM 读 1 字节比从 SMEM 读 1 字节贵 30 倍。这就是为什么"把数据搬到 SMEM"在 GPU 上是头等大事。
- 从 HBM 读 1 字节比从寄存器读 1 字节贵 800 倍。这就是为什么 GEMM 一定要把 K 维度拆开,让数据在寄存器里反复用。
- L1 几乎没用:在 LLM 算子里,几乎所有重要数据都被显式放到 SMEM 里了,L1 只剩下少量 spill 数据。所以 H100 默认配置 SMEM 228 KB / L1 28 KB。
- L2 是隐藏的关键层:50 MB 听起来不大,但放下 5-10 个 GEMM tile 完全够。两个相邻的 kernel 可以靠 L2 的命中实现"准 fused"——这是 vLLM 等推理引擎调度的关键考虑。
4.2 寄存器:零延迟,零开销,但稀缺
寄存器是 GPU 上最快的存储——访问延迟基本是 0 cycle(与算术指令融合在一起)。但寄存器有几个硬约束:
4.2.1 寄存器是线程私有的
__global__ void kernel() {
float x = 1.0f; // x 是寄存器, 线程 0 的 x 跟线程 1 的 x 是两个独立寄存器
int y = threadIdx.x;
// 线程间想交换 x: 不能直接, 必须用 shfl 或 SMEM
int peer_y = __shfl_sync(0xFFFFFFFF, y, 0); // 拿到 lane 0 的 y
}
寄存器不能在线程间共享。要交换数据必须用 warp shuffle 或 SMEM。
4.2.2 寄存器数量的硬约束
Hopper 上每个 SM 总共 65536 个 32-bit 寄存器。如果每个线程要 N 个寄存器,那 SM 上能 active 的线程数最多是 65536 / N。这个约束直接决定 occupancy:
每线程寄存器数 | SM 上最大 active 线程数 | Occupancy
─────────────────────────────────────────────────────────────
16 | 4096 | 100% (上限是 2048)
32 | 2048 | 100%
64 | 1024 | 50%
128 | 512 | 25%
255 | 256 | 12.5% (256 是单线程上限)
所以"用更多寄存器存中间结果"和"让更多 warp 活跃以隐藏延迟"是直接竞争的。第 12 章手写 GEMM 时会非常具体地讨论这个权衡。
4.2.3 Register Spill:性能杀手
如果一个线程需要的寄存器超过单线程上限(255 个 32-bit),编译器会把"装不下的寄存器"溢出到 local memory。Local memory 物理上在 HBM 里——就是说,你以为是寄存器的访问,实际上变成了 HBM 访问,慢 800 倍。
判断是否发生 spill 最简单的方法是 nvcc 加 -Xptxas -v:
ptxas info: Used 80 registers, 0 stack, 0 bytes spill stores, 0 bytes spill loads
spill stores/loads = 0 是好的。如果非零,说明发生了 spill。任何严肃的 LLM kernel 都不能有 spill,第 12、17 章会详细讲怎么避免。
4.3 Shared Memory:黄金 228 KB
SMEM 是 LLM kernel 优化的核心舞台。它在物理上是 SM 内部的一块 SRAM,访问延迟约 25 cycles,带宽极高(每 SM ~50 TB/s,是 HBM 全卡带宽 3.35 TB/s 的 ~15 倍——但要记住 HBM 是被全卡 132 SM 共享的)。
SMEM 的两个关键特性需要特别理解:32 banks 和 bank conflict。
4.3.1 SMEM 被切成 32 个 bank
SMEM 物理上不是一整块 SRAM,而是 32 个 bank——每个 bank 可以独立地处理读/写请求。一个 warp(32 线程)同时访问 SMEM 时,硬件并行地让 32 个 bank 各自服务一个线程。
SMEM (228 KB)
┌──────────────────────────────────────────────────────────────┐
│ bank 0 bank 1 bank 2 ... bank 30 bank 31 │
│ 4 B 4 B 4 B ... 4 B 4 B <- 第 1 个 32-bit 字 │
│ 4 B 4 B 4 B ... 4 B 4 B <- 第 2 个 32-bit 字 │
│ ... │
└──────────────────────────────────────────────────────────────┘
地址 → bank 映射: bank_id = (addr / 4) % 32
地址映射意味着:
smem[0], smem[1], ... smem[31]分别落在 bank 0..31。smem[32], smem[33], ... smem[63]也分别落在 bank 0..31。smem[i]和smem[i + 32k]落在同一个 bank。
4.3.2 Bank Conflict 的代价
如果一个 warp 内 N 个线程访问同一个 bank 的不同地址,硬件只能串行处理——这叫 N-way bank conflict,性能是无冲突的 1/N。
// 反例: 32-way bank conflict
__shared__ float smem[32 * 32];
int tid = threadIdx.x;
float v = smem[tid * 32]; // 所有 32 线程访问 bank 0 -> 32-way conflict!
// 性能是无冲突的 1/32
// 正例: 无冲突
__shared__ float smem[32 * 32];
int tid = threadIdx.x;
float v = smem[tid]; // 32 线程访问 32 个不同的 bank -> 无冲突
// 也是正例: 广播 (broadcast)
__shared__ float smem[32];
int tid = threadIdx.x;
float v = smem[0]; // 所有 32 线程都访问 bank 0 同一个地址 -> 硬件检测到 broadcast,
// 0 cycle 额外开销
特殊情况:所有线程访问同一个地址会触发硬件 broadcast,零额外开销。但只要有一部分线程访问同 bank 不同地址,就是 conflict。
4.3.3 矩阵转置的经典 bank conflict
最经典的 bank conflict 例子是矩阵转置:
// 原始矩阵 (32 行 × 32 列), 转置后写入 SMEM
__shared__ float smem[32][32];
int tid = threadIdx.x;
int row = blockIdx.y;
// 把 mat[row][tid] 写到 smem[tid][row] (转置)
smem[tid][row] = mat[row * 32 + tid];
// 注意 smem[tid][row]:
// tid=0, row=0: smem[0][0] -> bank 0
// tid=1, row=0: smem[1][0] -> 偏移 128 字节 -> bank 0 (32-way conflict!)
// tid=2, row=0: smem[2][0] -> 偏移 256 字节 -> bank 0
// ...
// 32 个线程全部撞到 bank 0!
解决方法是 padding:
__shared__ float smem[32][33]; // 多一列 (33 而不是 32)
smem[tid][row] = mat[row * 32 + tid];
// tid=0, row=0: smem[0][0] -> bank 0
// tid=1, row=0: smem[1][0] -> 偏移 132 字节 -> bank 1
// tid=2, row=0: smem[2][0] -> 偏移 264 字节 -> bank 2
// ... 完美错开!
这种 padding 叫 "+1 trick"。它浪费了一点 SMEM 容量(每 32 列加 1 列),但完全消除 bank conflict,性能差距可以是 32 倍。
4.3.4 Hopper 的 swizzle:硬件加速的反 conflict
Hopper 引入了硬件 swizzle:TMA 指令可以指定一种 swizzle 模式,硬件在写 SMEM 时自动按 swizzle 模式重排地址,无需程序员手动 padding。
NVIDIA 提供四种内置 swizzle 模式:
INTERLEAVE_NONE:不 swizzle。INTERLEAVE_32B:每 32 字节一个块,块内 swizzle。INTERLEAVE_64B:每 64 字节一个块。INTERLEAVE_128B:每 128 字节一个块(最常用)。
128B swizzle 对 16×16 BF16 fragment(256 字节)是完美匹配——一行 fragment 恰好两个 128B 块,写入时硬件自动错开 bank。
第 12 章会用 ldmatrix + swizzle 给 Tensor Core 喂 fragment,第 17 章会用 TMA + swizzle 给 WGMMA 喂 tile。
4.4 L2 Cache:被忽视的中间层
L2 cache 50 MB,对所有 SM 共享。它是连接"片上 SMEM"和"片外 HBM"的中间层。
L2 的几个关键特性:
4.4.1 L2 是 set-associative 的
L2 cache line 大小 128 字节,set-associative 设计。同一个地址只能映射到 L2 的某一组 way 上,如果这组 way 都被占用,新数据来时会驱逐旧数据。
实践影响:两个数据流在 L2 上"打架"——比如一个 stream 的 KV cache 和另一个 stream 的 weight 都映射到同一组 way,互相驱逐对方,导致 L2 命中率塌陷。
4.4.2 L2 持久化(L2 Persistence)
CUDA 11+ 提供了 L2 persistence API,让程序员显式标记某些数据"应该常驻 L2":
// 把某段地址标记为 persistent (Hopper 上最多 75% L2 = 37.5 MB)
cudaStreamAttrValue attr = {};
attr.accessPolicyWindow.base_ptr = kv_cache_ptr;
attr.accessPolicyWindow.num_bytes = 32 * 1024 * 1024; // 32 MB
attr.accessPolicyWindow.hitRatio = 1.0;
attr.accessPolicyWindow.hitProp = cudaAccessPropertyPersisting;
attr.accessPolicyWindow.missProp = cudaAccessPropertyStreaming;
cudaStreamSetAttribute(stream,
cudaStreamAttributeAccessPolicyWindow, &attr);
这对反复访问的小热数据(比如 RoPE 的 cos/sin 表、layer-wise scale 表)有用。但对冷数据(KV cache 这种大几个 GB)没用——它放不下 L2。
4.4.3 L2 的 streaming 写
LLM 算子里很多写操作是 write-only(输出张量),写完之后这次 kernel 不会再读。这种数据放进 L2 没意义——只会驱逐有用的数据。CUDA 提供 streaming store 来避开 L2:
// PTX: st.global.cs (cs = cache streaming, 不污染 L2)
asm("st.global.cs.f32 [%0], %1;" :: "l"(addr), "f"(val));
或者 CUDA C++ 的 hint:
__stcs(addr, val); // streaming store, 绕过 L2
__ldcg(addr); // global cache load (默认)
__ldca(addr); // all cache load (打到 L1+L2)
__ldcs(addr); // streaming load, 绕过 L2
这些 hint 在 GEMM epilogue(写最终结果)和 attention 的 mask 写入这种"用一次就扔"的场景非常有用。
4.5 HBM3:3.35 TB/s 的金矿
HBM3 是 Hopper 与外部世界的接口。80 GB 容量、3.35 TB/s 带宽——听起来很多,但被 132 SM 平摊后,每 SM 大约 25 GB/s。这就是为什么"减少 HBM 访问"是 LLM 算子的第一性原理。
4.5.1 Coalesced Access:访存的第一原则
HBM 的内存控制器以 transaction 为单位访问——每次最少读 32 字节。一个 warp(32 线程)发 32 个 4 字节读请求时:
- 如果 32 个请求是连续的 128 字节(地址 0..127):硬件合并成 1 次 128 字节 transaction,效率 100%。
- 如果 32 个请求散落在 4096 字节范围内(每个间隔 128 字节):硬件发出 32 次独立 transaction(每次只用 4 字节有效数据),效率 1/32。
flowchart LR
subgraph C [Coalesced]
direction LR
T1[32 线程读地址 0,4,8,...124] --> M1[1 次 128B transaction]
M1 --> Eff1[效率 100%]
end
subgraph U [Uncoalesced]
direction LR
T2[32 线程读地址 0,128,256,...3968] --> M2[32 次 32B transaction]
M2 --> Eff2[效率 12.5% 实际占用 1024B 总线 拿到 128B 有效数据]
end
LLM 算子里 coalesced 的常见保证方式:
- 行优先存储 + 沿最低维度展开线程:让 thread 0..31 访问相邻 32 个元素。
- vectorized load(float4 / int4):每个线程一次读 16 字节,4 个线程就读完一个 cache line,更易合并。
- TMA:硬件自动保证 coalesced,无需程序员关心。
4.5.2 Vectorized Load:一行代码的 4x
在 Ampere 之前没有 TMA 时,最常用的访存优化是 vectorized load:
// 标量 load: 4 次访存
float v0 = arr[i + 0];
float v1 = arr[i + 1];
float v2 = arr[i + 2];
float v3 = arr[i + 3];
// Vectorized load: 1 次访存读 16 字节
float4 v = *reinterpret_cast<float4*>(&arr[i]);
// 或:
float4 v = __ldg(reinterpret_cast<const float4*>(&arr[i]));
float4 在硬件层面对应一条 LDG.E.128 指令,单条指令读 16 字节。这有几个好处:
- 指令带宽减半甚至更多:原来 4 条指令变成 1 条。
- 更易触发 coalesced:8 个线程的 float4 = 128 字节,正好一个 cache line。
- 更高的内存级并行度:一条指令在 pipeline 里"占用"的资源更少。
但要注意 vectorized load 要求地址 16 字节对齐。如果输入指针没对齐,要么硬件 trap,要么慢路径处理。LLM 算子里常用 cudaMallocAlign 保证对齐。
4.5.3 一个真实案例:从 256 GB/s 到 2 TB/s
最能说明问题的是这样一个例子:对一个 long×long 的 INT4 量化矩阵做反量化(dequantize)。
// 朴素版本: 每线程读 1 个 INT4 (实际读 1 字节, 用其中 4 位)
int8_t packed = arr[tid];
int8_t a = packed & 0x0F;
int8_t b = (packed >> 4) & 0x0F;
out[tid * 2 + 0] = scale * (float)a;
out[tid * 2 + 1] = scale * (float)b;
// 测得带宽: ~256 GB/s
这版本看起来没毛病——线程是 coalesced 的。但带宽只有 256 GB/s(理论上限的 7%)。
问题在哪?每条指令处理的数据太少。读 1 字节、做 4 个算术、写 8 字节——指令带宽(每 SM 每周期 4 条)成为瓶颈。
优化版:
// Vectorized 版本: 每线程读 16 字节 = 32 个 INT4
uint4 packed = *reinterpret_cast<const uint4*>(&arr[tid * 16]);
// 在寄存器里展开 32 个 INT4
#pragma unroll
for (int i = 0; i < 32; ++i) {
int8_t v = (packed.x >> (i * 4)) & 0x0F; // (示意, 实际跨 uint4 字段)
out[tid * 32 + i] = scale * (float)v;
}
// 测得带宽: ~2.1 TB/s (62% of theoretical)
10 倍提升来自三处:
- 指令密度:每条 LDG 处理 16 字节而不是 1 字节。
- 更高 ILP:编译器可以把 32 次解包并行调度。
- 写入也 vectorized:每线程一次 STG 写 32×4=128 字节。
这个案例很典型——LLM 推理的 dequant kernel 几乎都长这个样。第 9 章会详细讲。
4.6 算术强度与 Roofline 模型
理解了内存层级,就可以引入一个关键的性能分析工具:Roofline 模型。
Arithmetic Intensity(算术强度)= 算术操作数 ÷ 内存访问字节数
对于一个 kernel:
- 如果算术强度低(比如 1 FLOP/byte),它一定被带宽限制。
- 如果算术强度高(比如 1000 FLOPs/byte),它有可能被算力限制。
- 临界点就是:算力峰值 ÷ 带宽峰值。Hopper FP16 是 989 TFLOPs / 3.35 TB/s ≈ 295 FLOPs/byte。
flowchart LR
X[算术强度 FLOPs/byte] --> Y[实际性能 TFLOPs]
subgraph Roofline
direction TB
BW[带宽墙: y = 3.35 × x]
COMP[算力墙: y = 989]
BW -.-> COMP
end
把不同 kernel 标在 roofline 上:
实际 TFLOPs/s
▲
989 ─────┼────────────────────── 算力天花板
│ ╱
│╱
│
│ GEMM (4096^3)
│ AI=1300, perf~989
│
500 ─┤
│
│
│ FA2 (long seq)
│ AI~80, perf~270
│ ╱
100 ─┤ ╱
│╱ FA2 (short seq)
│ AI~25, perf~84
│
│ Attention naive
│ AI~10, perf~33
│
└─────────────────────► 算术强度 (FLOPs/byte)
1 10 100 1000
带宽屋顶 (3.35 TB/s)
阈值 295 FLOPs/byte 来自 H100 FP16 算力 / 带宽 = 989/3.35。
Roofline 是 LLM 算子优化的"温度计"。看一个新算子,先估算它的算术强度,就能预估它是带宽 bound 还是算力 bound。
LLM 工作负载的算术强度分布:
| 算子 | 算术强度 | 类型 | 优化方向 |
|---|---|---|---|
| GEMM (M=N=K=4096) | ~1300 | 算力 bound | Tensor Core 利用 |
| GEMM (small batch) | ~10-50 | 带宽 bound | 减少 HBM 访问 |
| Attention (FA2, 长序列) | ~80 | 接近 bound | TMA + WGMMA 流水 |
| Attention (短序列) | ~25 | 带宽 bound | 减少 KV 重复读 |
| LayerNorm | ~1 | 带宽 bound | fused kernel |
| RMSNorm | ~1 | 带宽 bound | fused kernel |
| Softmax | ~3 | 带宽 bound | 减少 pass 数 |
| Element-wise | 0.25 | 严重带宽 bound | 算子融合 |
LLM 推理的一个隐藏事实:绝大多数算子都跑不到算力峰值,它们被带宽卡住。这是为什么 Tensor Core 989 TFLOPs 的算力大部分时间是闲置的——HBM 带宽追不上。减少 HBM 访问比"调用更厉害的 Tensor Core 指令"更重要。
4.7 一个完整案例:reduce 的访存优化
为了把这一章所有概念串起来,看一个具体例子:对 1 亿个 float 求和。
朴素版本:
__global__ void reduce_v1(const float* arr, float* out, int N) {
__shared__ float smem[256];
int tid = threadIdx.x;
int gid = blockIdx.x * 256 + tid;
smem[tid] = (gid < N) ? arr[gid] : 0.0f;
__syncthreads();
for (int s = 128; s > 0; s >>= 1) {
if (tid < s) smem[tid] += smem[tid + s];
__syncthreads();
}
if (tid == 0) atomicAdd(out, smem[0]);
}
// 实测 H100: ~1500 GB/s (45% of HBM 峰值)
带宽 1500 GB/s——离 3.35 TB/s 还很远。问题在哪?
- 每线程读 1 个 float(4 字节)。指令带宽是瓶颈。
atomicAdd在 HBM 上,所有 block 的 atomic 串行化。
优化版:
__global__ void reduce_v2(const float* arr, float* out, int N) {
__shared__ float smem[256];
int tid = threadIdx.x;
int gid = blockIdx.x * 256 * 4 + tid * 4;
// 1. Vectorized load: 每线程读 4 个 float
float4 v = (gid + 4 <= N) ?
*reinterpret_cast<const float4*>(&arr[gid]) :
make_float4(0, 0, 0, 0);
float local_sum = v.x + v.y + v.z + v.w;
// 2. Warp shuffle 归约 (无 SMEM, 无 sync)
for (int offset = 16; offset > 0; offset >>= 1) {
local_sum += __shfl_xor_sync(0xFFFFFFFF, local_sum, offset);
}
// 3. Block 内汇总 warp partial sum (8 warps -> 1 block)
if (tid % 32 == 0) smem[tid / 32] = local_sum;
__syncthreads();
if (tid < 8) {
local_sum = smem[tid];
for (int offset = 4; offset > 0; offset >>= 1) {
local_sum += __shfl_xor_sync(0xFF, local_sum, offset);
}
if (tid == 0) atomicAdd(out, local_sum);
}
}
// 实测 H100: ~3050 GB/s (91% of HBM 峰值)
提升了 2 倍,从 45% 到 91%。改的是什么?
- vectorized load:每线程读 16 字节,cache line 利用率拉满。
- warp shuffle 替代 SMEM 归约:减少 SMEM 访问和 sync 次数。
- block 内 partial sum:减少 atomic 次数。
这个例子的细节会在第 5 章 reduce 那一章完整展开。但要看到的是:同样的算法,不同的访存策略,性能差 2 倍。这种差距在 LLM 算子里到处都是。
4.8 这一章的小结与下一章
这一章我们建立了 GPU 访存的精确认知:
- 五层内存的代价表:register (0c) → SMEM (25c) → L1 (30c) → L2 (250c) → HBM (800c)。每一层延迟差几倍到几十倍。
- Coalesced 访存是 HBM 的第一性原则:32 线程访问连续 128 字节 = 1 次 transaction,散落访问 = 32 次 transaction。
- SMEM 的 32 banks 与 bank conflict:访问同 bank 不同地址 → N-way conflict → 性能 1/N。padding 或 swizzle 可避免。
- Roofline 决定了一个 kernel 的上限:算术强度低 = 带宽 bound,高 = 算力 bound。LLM 大多算子在 100 以下,都是带宽 bound。
- Vectorized load 与 streaming store:用
float4、__stcs等可以显著提升带宽利用率。
第 5 章我们正式开始写 kernel 代码——以 reduce 为例,把第 1-4 章的所有概念落到具体的 kernel 实现上。读者会看到一个 reduce kernel 从 100 GB/s 一步步优化到 3 TB/s 的完整过程,这是后续 GEMM/FA2 优化的"压缩演练"。
本章动手练习:
- 写一个 kernel 故意制造 32-way bank conflict,再写一个无冲突版本,用 Nsight Compute 测 SMEM 带宽对比。
- 跑 NVIDIA 官方 cuda-samples 里的
bandwidthTest,记录你 H100 的实际 H2D / D2H / D2D 带宽。D2D 应该是 ~3000 GB/s。- 计算一个 7B 模型 (hidden_size=4096, num_layers=32) 推理时一次 forward pass 的总 HBM 访问字节数。验证一下"为什么大模型推理是带宽 bound"。