CUDA 算子工程:手写 FlashAttention v2 之路
第 18 章 Persistent Kernel 与 Producer-Consumer
第 18 章 Persistent Kernel 与 Producer-Consumer
"A persistent kernel turns CUDA's block-based scheduling on its head: instead of launching enough blocks to fill the GPU, you launch exactly the right number — and let each block do many tiles' worth of work." ——Mark Harris's old NVIDIA blog (paraphrased)
18.1 经典 Kernel vs Persistent Kernel
到目前为止,本书所有 kernel 都是"经典"模式:
// 经典模式: grid_size = N_tiles, 每 block 一个 tile
dim3 grid(M_tiles, N_tiles);
my_kernel<<<grid, block>>>(...);
GPU 调度器把 M_tiles × N_tiles 个 block 自动分配到 SM 上。如果 block 数 >> SM 数,调度器自动队列、循环执行。
这个模式简单优雅,对大 tile 工作负载(比如训练时的大 GEMM)非常合适。但对小 tile 工作负载有几个隐藏问题:
- Block 启动开销:每个 block 启动有 ~100-200 cycles 的 dispatch + register init 开销。tile 多但小(比如 LLM decoding 时每个 batch_size=1, seq_len=1)时,启动开销占总时间的可观比例。
- L2 cache 不复用:每个 block 独立运行,相邻 block 的数据 L2 cache 命中无法保证。
- 调度延迟:block 数远超 SM 数时,调度器的全局调度有延迟。
Persistent Kernel 的思路:grid_size 固定为 SM 数(或其倍数),每个 block 通过 grid-stride loop 处理多个 tile:
// Persistent 模式
int n_sms = 132; // H100
dim3 grid(n_sms);
persistent_kernel<<<grid, block>>>(...);
__global__ void persistent_kernel(int total_tiles, ...) {
while (true) {
// 用 atomic counter 抢下一个 tile
int tile_id = atomicAdd(&global_counter, 1);
if (tile_id >= total_tiles) break;
process_tile(tile_id, ...);
}
}
每个 block 启动一次,活到所有 tile 都做完才退出。
18.2 Persistent Kernel 的优势
18.2.1 摊薄启动开销
如果总 tile 数 = N,传统 kernel 启动 N 个 block;persistent kernel 只启动 132 个(SM 数)。block 启动总开销从 N × 200 cycles 降到 132 × 200 cycles,N 大时几乎可以忽略。
例:N=10000 tiles, 节省启动开销 ≈ (10000 - 132) × 200 cycles = ~2 × 10^6 cycles ≈ 1ms(在 H100 1.83 GHz 下)。1ms 听起来不多,但 LLM 推理一次 forward 要做几百次 kernel launch,总开销几十 ms。
18.2.2 L2 Cache 复用
同一个 block 在执行多个 tile 时,SMEM/L1 上的数据可以被复用。比如 attention 中相邻 query block 共享 K/V tile,persistent 模式下可以把 K/V tile 缓存在 SMEM,避免重复 HBM 加载。
18.2.3 更友好的负载均衡
如果 tiles 之间的工作量不均衡(比如 sparse attention,某些 tile 几乎是空的),传统 kernel 中"轻 tile"的 block 早早完成、SM 空闲;"重 tile"的 block 还在算。Persistent kernel 中每个 block 通过 atomic counter 抢任务——自动负载均衡:完成快的 block 多抢几个,慢的少抢。
18.2.4 CUDA Graph 友好
LLM 推理的核心优化之一是 CUDA Graph——把多个 kernel launch 录成一个 graph,整体提交,省去重复 launch 开销。但 CUDA Graph 要求每次 launch 的 grid_size 一样。Persistent kernel 天然满足这一点(grid_size 永远 = SM 数),适配 CUDA Graph 极其顺滑。
18.3 Tile Scheduler 设计
Persistent kernel 的核心是 tile scheduler——决定哪个 block 处理哪些 tile。最简单的 scheduler 是 atomic counter:
__device__ int next_tile() {
__shared__ int s_tile;
if (threadIdx.x == 0) {
s_tile = atomicAdd(&global_counter, 1);
}
__syncthreads();
return s_tile;
}
但 atomic 有竞争开销。更高级的 scheduler:
18.3.1 Static Scheduler
在 host 端预先分配每个 block 处理哪些 tile:
// Host 端
int tiles_per_sm = (total_tiles + n_sms - 1) / n_sms;
// Kernel 端
__global__ void kernel(int tiles_per_sm, ...) {
int my_first = blockIdx.x * tiles_per_sm;
int my_last = min((blockIdx.x + 1) * tiles_per_sm, total_tiles);
for (int t = my_first; t < my_last; ++t) {
process_tile(t, ...);
}
}
简单、无竞争,但负载不均衡时差。
18.3.2 Round-Robin Scheduler
__global__ void kernel(...) {
for (int t = blockIdx.x; t < total_tiles; t += gridDim.x) {
process_tile(t, ...);
}
}
也是无竞争,且天然循环——所有 block 平均分担。CUTLASS 默认 scheduler。
18.3.3 Dynamic Atomic Scheduler
__global__ void kernel(...) {
while (true) {
int t = atomic_get_next_tile();
if (t >= total_tiles) break;
process_tile(t, ...);
}
}
最灵活,但 atomic 开销。tile 内工作量大时(GEMM 这种),atomic 开销可以忽略。
18.3.4 GPU-side Tile Scheduler with Optimization
CUTLASS 3.x 的 PersistentTileScheduler 把多种策略组合:
- 启动时 round-robin 分配第一组 tile(避免初始 atomic 竞争)。
- 之后用 atomic counter 处理剩余 tiles(动态平衡)。
- 支持 "stream-K":把一个大 tile 拆成多个小 tile 跨 SM 并行(解决长尾)。
18.4 Persistent + Producer/Consumer
把第 17 章的 Producer/Consumer 模式和 Persistent kernel 组合,给出 现代 attention kernel 的最终形态:
__global__ void persistent_fa3(...) {
// 1. Persistent loop: 抢 tile
while (true) {
int tile_id;
if (threadIdx.x == 0) {
tile_id = atomic_get_next_tile();
}
tile_id = __shfl_sync(0xFFFFFFFF, tile_id, 0);
if (tile_id >= total_tiles) break;
int (q_tile_idx, head_idx, batch_idx) = decode_tile_id(tile_id);
// 2. 在这个 tile 上跑 FA3 producer/consumer 流水
if (warp_id == 0) {
producer_main(q_tile_idx, head_idx, batch_idx, ...);
} else {
consumer_main(q_tile_idx, head_idx, batch_idx, ...);
}
// 3. (可选) 同步, 准备下一个 tile
__syncthreads();
}
}
这就是 vLLM、TensorRT-LLM、SGLang 中"高性能 attention kernel"的标准结构。
18.5 LLM Decoding 的特殊优化:Flash-Decoding
LLM 推理的 decoding 阶段有一个特殊形态:Q 只有 1 个 token(当前生成的 token),但 K/V 有几千甚至几十万个(已生成的所有历史 tokens)。
这个形态下:
- Q 只有 1 行,外层 Q 循环只有 1 次迭代。
- K/V 维度极长,内层 K 循环可能跑几百次。
- 单个 Q 的 attention 几乎全部分配给一个 block——只能用 1 个 SM。
- 其他 131 个 SM 闲置。
Flash-Decoding 的解决方案:把 K 维度也切分给多个 SM,每个 SM 算一段 K 的 partial 结果,最后合并。
# Flash-Decoding 伪代码
# Q: [1, d] (单 token)
# K, V: [N, d] (N 可达 100K)
n_splits = 8 # 把 K 维度切 8 份
chunk_size = N // n_splits
# Phase 1: 每个 SM 算自己那一段 K 的 partial
for split_id in range(n_splits): # 8 个 block 并行
K_chunk = K[split_id * chunk_size : (split_id + 1) * chunk_size]
V_chunk = V[split_id * chunk_size : (split_id + 1) * chunk_size]
S_chunk = Q @ K_chunk.T
P_chunk = softmax(S_chunk) # 局部 softmax
O_partial[split_id] = P_chunk @ V_chunk
LSE_partial[split_id] = lse(S_chunk)
# Phase 2: 合并 partial
final_lse = combine_lse(LSE_partial)
final_O = sum(O_partial[s] * exp(LSE_partial[s] - final_lse) for s in range(n_splits))
final_O /= sum(exp(LSE_partial[s] - final_lse) for s in range(n_splits))
Flash-Decoding 把单 Q 的 attention 从 1 个 SM 提升到 8-32 个 SM,解码延迟减半甚至更多。vLLM 0.4+ 默认启用。
18.6 Persistent Kernel 的代价
Persistent 模式不是免费的:
18.6.1 寄存器固定
Persistent kernel 的所有 tile 共享同一个寄存器布局。如果不同 tile 的最优寄存器需求不一样(比如 short context 的 attention 和 long context 的 attention),persistent 模式下只能取最大需求——可能浪费寄存器。
18.6.2 SMEM 复用要谨慎
不同 tile 之间复用 SMEM 听起来好,但需要小心 race condition——前一个 tile 还没写完 SMEM,后一个 tile 已经在读了。需要 __syncthreads 或 __threadfence_block 同步。
18.6.3 调度复杂度
简单的 round-robin scheduler 可能导致负载不均;动态 atomic scheduler 又有竞争开销。CUTLASS 提供的 stream-K scheduler 是个不错的折中,但实现复杂。
18.6.4 不适合所有工作负载
Persistent 适合大量小 tile 或有复用机会的工作。如果 tile 都是大 GEMM(比如训练),传统模式同样高效,persistent 没有优势,反而增加复杂度。
18.7 实战案例:vLLM 的 PagedAttention
vLLM 的 PagedAttention kernel 是 persistent + producer/consumer 的典型应用。其核心设计:
- Persistent 模式:grid_size = N_blocks * N_heads,每个 block 处理一个 (page, head) tuple。
- 每个 page 是 16 个 token 的 KV:page-aligned 访存。
- Atomic 累加:page 级别的 partial sum 通过 atomic 合并。
- CUDA Graph 友好:因为 grid_size 固定(与 batch 中的 token 总数无关),CUDA Graph 可以稳定捕获。
详细设计请参考《vLLM 内核探秘》第 4 章 PagedAttention。
18.8 第四篇收官:从理论到 SOTA
第四篇我们完成了 attention kernel 优化的完整旅程:
| 章节 | 主题 | 性能(H100, FP16, N=4096) |
|---|---|---|
| 第 14 章 | IO-Aware 思想 | 朴素 attention 40 TFLOPs |
| 第 15 章 | FA2 forward 骨架 | ~530 TFLOPs |
| 第 16 章 | FA2 backward | ~250 TFLOPs |
| 第 17 章 | TMA + WGMMA + Warp Spec (FA3) | ~740 TFLOPs (FP16) / 1200 TFLOPs (FP8) |
| 第 18 章 | Persistent + Producer/Consumer | + Flash-Decoding 长上下文加速 |
从 40 TFLOPs 到 1200 TFLOPs,30 倍性能跃迁。这就是 GPU 工程的极限——同一份算法,不同的实现,性能差几十倍。
到这里,读者已经具备了 LLM 推理 / 训练中所有核心 kernel(GEMM、Attention、LayerNorm、Softmax、量化)的优化能力。
第五篇(第 19-21 章)我们换个视角——讲性能工程的工具链:怎么用 Nsight Compute 找瓶颈、怎么读 PTX/SASS、常见性能反模式。这些工具是日常 kernel 调优中的瑞士军刀,没有它们,再好的优化思路也找不到落点。
本章动手练习:
- 把第 17 章的 FA3 forward 改成 persistent 模式,对比性能。
- 实现 Flash-Decoding(n_splits=8),在 N=64K 长上下文 decoding 上测试加速比。
- 阅读 vLLM 的
csrc/attention/attention_kernels.cu,找到 persistent loop 的代码位置。