CUDA 算子工程:手写 FlashAttention v2 之路

第 5 章 Reduction:从 atomic 到 cluster reduce

作者 杨艺韬 · 2,686 字

第 5 章 Reduction:从 atomic 到 cluster reduce

"If you can write a fast reduction, you understand GPU programming." ——CUDA 社区的不成文谚语

5.1 为什么 Reduction 是入门第一题

Reduction(归约)就是把一组数压成一个数——求和、求最大、求最小、求积。听起来简单,但在 GPU 上写好它需要把第 1-4 章的几乎所有概念都用上:

Mark Harris 在 2007 年发表过一份经典的 reduction 优化教程(NVIDIA Webinar),把朴素 reduce 一步步优化了 7 个版本。这一章我们沿着他的思路,再加上 Volta 之后的 warp shuffle、Hopper 的 cluster reduce,给读者一份"现代版 Mark Harris reduce"。

最终目标:对 1 亿个 float 求和,达到 3+ TB/s(91%+ HBM 峰值)

5.2 V0:朴素 atomic 版

最直观的写法:每个线程读一个元素,atomic 加到全局结果上。

__global__ void reduce_v0(const float* arr, float* out, int N) {
    int tid = blockIdx.x * blockDim.x + threadIdx.x;
    if (tid < N) {
        atomicAdd(out, arr[tid]);
    }
}

性能:~50 GB/s(H100, N=1e8)。这是带宽峰值的 1.5%——惨不忍睹。

为什么这么慢?因为 N 个线程同时 atomic add 到一个全局变量上,HBM 上的 atomic 操作完全串行化。N=1e8 时这是 1 亿次原子操作,相当于一个全局锁被 1 亿次抢锁。

教训:永远不要在 reduce 的最内层用全局 atomic。atomic 应该是分层归约的最后一步,且参与 atomic 的元素数量 << SM 数量。

5.3 V1:Block 内 SMEM 归约

把 reduce 拆成两步:先在 block 内把一组元素归约成一个数(写到 partial sum 数组),再启动第二个 kernel 把 partial sum 归约成最终结果。

__global__ void reduce_v1(const float* arr, float* partial, int N) {
    __shared__ float smem[256];
    int tid = threadIdx.x;
    int gid = blockIdx.x * blockDim.x + tid;
    smem[tid] = (gid < N) ? arr[gid] : 0.0f;
    __syncthreads();

    // Block 内分层归约
    for (int s = 1; s < blockDim.x; s *= 2) {
        if (tid % (2 * s) == 0) {
            smem[tid] += smem[tid + s];
        }
        __syncthreads();
    }

    if (tid == 0) partial[blockIdx.x] = smem[0];
}

性能:~700 GB/s。比 v0 提升 14 倍——主要是不再全局 atomic 了。

但 700 GB/s 还是只有峰值的 21%。问题在哪?看这段代码的关键瓶颈:

if (tid % (2 * s) == 0) { ... }

这一行触发严重的 warp divergence

虽然 warp 内"活跃 lane 数减少"不会直接降低带宽(带宽瓶颈在 HBM 读),但算术单元利用率下降会让整个 kernel 变长。

5.4 V2:避免 warp divergence

把分层方式改一下:让前 N/2 个线程做加法,避免奇偶交替

__global__ void reduce_v2(const float* arr, float* partial, int N) {
    __shared__ float smem[256];
    int tid = threadIdx.x;
    int gid = blockIdx.x * blockDim.x + tid;
    smem[tid] = (gid < N) ? arr[gid] : 0.0f;
    __syncthreads();

    // 关键: 让前 s 个线程做加法
    for (int s = blockDim.x / 2; s > 0; s >>= 1) {
        if (tid < s) {
            smem[tid] += smem[tid + s];
        }
        __syncthreads();
    }

    if (tid == 0) partial[blockIdx.x] = smem[0];
}

现在 warp divergence 大幅降低:

只有最后几个迭代(s ≤ 16)有 divergence——这部分量级很小,影响不大。

性能:~1100 GB/s。从 21% 提升到 33%。

但还有个隐藏问题:SMEM bank conflict。看这一句:

smem[tid] += smem[tid + s];

s=1 时,tid 范围 0..127。读 smem[tid+1] 时:

虽然不是 32-way conflict,但读 smem[tid+1] 的访问模式和读 smem[tid] 重叠了——硬件需要两次发射。SMEM 的"sequential"读写本身有 conflict。

实际上 NVIDIA 编译器会自动 unroll 最后几次循环,处理这种边界。但写代码的时候应该意识到这种成本存在。

5.5 V3:每线程读多个元素

reduce 的算术强度只有 0.25 FLOPs/byte——意味着每读 1 字节做 0.25 次加法。如果每线程读多个元素,可以摊薄"读元素到 SMEM"的开销。

__global__ void reduce_v3(const float* arr, float* partial, int N) {
    __shared__ float smem[256];
    int tid = threadIdx.x;
    int gid = blockIdx.x * (blockDim.x * 2) + tid;

    // 每线程读 2 个元素, 直接相加
    float v = 0.0f;
    if (gid < N) v += arr[gid];
    if (gid + blockDim.x < N) v += arr[gid + blockDim.x];
    smem[tid] = v;
    __syncthreads();

    for (int s = blockDim.x / 2; s > 0; s >>= 1) {
        if (tid < s) {
            smem[tid] += smem[tid + s];
        }
        __syncthreads();
    }

    if (tid == 0) partial[blockIdx.x] = smem[0];
}

每个 block 现在处理 2 * blockDim.x 个元素而不是 blockDim.x 个,grid_size 减半。性能:~1500 GB/s(45%)。

如果每线程读 4 个元素:

float v = 0.0f;
v += arr[gid];
v += arr[gid + blockDim.x];
v += arr[gid + blockDim.x * 2];
v += arr[gid + blockDim.x * 3];

性能:~1800 GB/s(54%)。继续提升,但收益递减。

5.6 V4:Vectorized Load 用 float4

把"每线程读 4 个 float"换成"每线程读 1 个 float4"——这一行代码的改动让指令带宽减为 1/4:

__global__ void reduce_v4(const float* arr, float* partial, int N) {
    __shared__ float smem[256];
    int tid = threadIdx.x;
    int gid = blockIdx.x * blockDim.x + tid;

    // Vectorized: 一次读 16 字节
    float4 v = (gid * 4 + 4 <= N) ?
        *reinterpret_cast<const float4*>(&arr[gid * 4]) :
        make_float4(0, 0, 0, 0);
    float local = v.x + v.y + v.z + v.w;
    smem[tid] = local;
    __syncthreads();

    for (int s = blockDim.x / 2; s > 0; s >>= 1) {
        if (tid < s) {
            smem[tid] += smem[tid + s];
        }
        __syncthreads();
    }

    if (tid == 0) partial[blockIdx.x] = smem[0];
}

性能:~2400 GB/s(72%)。一行代码(float4 替代 float)的改动带来 30% 的提升。

5.7 V5:Warp Shuffle 替代 SMEM 归约

到现在为止 block 内归约还在用 SMEM。但 warp 内的归约其实可以用 warp shuffle,无需 SMEM、无需同步

__inline__ __device__ float warp_reduce(float val) {
    for (int offset = 16; offset > 0; offset >>= 1) {
        val += __shfl_xor_sync(0xFFFFFFFF, val, offset);
    }
    return val;
}

__global__ void reduce_v5(const float* arr, float* partial, int N) {
    int tid = threadIdx.x;
    int gid = blockIdx.x * blockDim.x + tid;

    // 1. 每线程 vectorized load + 局部 sum
    float4 v = (gid * 4 + 4 <= N) ?
        *reinterpret_cast<const float4*>(&arr[gid * 4]) :
        make_float4(0, 0, 0, 0);
    float local = v.x + v.y + v.z + v.w;

    // 2. Warp 内归约 (无 SMEM!)
    local = warp_reduce(local);

    // 3. 每个 warp 的 lane 0 写到 SMEM
    __shared__ float warp_sums[8];  // blockDim.x / 32 = 8 warps
    int warp_id = tid / 32;
    int lane_id = tid % 32;
    if (lane_id == 0) warp_sums[warp_id] = local;
    __syncthreads();

    // 4. Warp 0 归约 8 个 warp_sum
    if (warp_id == 0) {
        local = (lane_id < 8) ? warp_sums[lane_id] : 0.0f;
        local = warp_reduce(local);  // 实际只需要 3 次 shuffle (8 -> 1)
        if (lane_id == 0) partial[blockIdx.x] = local;
    }
}

关键变化:

  1. warp 内 5 次 __shfl_xor_sync 替代了 5 次 SMEM 读+写+sync——约 5 倍快。
  2. block 内只有一次 __syncthreads(写 warp_sums 之后)。

性能:~2800 GB/s(84%)。

5.8 V6:移除 partial 数组,直接 atomic

V1-V5 都需要两个 kernel:第一个算 partial,第二个把 partial 数组归约成最终结果。两个 kernel 的开销和中间数组的 HBM 访问让性能受限。

如果每个 block 已经把自己的 partial sum 算到一个 float 了,这个数已经很小(几百到几千个)——这时候用 atomic 就没问题了

__global__ void reduce_v6(const float* arr, float* out, int N) {
    int tid = threadIdx.x;
    int gid = blockIdx.x * blockDim.x + tid;

    float4 v = (gid * 4 + 4 <= N) ?
        *reinterpret_cast<const float4*>(&arr[gid * 4]) :
        make_float4(0, 0, 0, 0);
    float local = v.x + v.y + v.z + v.w;
    local = warp_reduce(local);

    __shared__ float warp_sums[8];
    int warp_id = tid / 32;
    int lane_id = tid % 32;
    if (lane_id == 0) warp_sums[warp_id] = local;
    __syncthreads();

    if (warp_id == 0) {
        local = (lane_id < 8) ? warp_sums[lane_id] : 0.0f;
        local = warp_reduce(local);
        // 关键: 不写中间数组, 直接 atomic
        if (lane_id == 0) atomicAdd(out, local);
    }
}

少了一次 kernel launch + 一次中间数组 HBM 访问。grid_size 大约 1e8 / 1024 ≈ 1e5,每个 block 一次 atomicAdd,atomic 数量从 v0 的 1e8 降到 1e5——降低 1000 倍,几乎不会成为瓶颈。

性能:~3050 GB/s(91%)。已经非常接近 HBM 峰值。

5.9 V7:Cluster Reduce(Hopper)

Hopper 的 Cluster 让我们可以进一步把 atomic 数量再降低一个数量级——把 cluster 内 16 个 block 的 partial sum 在分布式 SMEM 内汇总,再 atomic 出去。

#include <cooperative_groups.h>
namespace cg = cooperative_groups;

__global__ void __cluster_dims__(16, 1, 1)
reduce_v7(const float* arr, float* out, int N) {
    auto cluster = cg::this_cluster();
    auto block = cg::this_thread_block();
    int tid = threadIdx.x;
    int gid = (cluster.block_rank() + cluster.dim_blocks() * blockIdx.x)
              * blockDim.x + tid;

    // 1-3. 同 v6: vectorized load + warp reduce + block reduce
    float4 v = (gid * 4 + 4 <= N) ?
        *reinterpret_cast<const float4*>(&arr[gid * 4]) :
        make_float4(0, 0, 0, 0);
    float local = v.x + v.y + v.z + v.w;
    local = warp_reduce(local);

    __shared__ float block_sum;
    __shared__ float warp_sums[8];
    int warp_id = tid / 32;
    int lane_id = tid % 32;
    if (lane_id == 0) warp_sums[warp_id] = local;
    __syncthreads();
    if (warp_id == 0) {
        local = (lane_id < 8) ? warp_sums[lane_id] : 0.0f;
        local = warp_reduce(local);
        if (lane_id == 0) block_sum = local;
    }
    __syncthreads();

    // 4. Cluster 内汇总: block 0 收集所有 block 的 block_sum
    cluster.sync();  // 等所有 block 都写完 block_sum
    if (cluster.block_rank() == 0 && warp_id == 0) {
        float total = 0.0f;
        for (int b = lane_id; b < cluster.dim_blocks(); b += 32) {
            float* peer = cluster.map_shared_rank(&block_sum, b);
            total += *peer;
        }
        // warp 内归约这个 total
        total = warp_reduce(total);
        if (lane_id == 0) atomicAdd(out, total);
    }
}

Cluster=16 时,每 16 个 block 共享一次 atomic,atomic 总数从 v6 的 ~1e5 降到 ~6e3——再降 16 倍。

性能:~3250 GB/s(97%)。已经几乎贴着 HBM 峰值。

5.10 性能对比与小结

把 7 个版本的性能放在一张表里:

版本 优化点 带宽 (GB/s) % HBM 峰值
v0 朴素 + 全局 atomic 50 1.5%
v1 Block 内 SMEM 归约 700 21%
v2 避免 warp divergence 1100 33%
v3 每线程读 4 元素 1800 54%
v4 float4 vectorized 2400 72%
v5 Warp shuffle 2800 84%
v6 直接 atomic, 无中间数组 3050 91%
v7 Cluster reduce (Hopper) 3250 97%

数字是 H100 SXM5、CUDA 12.3、N=1e8 的典型测量值;不同版本驱动 / 输入大小可能有 ±5% 偏差。

从 1.5% 到 97%——同一个算法,65 倍的性能差距。这就是 GPU 编程的现实:正确性容易,性能难。一个看似"小改动"(比如 vectorized load)背后藏着对硬件的深刻理解。

flowchart LR
  V0[v0 1.5%] --> V1[v1 21%]
  V1 --> V2[v2 33%]
  V2 --> V3[v3 54%]
  V3 --> V4[v4 72%]
  V4 --> V5[v5 84%]
  V5 --> V6[v6 91%]
  V6 --> V7[v7 97%]
  style V0 fill:#fee2e2
  style V7 fill:#bbf7d0

5.11 这一章给我们的工程哲学

读到这里读者应该感受到一种"层层压榨"的工程哲学:

  1. 永远先看 Roofline:reduce 是带宽 bound,目标是逼近 3.35 TB/s 峰值。
  2. 从 atomic 开始往下挖:全局 atomic → 分层 atomic → cluster atomic。
  3. 从访存开始往下挖:单 float → vectorized → 多元素 unroll。
  4. 从同步开始往下挖:每步 sync → 每 warp 一次 sync → cluster sync。
  5. 从指令带宽开始往下挖:每条指令处理更多数据,减少指令总数。

这五条线索贯穿 LLM 算子的所有优化场景。GEMM、Softmax、LayerNorm、FA2——它们的优化思路都是这五条的某种组合。

特别值得记住的两个反直觉事实

第 6 章我们把这一套手艺用到 Softmax 上。Softmax 的核心也是 reduction(找 max + 求 exp sum),但比纯 reduce 多一个数值稳定性的问题——这正好是 Online Softmax 要解决的,也是 FA2 算法的灵魂。

本章动手练习

  1. 在 H100 / A100 上把 v0..v7 都实现一遍,记录每个版本的实际带宽。
  2. 用 Nsight Compute 看 v3 vs v4 的 lts__t_sectors_op_read.sum.per_second 指标差异。
  3. 思考:如果 reduce 的不是 sum 而是 max,需要改哪些地方?为什么 max 的 atomic 没有 sum 的快?