CUDA 算子工程:手写 FlashAttention v2 之路
第 5 章 Reduction:从 atomic 到 cluster reduce
第 5 章 Reduction:从 atomic 到 cluster reduce
"If you can write a fast reduction, you understand GPU programming." ——CUDA 社区的不成文谚语
5.1 为什么 Reduction 是入门第一题
Reduction(归约)就是把一组数压成一个数——求和、求最大、求最小、求积。听起来简单,但在 GPU 上写好它需要把第 1-4 章的几乎所有概念都用上:
- SIMT 与 warp:reduce 的核心是"32 个值合并成 1 个值"。
- SMEM 与 bank conflict:跨 warp 的 partial sum 通过 SMEM 交换。
- Warp Shuffle:warp 内的 reduce 不应该走 SMEM。
- Coalesced 访存:HBM 读取需要 coalesced。
- Vectorized load:用
float4提升指令带宽。 - Cluster Reduce(Hopper 新):跨 block 协作。
- 算术强度:reduce 是极端 bandwidth-bound(AI ≈ 0.25),优化目标是逼近 HBM 峰值带宽。
Mark Harris 在 2007 年发表过一份经典的 reduction 优化教程(NVIDIA Webinar),把朴素 reduce 一步步优化了 7 个版本。这一章我们沿着他的思路,再加上 Volta 之后的 warp shuffle、Hopper 的 cluster reduce,给读者一份"现代版 Mark Harris reduce"。
最终目标:对 1 亿个 float 求和,达到 3+ TB/s(91%+ HBM 峰值)。
5.2 V0:朴素 atomic 版
最直观的写法:每个线程读一个元素,atomic 加到全局结果上。
__global__ void reduce_v0(const float* arr, float* out, int N) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < N) {
atomicAdd(out, arr[tid]);
}
}
性能:~50 GB/s(H100, N=1e8)。这是带宽峰值的 1.5%——惨不忍睹。
为什么这么慢?因为 N 个线程同时 atomic add 到一个全局变量上,HBM 上的 atomic 操作完全串行化。N=1e8 时这是 1 亿次原子操作,相当于一个全局锁被 1 亿次抢锁。
教训:永远不要在 reduce 的最内层用全局 atomic。atomic 应该是分层归约的最后一步,且参与 atomic 的元素数量 << SM 数量。
5.3 V1:Block 内 SMEM 归约
把 reduce 拆成两步:先在 block 内把一组元素归约成一个数(写到 partial sum 数组),再启动第二个 kernel 把 partial sum 归约成最终结果。
__global__ void reduce_v1(const float* arr, float* partial, int N) {
__shared__ float smem[256];
int tid = threadIdx.x;
int gid = blockIdx.x * blockDim.x + tid;
smem[tid] = (gid < N) ? arr[gid] : 0.0f;
__syncthreads();
// Block 内分层归约
for (int s = 1; s < blockDim.x; s *= 2) {
if (tid % (2 * s) == 0) {
smem[tid] += smem[tid + s];
}
__syncthreads();
}
if (tid == 0) partial[blockIdx.x] = smem[0];
}
性能:~700 GB/s。比 v0 提升 14 倍——主要是不再全局 atomic 了。
但 700 GB/s 还是只有峰值的 21%。问题在哪?看这段代码的关键瓶颈:
if (tid % (2 * s) == 0) { ... }
这一行触发严重的 warp divergence:
- s=1 时:tid=0,2,4,6,...30 活跃(16 个),tid=1,3,5,...31 闲置 → 50% divergence。
- s=2 时:tid=0,4,8,...28 活跃(8 个),其余闲置 → 75% divergence。
- s=4 时:tid=0,8,16,24 活跃(4 个)。
- ...
- s=64 时:只有 tid=0 活跃。
虽然 warp 内"活跃 lane 数减少"不会直接降低带宽(带宽瓶颈在 HBM 读),但算术单元利用率下降会让整个 kernel 变长。
5.4 V2:避免 warp divergence
把分层方式改一下:让前 N/2 个线程做加法,避免奇偶交替。
__global__ void reduce_v2(const float* arr, float* partial, int N) {
__shared__ float smem[256];
int tid = threadIdx.x;
int gid = blockIdx.x * blockDim.x + tid;
smem[tid] = (gid < N) ? arr[gid] : 0.0f;
__syncthreads();
// 关键: 让前 s 个线程做加法
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
if (tid < s) {
smem[tid] += smem[tid + s];
}
__syncthreads();
}
if (tid == 0) partial[blockIdx.x] = smem[0];
}
现在 warp divergence 大幅降低:
- s=128 时:tid=0..127 活跃(4 个完整 warp),tid=128..255 闲置(4 个完整 warp)。整个 warp 同进同退,无 warp 内 divergence。
- s=64 时:tid=0..63 活跃(2 个完整 warp)。
- s=32 时:tid=0..31 活跃(1 个完整 warp)。
- s=16 时:tid=0..15 活跃(半个 warp,开始有 divergence)。
只有最后几个迭代(s ≤ 16)有 divergence——这部分量级很小,影响不大。
性能:~1100 GB/s。从 21% 提升到 33%。
但还有个隐藏问题:SMEM bank conflict。看这一句:
smem[tid] += smem[tid + s];
s=1 时,tid 范围 0..127。读 smem[tid+1] 时:
- tid=0 读 smem[1] (bank 1)
- tid=1 读 smem[2] (bank 2)
- ...
- tid=31 读 smem[32] (bank 0)
虽然不是 32-way conflict,但读 smem[tid+1] 的访问模式和读 smem[tid] 重叠了——硬件需要两次发射。SMEM 的"sequential"读写本身有 conflict。
实际上 NVIDIA 编译器会自动 unroll 最后几次循环,处理这种边界。但写代码的时候应该意识到这种成本存在。
5.5 V3:每线程读多个元素
reduce 的算术强度只有 0.25 FLOPs/byte——意味着每读 1 字节做 0.25 次加法。如果每线程读多个元素,可以摊薄"读元素到 SMEM"的开销。
__global__ void reduce_v3(const float* arr, float* partial, int N) {
__shared__ float smem[256];
int tid = threadIdx.x;
int gid = blockIdx.x * (blockDim.x * 2) + tid;
// 每线程读 2 个元素, 直接相加
float v = 0.0f;
if (gid < N) v += arr[gid];
if (gid + blockDim.x < N) v += arr[gid + blockDim.x];
smem[tid] = v;
__syncthreads();
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
if (tid < s) {
smem[tid] += smem[tid + s];
}
__syncthreads();
}
if (tid == 0) partial[blockIdx.x] = smem[0];
}
每个 block 现在处理 2 * blockDim.x 个元素而不是 blockDim.x 个,grid_size 减半。性能:~1500 GB/s(45%)。
如果每线程读 4 个元素:
float v = 0.0f;
v += arr[gid];
v += arr[gid + blockDim.x];
v += arr[gid + blockDim.x * 2];
v += arr[gid + blockDim.x * 3];
性能:~1800 GB/s(54%)。继续提升,但收益递减。
5.6 V4:Vectorized Load 用 float4
把"每线程读 4 个 float"换成"每线程读 1 个 float4"——这一行代码的改动让指令带宽减为 1/4:
__global__ void reduce_v4(const float* arr, float* partial, int N) {
__shared__ float smem[256];
int tid = threadIdx.x;
int gid = blockIdx.x * blockDim.x + tid;
// Vectorized: 一次读 16 字节
float4 v = (gid * 4 + 4 <= N) ?
*reinterpret_cast<const float4*>(&arr[gid * 4]) :
make_float4(0, 0, 0, 0);
float local = v.x + v.y + v.z + v.w;
smem[tid] = local;
__syncthreads();
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
if (tid < s) {
smem[tid] += smem[tid + s];
}
__syncthreads();
}
if (tid == 0) partial[blockIdx.x] = smem[0];
}
性能:~2400 GB/s(72%)。一行代码(float4 替代 float)的改动带来 30% 的提升。
5.7 V5:Warp Shuffle 替代 SMEM 归约
到现在为止 block 内归约还在用 SMEM。但 warp 内的归约其实可以用 warp shuffle,无需 SMEM、无需同步:
__inline__ __device__ float warp_reduce(float val) {
for (int offset = 16; offset > 0; offset >>= 1) {
val += __shfl_xor_sync(0xFFFFFFFF, val, offset);
}
return val;
}
__global__ void reduce_v5(const float* arr, float* partial, int N) {
int tid = threadIdx.x;
int gid = blockIdx.x * blockDim.x + tid;
// 1. 每线程 vectorized load + 局部 sum
float4 v = (gid * 4 + 4 <= N) ?
*reinterpret_cast<const float4*>(&arr[gid * 4]) :
make_float4(0, 0, 0, 0);
float local = v.x + v.y + v.z + v.w;
// 2. Warp 内归约 (无 SMEM!)
local = warp_reduce(local);
// 3. 每个 warp 的 lane 0 写到 SMEM
__shared__ float warp_sums[8]; // blockDim.x / 32 = 8 warps
int warp_id = tid / 32;
int lane_id = tid % 32;
if (lane_id == 0) warp_sums[warp_id] = local;
__syncthreads();
// 4. Warp 0 归约 8 个 warp_sum
if (warp_id == 0) {
local = (lane_id < 8) ? warp_sums[lane_id] : 0.0f;
local = warp_reduce(local); // 实际只需要 3 次 shuffle (8 -> 1)
if (lane_id == 0) partial[blockIdx.x] = local;
}
}
关键变化:
- warp 内 5 次
__shfl_xor_sync替代了 5 次 SMEM 读+写+sync——约 5 倍快。 - block 内只有一次
__syncthreads(写 warp_sums 之后)。
性能:~2800 GB/s(84%)。
5.8 V6:移除 partial 数组,直接 atomic
V1-V5 都需要两个 kernel:第一个算 partial,第二个把 partial 数组归约成最终结果。两个 kernel 的开销和中间数组的 HBM 访问让性能受限。
如果每个 block 已经把自己的 partial sum 算到一个 float 了,这个数已经很小(几百到几千个)——这时候用 atomic 就没问题了:
__global__ void reduce_v6(const float* arr, float* out, int N) {
int tid = threadIdx.x;
int gid = blockIdx.x * blockDim.x + tid;
float4 v = (gid * 4 + 4 <= N) ?
*reinterpret_cast<const float4*>(&arr[gid * 4]) :
make_float4(0, 0, 0, 0);
float local = v.x + v.y + v.z + v.w;
local = warp_reduce(local);
__shared__ float warp_sums[8];
int warp_id = tid / 32;
int lane_id = tid % 32;
if (lane_id == 0) warp_sums[warp_id] = local;
__syncthreads();
if (warp_id == 0) {
local = (lane_id < 8) ? warp_sums[lane_id] : 0.0f;
local = warp_reduce(local);
// 关键: 不写中间数组, 直接 atomic
if (lane_id == 0) atomicAdd(out, local);
}
}
少了一次 kernel launch + 一次中间数组 HBM 访问。grid_size 大约 1e8 / 1024 ≈ 1e5,每个 block 一次 atomicAdd,atomic 数量从 v0 的 1e8 降到 1e5——降低 1000 倍,几乎不会成为瓶颈。
性能:~3050 GB/s(91%)。已经非常接近 HBM 峰值。
5.9 V7:Cluster Reduce(Hopper)
Hopper 的 Cluster 让我们可以进一步把 atomic 数量再降低一个数量级——把 cluster 内 16 个 block 的 partial sum 在分布式 SMEM 内汇总,再 atomic 出去。
#include <cooperative_groups.h>
namespace cg = cooperative_groups;
__global__ void __cluster_dims__(16, 1, 1)
reduce_v7(const float* arr, float* out, int N) {
auto cluster = cg::this_cluster();
auto block = cg::this_thread_block();
int tid = threadIdx.x;
int gid = (cluster.block_rank() + cluster.dim_blocks() * blockIdx.x)
* blockDim.x + tid;
// 1-3. 同 v6: vectorized load + warp reduce + block reduce
float4 v = (gid * 4 + 4 <= N) ?
*reinterpret_cast<const float4*>(&arr[gid * 4]) :
make_float4(0, 0, 0, 0);
float local = v.x + v.y + v.z + v.w;
local = warp_reduce(local);
__shared__ float block_sum;
__shared__ float warp_sums[8];
int warp_id = tid / 32;
int lane_id = tid % 32;
if (lane_id == 0) warp_sums[warp_id] = local;
__syncthreads();
if (warp_id == 0) {
local = (lane_id < 8) ? warp_sums[lane_id] : 0.0f;
local = warp_reduce(local);
if (lane_id == 0) block_sum = local;
}
__syncthreads();
// 4. Cluster 内汇总: block 0 收集所有 block 的 block_sum
cluster.sync(); // 等所有 block 都写完 block_sum
if (cluster.block_rank() == 0 && warp_id == 0) {
float total = 0.0f;
for (int b = lane_id; b < cluster.dim_blocks(); b += 32) {
float* peer = cluster.map_shared_rank(&block_sum, b);
total += *peer;
}
// warp 内归约这个 total
total = warp_reduce(total);
if (lane_id == 0) atomicAdd(out, total);
}
}
Cluster=16 时,每 16 个 block 共享一次 atomic,atomic 总数从 v6 的 ~1e5 降到 ~6e3——再降 16 倍。
性能:~3250 GB/s(97%)。已经几乎贴着 HBM 峰值。
5.10 性能对比与小结
把 7 个版本的性能放在一张表里:
| 版本 | 优化点 | 带宽 (GB/s) | % HBM 峰值 |
|---|---|---|---|
| v0 | 朴素 + 全局 atomic | 50 | 1.5% |
| v1 | Block 内 SMEM 归约 | 700 | 21% |
| v2 | 避免 warp divergence | 1100 | 33% |
| v3 | 每线程读 4 元素 | 1800 | 54% |
| v4 | float4 vectorized | 2400 | 72% |
| v5 | Warp shuffle | 2800 | 84% |
| v6 | 直接 atomic, 无中间数组 | 3050 | 91% |
| v7 | Cluster reduce (Hopper) | 3250 | 97% |
数字是 H100 SXM5、CUDA 12.3、N=1e8 的典型测量值;不同版本驱动 / 输入大小可能有 ±5% 偏差。
从 1.5% 到 97%——同一个算法,65 倍的性能差距。这就是 GPU 编程的现实:正确性容易,性能难。一个看似"小改动"(比如 vectorized load)背后藏着对硬件的深刻理解。
flowchart LR V0[v0 1.5%] --> V1[v1 21%] V1 --> V2[v2 33%] V2 --> V3[v3 54%] V3 --> V4[v4 72%] V4 --> V5[v5 84%] V5 --> V6[v6 91%] V6 --> V7[v7 97%] style V0 fill:#fee2e2 style V7 fill:#bbf7d0
5.11 这一章给我们的工程哲学
读到这里读者应该感受到一种"层层压榨"的工程哲学:
- 永远先看 Roofline:reduce 是带宽 bound,目标是逼近 3.35 TB/s 峰值。
- 从 atomic 开始往下挖:全局 atomic → 分层 atomic → cluster atomic。
- 从访存开始往下挖:单 float → vectorized → 多元素 unroll。
- 从同步开始往下挖:每步 sync → 每 warp 一次 sync → cluster sync。
- 从指令带宽开始往下挖:每条指令处理更多数据,减少指令总数。
这五条线索贯穿 LLM 算子的所有优化场景。GEMM、Softmax、LayerNorm、FA2——它们的优化思路都是这五条的某种组合。
特别值得记住的两个反直觉事实:
- 同样的 SMEM 读写次数下,warp shuffle 比 SMEM 快 5 倍——因为 shuffle 是寄存器到寄存器,SMEM 是寄存器→SMEM→寄存器。
- Cluster Reduce 不是为了 reduce 本身——它的真正价值在于"避免一次额外 kernel launch"。kernel launch 在 small data 时是隐藏的瓶颈。
第 6 章我们把这一套手艺用到 Softmax 上。Softmax 的核心也是 reduction(找 max + 求 exp sum),但比纯 reduce 多一个数值稳定性的问题——这正好是 Online Softmax 要解决的,也是 FA2 算法的灵魂。
本章动手练习:
- 在 H100 / A100 上把 v0..v7 都实现一遍,记录每个版本的实际带宽。
- 用 Nsight Compute 看 v3 vs v4 的
lts__t_sectors_op_read.sum.per_second指标差异。- 思考:如果 reduce 的不是 sum 而是 max,需要改哪些地方?为什么 max 的 atomic 没有 sum 的快?