CUDA 算子工程:手写 FlashAttention v2 之路

第 18 章 Persistent Kernel 与 Producer-Consumer

作者 杨艺韬 · 2,083 字

第 18 章 Persistent Kernel 与 Producer-Consumer

"A persistent kernel turns CUDA's block-based scheduling on its head: instead of launching enough blocks to fill the GPU, you launch exactly the right number — and let each block do many tiles' worth of work." ——Mark Harris's old NVIDIA blog (paraphrased)

18.1 经典 Kernel vs Persistent Kernel

到目前为止,本书所有 kernel 都是"经典"模式:

// 经典模式: grid_size = N_tiles, 每 block 一个 tile
dim3 grid(M_tiles, N_tiles);
my_kernel<<<grid, block>>>(...);

GPU 调度器把 M_tiles × N_tiles 个 block 自动分配到 SM 上。如果 block 数 >> SM 数,调度器自动队列、循环执行。

这个模式简单优雅,对大 tile 工作负载(比如训练时的大 GEMM)非常合适。但对小 tile 工作负载有几个隐藏问题:

  1. Block 启动开销:每个 block 启动有 ~100-200 cycles 的 dispatch + register init 开销。tile 多但小(比如 LLM decoding 时每个 batch_size=1, seq_len=1)时,启动开销占总时间的可观比例。
  2. L2 cache 不复用:每个 block 独立运行,相邻 block 的数据 L2 cache 命中无法保证。
  3. 调度延迟:block 数远超 SM 数时,调度器的全局调度有延迟。

Persistent Kernel 的思路:grid_size 固定为 SM 数(或其倍数),每个 block 通过 grid-stride loop 处理多个 tile:

// Persistent 模式
int n_sms = 132;  // H100
dim3 grid(n_sms);
persistent_kernel<<<grid, block>>>(...);

__global__ void persistent_kernel(int total_tiles, ...) {
    while (true) {
        // 用 atomic counter 抢下一个 tile
        int tile_id = atomicAdd(&global_counter, 1);
        if (tile_id >= total_tiles) break;

        process_tile(tile_id, ...);
    }
}

每个 block 启动一次,活到所有 tile 都做完才退出。

18.2 Persistent Kernel 的优势

18.2.1 摊薄启动开销

如果总 tile 数 = N,传统 kernel 启动 N 个 block;persistent kernel 只启动 132 个(SM 数)。block 启动总开销从 N × 200 cycles 降到 132 × 200 cycles,N 大时几乎可以忽略

例:N=10000 tiles, 节省启动开销 ≈ (10000 - 132) × 200 cycles = ~2 × 10^6 cycles ≈ 1ms(在 H100 1.83 GHz 下)。1ms 听起来不多,但 LLM 推理一次 forward 要做几百次 kernel launch,总开销几十 ms。

18.2.2 L2 Cache 复用

同一个 block 在执行多个 tile 时,SMEM/L1 上的数据可以被复用。比如 attention 中相邻 query block 共享 K/V tile,persistent 模式下可以把 K/V tile 缓存在 SMEM,避免重复 HBM 加载。

18.2.3 更友好的负载均衡

如果 tiles 之间的工作量不均衡(比如 sparse attention,某些 tile 几乎是空的),传统 kernel 中"轻 tile"的 block 早早完成、SM 空闲;"重 tile"的 block 还在算。Persistent kernel 中每个 block 通过 atomic counter 抢任务——自动负载均衡:完成快的 block 多抢几个,慢的少抢。

18.2.4 CUDA Graph 友好

LLM 推理的核心优化之一是 CUDA Graph——把多个 kernel launch 录成一个 graph,整体提交,省去重复 launch 开销。但 CUDA Graph 要求每次 launch 的 grid_size 一样。Persistent kernel 天然满足这一点(grid_size 永远 = SM 数),适配 CUDA Graph 极其顺滑。

18.3 Tile Scheduler 设计

Persistent kernel 的核心是 tile scheduler——决定哪个 block 处理哪些 tile。最简单的 scheduler 是 atomic counter:

__device__ int next_tile() {
    __shared__ int s_tile;
    if (threadIdx.x == 0) {
        s_tile = atomicAdd(&global_counter, 1);
    }
    __syncthreads();
    return s_tile;
}

但 atomic 有竞争开销。更高级的 scheduler:

18.3.1 Static Scheduler

在 host 端预先分配每个 block 处理哪些 tile:

// Host 端
int tiles_per_sm = (total_tiles + n_sms - 1) / n_sms;

// Kernel 端
__global__ void kernel(int tiles_per_sm, ...) {
    int my_first = blockIdx.x * tiles_per_sm;
    int my_last = min((blockIdx.x + 1) * tiles_per_sm, total_tiles);
    for (int t = my_first; t < my_last; ++t) {
        process_tile(t, ...);
    }
}

简单、无竞争,但负载不均衡时差。

18.3.2 Round-Robin Scheduler

__global__ void kernel(...) {
    for (int t = blockIdx.x; t < total_tiles; t += gridDim.x) {
        process_tile(t, ...);
    }
}

也是无竞争,且天然循环——所有 block 平均分担。CUTLASS 默认 scheduler。

18.3.3 Dynamic Atomic Scheduler

__global__ void kernel(...) {
    while (true) {
        int t = atomic_get_next_tile();
        if (t >= total_tiles) break;
        process_tile(t, ...);
    }
}

最灵活,但 atomic 开销。tile 内工作量大时(GEMM 这种),atomic 开销可以忽略。

18.3.4 GPU-side Tile Scheduler with Optimization

CUTLASS 3.x 的 PersistentTileScheduler 把多种策略组合:

18.4 Persistent + Producer/Consumer

把第 17 章的 Producer/Consumer 模式和 Persistent kernel 组合,给出 现代 attention kernel 的最终形态

__global__ void persistent_fa3(...) {
    // 1. Persistent loop: 抢 tile
    while (true) {
        int tile_id;
        if (threadIdx.x == 0) {
            tile_id = atomic_get_next_tile();
        }
        tile_id = __shfl_sync(0xFFFFFFFF, tile_id, 0);

        if (tile_id >= total_tiles) break;

        int (q_tile_idx, head_idx, batch_idx) = decode_tile_id(tile_id);

        // 2. 在这个 tile 上跑 FA3 producer/consumer 流水
        if (warp_id == 0) {
            producer_main(q_tile_idx, head_idx, batch_idx, ...);
        } else {
            consumer_main(q_tile_idx, head_idx, batch_idx, ...);
        }

        // 3. (可选) 同步, 准备下一个 tile
        __syncthreads();
    }
}

这就是 vLLM、TensorRT-LLM、SGLang 中"高性能 attention kernel"的标准结构。

18.5 LLM Decoding 的特殊优化:Flash-Decoding

LLM 推理的 decoding 阶段有一个特殊形态:Q 只有 1 个 token(当前生成的 token),但 K/V 有几千甚至几十万个(已生成的所有历史 tokens)。

这个形态下:

Flash-Decoding 的解决方案:把 K 维度也切分给多个 SM,每个 SM 算一段 K 的 partial 结果,最后合并。

# Flash-Decoding 伪代码
# Q: [1, d] (单 token)
# K, V: [N, d] (N 可达 100K)

n_splits = 8  # 把 K 维度切 8 份
chunk_size = N // n_splits

# Phase 1: 每个 SM 算自己那一段 K 的 partial
for split_id in range(n_splits):  # 8 个 block 并行
    K_chunk = K[split_id * chunk_size : (split_id + 1) * chunk_size]
    V_chunk = V[split_id * chunk_size : (split_id + 1) * chunk_size]
    S_chunk = Q @ K_chunk.T
    P_chunk = softmax(S_chunk)  # 局部 softmax
    O_partial[split_id] = P_chunk @ V_chunk
    LSE_partial[split_id] = lse(S_chunk)

# Phase 2: 合并 partial
final_lse = combine_lse(LSE_partial)
final_O = sum(O_partial[s] * exp(LSE_partial[s] - final_lse) for s in range(n_splits))
final_O /= sum(exp(LSE_partial[s] - final_lse) for s in range(n_splits))

Flash-Decoding 把单 Q 的 attention 从 1 个 SM 提升到 8-32 个 SM,解码延迟减半甚至更多。vLLM 0.4+ 默认启用。

18.6 Persistent Kernel 的代价

Persistent 模式不是免费的:

18.6.1 寄存器固定

Persistent kernel 的所有 tile 共享同一个寄存器布局。如果不同 tile 的最优寄存器需求不一样(比如 short context 的 attention 和 long context 的 attention),persistent 模式下只能取最大需求——可能浪费寄存器。

18.6.2 SMEM 复用要谨慎

不同 tile 之间复用 SMEM 听起来好,但需要小心 race condition——前一个 tile 还没写完 SMEM,后一个 tile 已经在读了。需要 __syncthreads__threadfence_block 同步。

18.6.3 调度复杂度

简单的 round-robin scheduler 可能导致负载不均;动态 atomic scheduler 又有竞争开销。CUTLASS 提供的 stream-K scheduler 是个不错的折中,但实现复杂。

18.6.4 不适合所有工作负载

Persistent 适合大量小 tile有复用机会的工作。如果 tile 都是大 GEMM(比如训练),传统模式同样高效,persistent 没有优势,反而增加复杂度。

18.7 实战案例:vLLM 的 PagedAttention

vLLM 的 PagedAttention kernel 是 persistent + producer/consumer 的典型应用。其核心设计:

  1. Persistent 模式:grid_size = N_blocks * N_heads,每个 block 处理一个 (page, head) tuple。
  2. 每个 page 是 16 个 token 的 KV:page-aligned 访存。
  3. Atomic 累加:page 级别的 partial sum 通过 atomic 合并。
  4. CUDA Graph 友好:因为 grid_size 固定(与 batch 中的 token 总数无关),CUDA Graph 可以稳定捕获。

详细设计请参考《vLLM 内核探秘》第 4 章 PagedAttention。

18.8 第四篇收官:从理论到 SOTA

第四篇我们完成了 attention kernel 优化的完整旅程:

章节 主题 性能(H100, FP16, N=4096)
第 14 章 IO-Aware 思想 朴素 attention 40 TFLOPs
第 15 章 FA2 forward 骨架 ~530 TFLOPs
第 16 章 FA2 backward ~250 TFLOPs
第 17 章 TMA + WGMMA + Warp Spec (FA3) ~740 TFLOPs (FP16) / 1200 TFLOPs (FP8)
第 18 章 Persistent + Producer/Consumer + Flash-Decoding 长上下文加速

从 40 TFLOPs 到 1200 TFLOPs,30 倍性能跃迁。这就是 GPU 工程的极限——同一份算法,不同的实现,性能差几十倍。

到这里,读者已经具备了 LLM 推理 / 训练中所有核心 kernel(GEMM、Attention、LayerNorm、Softmax、量化)的优化能力。

第五篇(第 19-21 章)我们换个视角——讲性能工程的工具链:怎么用 Nsight Compute 找瓶颈、怎么读 PTX/SASS、常见性能反模式。这些工具是日常 kernel 调优中的瑞士军刀,没有它们,再好的优化思路也找不到落点。

本章动手练习

  1. 把第 17 章的 FA3 forward 改成 persistent 模式,对比性能。
  2. 实现 Flash-Decoding(n_splits=8),在 N=64K 长上下文 decoding 上测试加速比。
  3. 阅读 vLLM 的 csrc/attention/attention_kernels.cu,找到 persistent loop 的代码位置。