CUDA 算子工程:手写 FlashAttention v2 之路
第 6 章 Softmax 与 Online Softmax
第 6 章 Softmax 与 Online Softmax
"The single most important mathematical trick that enables FlashAttention is online softmax — a way to compute softmax in a single pass, in O(1) memory." ——FlashAttention v1 论文(Dao et al., 2022)
6.1 Softmax 的标准定义
给定一个向量 ,softmax 定义为:
直接照定义计算需要两遍:
// 第一遍: 算分母 (求和)
float sum = 0;
for (int i = 0; i < N; ++i) sum += expf(x[i]);
// 第二遍: 算每个输出
for (int i = 0; i < N; ++i) y[i] = expf(x[i]) / sum;
但这段代码有一个致命问题:数值溢出。
6.2 数值稳定的 Safe Softmax
如果 中有比较大的数(比如 100),那 大约是 。FP32 的上限大约是 ——直接溢出成 +inf。如果 全是 1000, 直接溢出,所有项都变 inf,分子分母 inf/inf 输出 NaN。
LLM 训练和推理中,attention 的 logits(QK^T / sqrt(d))经常达到 ±10 甚至 ±50 的范围。BF16 / FP16 的指数范围更窄(FP16 上限 65504),溢出风险更高。
解决方法是经典的 safe softmax 技巧——同时减去最大值:
数学上完全等价(分子分母都乘以 抵消),但所有 ,绝不溢出。
但 safe softmax 现在需要 3 遍:
// Pass 1: 找 max
float m = -INFINITY;
for (int i = 0; i < N; ++i) m = fmaxf(m, x[i]);
// Pass 2: 算分母
float sum = 0;
for (int i = 0; i < N; ++i) sum += expf(x[i] - m);
// Pass 3: 算每个输出
for (int i = 0; i < N; ++i) y[i] = expf(x[i] - m) / sum;
3 遍意味着 x 数组要从 HBM 读 3 次。这是大问题——softmax 的算术强度本身就低(每个元素几次浮点操作 vs 4 字节读),3 遍变成了带宽的 3 倍消耗。
6.3 Online Softmax:1 遍数学
Online Softmax 的核心思想:在遍历数据的过程中同时维护当前的 max 和 sum,不需要预先知道全局 max。
6.3.1 数学推导
定义两个状态量:
- = 截至第 个元素时的最大值
- = 截至第 个元素的"调整后 exp 和"
初始状态 ,。
当看到新元素 时,更新规则:
关键的"修正项"是 :当 max 被更新( 时这个值小于 1),把之前累积的 sum 也"按比例缩小",保证它仍然是相对于新 max 的 sum。
把这个递推走完, 就是 safe softmax 的分母。然后再走一遍计算输出(这一遍可以和下游计算 fuse 在一起)。
如果是 attention 这种"sum 之后还要点积"的场景,可以做到真正的 1 pass——这是 FA 的核心。
6.3.2 验证一个简单例子
考虑 。
朴素方式:
Online:
| 步骤 | |||
|---|---|---|---|
| 初始 | - | 0 | |
| n=1 | 1 | 1 | |
| n=2 | 5 | 5 | |
| n=3 | 3 | 5 |
,和朴素方式得到的 完全一致。
6.3.3 推导的几何理解
为什么 是正确的修正?
设 。我们希望 。
把它拆开:
第二项可以改写:
所以:
干净的代数变换。这就是为什么 online softmax 数学上严格等价于 safe softmax——它只是把"先扫一遍找 max 再扫一遍累加"改写成"边扫边更新"。
6.4 Online Softmax 的 GPU 实现
把 online softmax 实现成 GPU kernel。基本模板:
__global__ void softmax_online(const float* in, float* out, int N) {
int tid = threadIdx.x;
extern __shared__ float smem[];
// ====== Phase 1: Online sweep (1 pass) ======
float m = -INFINITY; // 当前 max
float l = 0.0f; // 当前调整后 sum
for (int i = tid; i < N; i += blockDim.x) {
float x = in[i];
float m_new = fmaxf(m, x);
l = l * expf(m - m_new) + expf(x - m_new);
m = m_new;
}
// 此时每个 thread 持有局部的 (m, l)
// ====== Phase 2: Block-level reduce ======
// Warp 内合并 (m, l)
auto warp_combine = [](float& m, float& l, float m2, float l2) {
float m_new = fmaxf(m, m2);
l = l * expf(m - m_new) + l2 * expf(m2 - m_new);
m = m_new;
};
for (int offset = 16; offset > 0; offset >>= 1) {
float m2 = __shfl_xor_sync(0xFFFFFFFF, m, offset);
float l2 = __shfl_xor_sync(0xFFFFFFFF, l, offset);
warp_combine(m, l, m2, l2);
}
// 此时 warp 内所有 lane 都持有 warp 的 (m, l)
// 写入 SMEM, block 内再 reduce 一次
int warp_id = tid / 32;
int lane_id = tid % 32;
if (lane_id == 0) {
smem[warp_id * 2 + 0] = m;
smem[warp_id * 2 + 1] = l;
}
__syncthreads();
// Warp 0 收集 8 个 warp 的 (m, l)
if (warp_id == 0) {
m = (lane_id < 8) ? smem[lane_id * 2 + 0] : -INFINITY;
l = (lane_id < 8) ? smem[lane_id * 2 + 1] : 0.0f;
for (int offset = 4; offset > 0; offset >>= 1) {
float m2 = __shfl_xor_sync(0xFF, m, offset);
float l2 = __shfl_xor_sync(0xFF, l, offset);
warp_combine(m, l, m2, l2);
}
if (lane_id == 0) {
smem[0] = m;
smem[1] = l;
}
}
__syncthreads();
float final_m = smem[0];
float final_l = smem[1];
// ====== Phase 3: Normalize ======
for (int i = tid; i < N; i += blockDim.x) {
out[i] = expf(in[i] - final_m) / final_l;
}
}
关键点:
- Phase 1 每个 thread 维护自己的
(m, l)状态。 - Phase 2 用
warp_combine合并不同 thread 的(m, l)——这是 online softmax 的核心组合规则。 - Phase 3 用最终的全局
(m, l)做归一化。
这里 Phase 1 和 Phase 3 都需要读 in[i]——所以严格说还是 2 pass。但 Phase 1 不再需要 3 次单独的 max/sum/exp pass——它把 max 和 sum 合并到一遍里了。这对 SRAM 受限的 GEMM/attention 场景非常关键,因为可以让 in 数据只在 SMEM/寄存器里待一次。
6.4.1 Combine 函数的对称性
注意 warp_combine 函数:
auto warp_combine = [](float& m, float& l, float m2, float l2) {
float m_new = fmaxf(m, m2);
l = l * expf(m - m_new) + l2 * expf(m2 - m_new);
m = m_new;
};
这个函数是对称、可结合的:
这是 reduce 类操作的必要条件——没有这个性质,并行归约就不安全。Online softmax 的优雅之处就在于它的 combine 满足 monoid 性质。
6.5 LSE 形式:Log-Sum-Exp
attention 中常用 softmax 的 log 版本(log-softmax),定义为:
其中 是 log-sum-exp。
数值稳定的 LSE:
在 online 形式中,存 时实际上已经在算 LSE:
FA2 反向传播时需要存 LSE(一个标量/行)来支持梯度计算。FA2 前向只输出 LSE,不输出全部 attention 矩阵——这是它显存从 降到 的关键。
6.6 Online Softmax 应用到 FA:1 Pass 真的成立
FA 的核心创新就是把 attention 改写成 online 形式:
朴素 attention:
需要 3 个完整的 pass:
- 算 (HBM 写 N×N 矩阵)。
- 算 行级(HBM 写 N×N 矩阵)。
- 算 (HBM 写 N×d 矩阵)。
中间结果 和 都是 大小,对 N=4096 就是 64M 个 FP16 = 128 MB——HBM 带宽消耗巨大。
FA 用 online softmax 把它压成 1 pass:
# 伪代码
m = -inf, l = 0
O = zeros(d) # output accumulator
for j in range(0, N, B): # block-by-block
K_block = K[j:j+B]
V_block = V[j:j+B]
S_block = Q @ K_block.T # B 个分数
m_block = max(S_block) # block 局部 max
P_block = exp(S_block - m_block) # 局部 exp
l_block = sum(P_block)
# Combine 到 (m, l)
m_new = max(m, m_block)
alpha = exp(m - m_new)
beta = exp(m_block - m_new)
# 重要: 之前累积的 O 也要按 alpha 缩放
O = O * alpha + beta * (P_block @ V_block)
l = l * alpha + beta * l_block
m = m_new
O = O / l # 最终归一化
注意几点:
- m, l, O 三个状态量同步演化。每来一个 K/V block,都更新这三个。
- O 的缩放因子 来自 online softmax 的修正——之前累积的 attention 输出也要按新 max 缩放。
- 完全不写中间 S 或 P 矩阵到 HBM——只在寄存器/SMEM 中流过。
这就是 FA 的本质:用 online softmax 让 attention 变成可流式的算法。Q/K/V 可以按 tile 喂给 kernel,算完就丢,不需要存中间结果。
第 14-15 章会把这个伪代码落到具体的 CUDA kernel 上。
6.7 Softmax 的访存优化
回到 standalone softmax 的 kernel 实现。性能优化要点:
6.7.1 整行处理 vs 分块处理
LLM 中 softmax 是按行操作的(attention 的每一行独立 softmax)。两种 block 配置:
整行处理:一个 block 处理一行的所有 N 个元素。
适合: N <= 4096 (能放进 SMEM 或寄存器)
优势: 不需要跨 block 通信
分块处理:多个 block 协作处理一行。
适合: N 极大 (比如长序列 attention 的中间矩阵)
代价: 需要两阶段 + atomic 或 cluster reduce
绝大多数 LLM 场景下 N ≤ 8192,整行处理更优。
6.7.2 Vectorized I/O
和 reduce 类似,softmax 的访存也应该 vectorized:
// 每线程读 4 个 fp16 = 8 字节
half2 v0 = *reinterpret_cast<const half2*>(&in[i]);
half2 v1 = *reinterpret_cast<const half2*>(&in[i + 2]);
或者用 fp16 的 int4 (16 字节 = 8 个 fp16):
int4 packed = *reinterpret_cast<const int4*>(&in[i * 8]);
half h[8];
memcpy(h, &packed, 16);
6.7.3 Fused Softmax + Mask
attention 中 softmax 之前通常有 mask(causal mask、padding mask)。fused 写法:
float x = in[i] + mask[i]; // mask 通常是 0 或 -INFINITY
m = fmaxf(m, x);
把 mask add 直接 fuse 到 softmax 的 max/sum 阶段,避免单独走一遍 mask kernel。
6.7.4 PyTorch 实测对比
PyTorch 在 H100 上 softmax 性能(N×N 矩阵,N=4096):
| 实现 | 带宽 (GB/s) | % HBM 峰值 |
|---|---|---|
| 朴素 3-pass | ~700 | 21% |
| Fused safe softmax | ~2200 | 66% |
| Online softmax + vectorized | ~2900 | 87% |
数字是估算量级,PyTorch 不同版本/输入大小会有变化。
87% 接近峰值——这是 standalone softmax 的合理极限。FA 把 softmax 嵌入 attention 后,因为减少了 HBM 写中间矩阵,整体性能提升远不止 4×。
6.8 这一章给我们的"内核数学"
Online softmax 是这本书第一次正式接触"用算法重写让 GPU 友好"的思路。它带给读者两个核心 insight:
-
数学等价不等于 GPU 等价。同一个 softmax 公式有 3-pass 和 1-pass 两种数学等价的算法,但在 GPU 上性能差几倍。算法重写和"低层 kernel 优化"是 GPU 性能工程的两个独立维度,online softmax 是前者的典范。
-
Streaming(流式)算法在 GPU 上是黄金。能够用单 pass 维护少量状态量来累积结果的算法,天然适合 GPU——因为它把"中间结果"留在寄存器/SMEM,避免 HBM 往返。Online softmax、Welford 在线方差(下一章)、prefix sum——这些都是流式算法的代表。
第 7 章我们把 online 思路用到另一个 LLM 高频算子上:LayerNorm 与 RMSNorm。LayerNorm 需要算均值和方差——传统两遍写法(先算均值,再算方差)有数值稳定问题,Welford 算法(数学上和 online softmax 同源)可以一遍同时算出均值和方差。读完第 7 章,读者会发现"online 思维"在 LLM 算子里几乎无处不在。
本章动手练习:
- 实现一个 N=4096 的整行 softmax kernel,先用 3-pass,再改成 online,对比性能。
- 用 PyTorch 的
torch.nn.functional.softmax跑一遍,用 Nsight Compute 看它实际用的是 PyTorch 内部哪个 kernel(可能是at::native::softmax_warp_forward),读它的源码。- 思考:如果 softmax 的 N 极大(比如 N=1M),整行处理放不进单 block 的 SMEM,怎么用 online softmax + cluster reduce 解决?