CUDA 算子工程:手写 FlashAttention v2 之路

第 6 章 Softmax 与 Online Softmax

作者 杨艺韬 · 2,779 字

第 6 章 Softmax 与 Online Softmax

"The single most important mathematical trick that enables FlashAttention is online softmax — a way to compute softmax in a single pass, in O(1) memory." ——FlashAttention v1 论文(Dao et al., 2022)

6.1 Softmax 的标准定义

给定一个向量 x=(x1,x2,,xN)\mathbf{x} = (x_1, x_2, \ldots, x_N),softmax 定义为:

softmax(xi)=exij=1Nexj\text{softmax}(x_i) = \frac{e^{x_i}}{\sum_{j=1}^{N} e^{x_j}}

直接照定义计算需要两遍:

// 第一遍: 算分母 (求和)
float sum = 0;
for (int i = 0; i < N; ++i) sum += expf(x[i]);

// 第二遍: 算每个输出
for (int i = 0; i < N; ++i) y[i] = expf(x[i]) / sum;

但这段代码有一个致命问题:数值溢出

6.2 数值稳定的 Safe Softmax

如果 xix_i 中有比较大的数(比如 100),那 e100e^{100} 大约是 2.7×10432.7 \times 10^{43}。FP32 的上限大约是 3.4×10383.4 \times 10^{38}——直接溢出成 +inf。如果 xix_i 全是 1000,e1000e^{1000} 直接溢出,所有项都变 inf,分子分母 inf/inf 输出 NaN

LLM 训练和推理中,attention 的 logits(QK^T / sqrt(d))经常达到 ±10 甚至 ±50 的范围。BF16 / FP16 的指数范围更窄(FP16 上限 65504),溢出风险更高。

解决方法是经典的 safe softmax 技巧——同时减去最大值:

softmax(xi)=eximj=1Nexjm,m=maxjxj\text{softmax}(x_i) = \frac{e^{x_i - m}}{\sum_{j=1}^{N} e^{x_j - m}}, \quad m = \max_j x_j

数学上完全等价(分子分母都乘以 eme^{-m} 抵消),但所有 exim1e^{x_i - m} \le 1,绝不溢出。

但 safe softmax 现在需要 3 遍

// Pass 1: 找 max
float m = -INFINITY;
for (int i = 0; i < N; ++i) m = fmaxf(m, x[i]);

// Pass 2: 算分母
float sum = 0;
for (int i = 0; i < N; ++i) sum += expf(x[i] - m);

// Pass 3: 算每个输出
for (int i = 0; i < N; ++i) y[i] = expf(x[i] - m) / sum;

3 遍意味着 x 数组要从 HBM 读 3 次。这是大问题——softmax 的算术强度本身就低(每个元素几次浮点操作 vs 4 字节读),3 遍变成了带宽的 3 倍消耗。

6.3 Online Softmax:1 遍数学

Online Softmax 的核心思想:在遍历数据的过程中同时维护当前的 max 和 sum,不需要预先知道全局 max。

6.3.1 数学推导

定义两个状态量:

初始状态 m0=m_0 = -\infty0=0\ell_0 = 0

当看到新元素 xn+1x_{n+1} 时,更新规则:

mn+1=max(mn,xn+1)n+1=nemnmn+1+exn+1mn+1\begin{aligned} m_{n+1} &= \max(m_n, x_{n+1}) \\ \ell_{n+1} &= \ell_n \cdot e^{m_n - m_{n+1}} + e^{x_{n+1} - m_{n+1}} \end{aligned}

关键的"修正项"是 emnmn+1e^{m_n - m_{n+1}}:当 max 被更新(mn+1>mnm_{n+1} > m_n 时这个值小于 1),把之前累积的 sum 也"按比例缩小",保证它仍然是相对于新 max 的 sum。

把这个递推走完,N\ell_N 就是 safe softmax 的分母。然后再走一遍计算输出(这一遍可以和下游计算 fuse 在一起)。

如果是 attention 这种"sum 之后还要点积"的场景,可以做到真正的 1 pass——这是 FA 的核心。

6.3.2 验证一个简单例子

考虑 x=(1,5,3)\mathbf{x} = (1, 5, 3)

朴素方式

Online:

步骤 xnx_n mnm_n n\ell_n
初始 - -\infty 0
n=1 1 1 0e1+e0=10 \cdot e^{-\infty - 1} + e^{0} = 1
n=2 5 5 1e15+e0=e4+11.01831 \cdot e^{1-5} + e^{0} = e^{-4} + 1 \approx 1.0183
n=3 3 5 1.0183e55+e35=1.0183+e21.15361.0183 \cdot e^{5-5} + e^{3-5} = 1.0183 + e^{-2} \approx 1.1536

3=1.1536\ell_3 = 1.1536,和朴素方式得到的 \sum 完全一致。

6.3.3 推导的几何理解

为什么 nemnmn+1\ell_n \cdot e^{m_n - m_{n+1}} 是正确的修正?

mn+1=max(mn,xn+1)m_{n+1} = \max(m_n, x_{n+1})。我们希望 n+1=j=1n+1exjmn+1\ell_{n+1} = \sum_{j=1}^{n+1} e^{x_j - m_{n+1}}

把它拆开:

j=1n+1exjmn+1=exn+1mn+1+j=1nexjmn+1\sum_{j=1}^{n+1} e^{x_j - m_{n+1}} = e^{x_{n+1} - m_{n+1}} + \sum_{j=1}^{n} e^{x_j - m_{n+1}}

第二项可以改写:

j=1nexjmn+1=j=1nexjmnemnmn+1=nemnmn+1\sum_{j=1}^{n} e^{x_j - m_{n+1}} = \sum_{j=1}^{n} e^{x_j - m_n} \cdot e^{m_n - m_{n+1}} = \ell_n \cdot e^{m_n - m_{n+1}}

所以:

n+1=nemnmn+1+exn+1mn+1\ell_{n+1} = \ell_n \cdot e^{m_n - m_{n+1}} + e^{x_{n+1} - m_{n+1}}

干净的代数变换。这就是为什么 online softmax 数学上严格等价于 safe softmax——它只是把"先扫一遍找 max 再扫一遍累加"改写成"边扫边更新"。

6.4 Online Softmax 的 GPU 实现

把 online softmax 实现成 GPU kernel。基本模板:

__global__ void softmax_online(const float* in, float* out, int N) {
    int tid = threadIdx.x;
    extern __shared__ float smem[];

    // ====== Phase 1: Online sweep (1 pass) ======
    float m = -INFINITY;  // 当前 max
    float l = 0.0f;       // 当前调整后 sum

    for (int i = tid; i < N; i += blockDim.x) {
        float x = in[i];
        float m_new = fmaxf(m, x);
        l = l * expf(m - m_new) + expf(x - m_new);
        m = m_new;
    }
    // 此时每个 thread 持有局部的 (m, l)

    // ====== Phase 2: Block-level reduce ======
    // Warp 内合并 (m, l)
    auto warp_combine = [](float& m, float& l, float m2, float l2) {
        float m_new = fmaxf(m, m2);
        l = l * expf(m - m_new) + l2 * expf(m2 - m_new);
        m = m_new;
    };

    for (int offset = 16; offset > 0; offset >>= 1) {
        float m2 = __shfl_xor_sync(0xFFFFFFFF, m, offset);
        float l2 = __shfl_xor_sync(0xFFFFFFFF, l, offset);
        warp_combine(m, l, m2, l2);
    }
    // 此时 warp 内所有 lane 都持有 warp 的 (m, l)

    // 写入 SMEM, block 内再 reduce 一次
    int warp_id = tid / 32;
    int lane_id = tid % 32;
    if (lane_id == 0) {
        smem[warp_id * 2 + 0] = m;
        smem[warp_id * 2 + 1] = l;
    }
    __syncthreads();

    // Warp 0 收集 8 个 warp 的 (m, l)
    if (warp_id == 0) {
        m = (lane_id < 8) ? smem[lane_id * 2 + 0] : -INFINITY;
        l = (lane_id < 8) ? smem[lane_id * 2 + 1] : 0.0f;
        for (int offset = 4; offset > 0; offset >>= 1) {
            float m2 = __shfl_xor_sync(0xFF, m, offset);
            float l2 = __shfl_xor_sync(0xFF, l, offset);
            warp_combine(m, l, m2, l2);
        }
        if (lane_id == 0) {
            smem[0] = m;
            smem[1] = l;
        }
    }
    __syncthreads();

    float final_m = smem[0];
    float final_l = smem[1];

    // ====== Phase 3: Normalize ======
    for (int i = tid; i < N; i += blockDim.x) {
        out[i] = expf(in[i] - final_m) / final_l;
    }
}

关键点:

  1. Phase 1 每个 thread 维护自己的 (m, l) 状态。
  2. Phase 2warp_combine 合并不同 thread 的 (m, l)——这是 online softmax 的核心组合规则。
  3. Phase 3 用最终的全局 (m, l) 做归一化。

这里 Phase 1 和 Phase 3 都需要读 in[i]——所以严格说还是 2 pass。但 Phase 1 不再需要 3 次单独的 max/sum/exp pass——它把 max 和 sum 合并到一遍里了。这对 SRAM 受限的 GEMM/attention 场景非常关键,因为可以让 in 数据只在 SMEM/寄存器里待一次

6.4.1 Combine 函数的对称性

注意 warp_combine 函数:

auto warp_combine = [](float& m, float& l, float m2, float l2) {
    float m_new = fmaxf(m, m2);
    l = l * expf(m - m_new) + l2 * expf(m2 - m_new);
    m = m_new;
};

这个函数是对称、可结合的:

这是 reduce 类操作的必要条件——没有这个性质,并行归约就不安全。Online softmax 的优雅之处就在于它的 combine 满足 monoid 性质。

6.5 LSE 形式:Log-Sum-Exp

attention 中常用 softmax 的 log 版本(log-softmax),定义为:

logsoftmax(xi)=xilogjexj=xiLSE(x)\log \text{softmax}(x_i) = x_i - \log\sum_{j} e^{x_j} = x_i - \text{LSE}(\mathbf{x})

其中 LSE(x)=logjexj\text{LSE}(\mathbf{x}) = \log \sum_j e^{x_j} 是 log-sum-exp。

数值稳定的 LSE:

LSE(x)=m+logjexjm,m=maxjxj\text{LSE}(\mathbf{x}) = m + \log \sum_j e^{x_j - m}, \quad m = \max_j x_j

在 online 形式中,存 (mn,n)(m_n, \ell_n) 时实际上已经在算 LSE:

LSE=m+log\text{LSE} = m + \log \ell

FA2 反向传播时需要存 LSE(一个标量/行)来支持梯度计算。FA2 前向只输出 LSE,不输出全部 attention 矩阵——这是它显存从 O(N2)O(N^2) 降到 O(N)O(N) 的关键。

6.6 Online Softmax 应用到 FA:1 Pass 真的成立

FA 的核心创新就是把 attention 改写成 online 形式:

朴素 attention:

O=softmax(QKT)VO = \text{softmax}(QK^T) \cdot V

需要 3 个完整的 pass:

  1. S=QKTS = QK^T(HBM 写 N×N 矩阵)。
  2. P=softmax(S)P = \text{softmax}(S) 行级(HBM 写 N×N 矩阵)。
  3. O=PVO = PV(HBM 写 N×d 矩阵)。

中间结果 SSPP 都是 O(N2)O(N^2) 大小,对 N=4096 就是 64M 个 FP16 = 128 MB——HBM 带宽消耗巨大。

FA 用 online softmax 把它压成 1 pass

# 伪代码
m = -inf, l = 0
O = zeros(d)  # output accumulator

for j in range(0, N, B):  # block-by-block
    K_block = K[j:j+B]
    V_block = V[j:j+B]
    S_block = Q @ K_block.T              # B 个分数
    m_block = max(S_block)               # block 局部 max
    P_block = exp(S_block - m_block)     # 局部 exp
    l_block = sum(P_block)

    # Combine 到 (m, l)
    m_new = max(m, m_block)
    alpha = exp(m - m_new)
    beta = exp(m_block - m_new)

    # 重要: 之前累积的 O 也要按 alpha 缩放
    O = O * alpha + beta * (P_block @ V_block)
    l = l * alpha + beta * l_block
    m = m_new

O = O / l  # 最终归一化

注意几点:

  1. m, l, O 三个状态量同步演化。每来一个 K/V block,都更新这三个。
  2. O 的缩放因子 α=emmnew\alpha = e^{m - m_{\text{new}}} 来自 online softmax 的修正——之前累积的 attention 输出也要按新 max 缩放。
  3. 完全不写中间 S 或 P 矩阵到 HBM——只在寄存器/SMEM 中流过。

这就是 FA 的本质:用 online softmax 让 attention 变成可流式的算法。Q/K/V 可以按 tile 喂给 kernel,算完就丢,不需要存中间结果。

第 14-15 章会把这个伪代码落到具体的 CUDA kernel 上。

6.7 Softmax 的访存优化

回到 standalone softmax 的 kernel 实现。性能优化要点:

6.7.1 整行处理 vs 分块处理

LLM 中 softmax 是按行操作的(attention 的每一行独立 softmax)。两种 block 配置:

整行处理:一个 block 处理一行的所有 N 个元素。

适合: N <= 4096 (能放进 SMEM 或寄存器)
优势: 不需要跨 block 通信

分块处理:多个 block 协作处理一行。

适合: N 极大 (比如长序列 attention 的中间矩阵)
代价: 需要两阶段 + atomic 或 cluster reduce

绝大多数 LLM 场景下 N ≤ 8192,整行处理更优。

6.7.2 Vectorized I/O

和 reduce 类似,softmax 的访存也应该 vectorized:

// 每线程读 4 个 fp16 = 8 字节
half2 v0 = *reinterpret_cast<const half2*>(&in[i]);
half2 v1 = *reinterpret_cast<const half2*>(&in[i + 2]);

或者用 fp16 的 int4 (16 字节 = 8 个 fp16):

int4 packed = *reinterpret_cast<const int4*>(&in[i * 8]);
half h[8];
memcpy(h, &packed, 16);

6.7.3 Fused Softmax + Mask

attention 中 softmax 之前通常有 mask(causal mask、padding mask)。fused 写法:

float x = in[i] + mask[i];  // mask 通常是 0 或 -INFINITY
m = fmaxf(m, x);

把 mask add 直接 fuse 到 softmax 的 max/sum 阶段,避免单独走一遍 mask kernel。

6.7.4 PyTorch 实测对比

PyTorch 在 H100 上 softmax 性能(N×N 矩阵,N=4096):

实现 带宽 (GB/s) % HBM 峰值
朴素 3-pass ~700 21%
Fused safe softmax ~2200 66%
Online softmax + vectorized ~2900 87%

数字是估算量级,PyTorch 不同版本/输入大小会有变化。

87% 接近峰值——这是 standalone softmax 的合理极限。FA 把 softmax 嵌入 attention 后,因为减少了 HBM 写中间矩阵,整体性能提升远不止 4×。

6.8 这一章给我们的"内核数学"

Online softmax 是这本书第一次正式接触"用算法重写让 GPU 友好"的思路。它带给读者两个核心 insight:

  1. 数学等价不等于 GPU 等价。同一个 softmax 公式有 3-pass 和 1-pass 两种数学等价的算法,但在 GPU 上性能差几倍。算法重写和"低层 kernel 优化"是 GPU 性能工程的两个独立维度,online softmax 是前者的典范。

  2. Streaming(流式)算法在 GPU 上是黄金。能够用单 pass 维护少量状态量来累积结果的算法,天然适合 GPU——因为它把"中间结果"留在寄存器/SMEM,避免 HBM 往返。Online softmax、Welford 在线方差(下一章)、prefix sum——这些都是流式算法的代表。

第 7 章我们把 online 思路用到另一个 LLM 高频算子上:LayerNorm 与 RMSNorm。LayerNorm 需要算均值和方差——传统两遍写法(先算均值,再算方差)有数值稳定问题,Welford 算法(数学上和 online softmax 同源)可以一遍同时算出均值和方差。读完第 7 章,读者会发现"online 思维"在 LLM 算子里几乎无处不在。

本章动手练习

  1. 实现一个 N=4096 的整行 softmax kernel,先用 3-pass,再改成 online,对比性能。
  2. 用 PyTorch 的 torch.nn.functional.softmax 跑一遍,用 Nsight Compute 看它实际用的是 PyTorch 内部哪个 kernel(可能是 at::native::softmax_warp_forward),读它的源码。
  3. 思考:如果 softmax 的 N 极大(比如 N=1M),整行处理放不进单 block 的 SMEM,怎么用 online softmax + cluster reduce 解决?