CUDA 算子工程:手写 FlashAttention v2 之路
第 7 章 LayerNorm 与 RMSNorm
第 7 章 LayerNorm 与 RMSNorm
"RMSNorm is to LayerNorm what online softmax is to safe softmax — a simplification that removes one pass of bookkeeping while preserving the essential function." ——常见于 LLM 优化讨论
7.1 为什么要归一化
Transformer 里每个 block 都套着一层归一化:Pre-LN 或 Post-LN。它的作用是把每个 token 的隐藏维度(hidden_size,4096 / 8192 / ...)的激活值"拉到合适的尺度",防止训练时数值爆炸或塌陷。
LayerNorm 的标准定义:
其中:
- 是均值
- 是方差
- 是可学习的缩放与偏置(每个特征维独立)
- 是数值稳定的小量(典型 )
每次 attention 之前一次 LayerNorm,每次 FFN 之前一次 LayerNorm——一个 32 层的 Transformer,每次 forward 要做 64+ 次 LayerNorm。这是 LLM 推理中除 GEMM 和 attention 之外最频繁的算子。
7.2 朴素两遍算法
最直观的写法:
// Pass 1: 算 mean
float mean = 0;
for (int i = 0; i < H; ++i) mean += x[i];
mean /= H;
// Pass 2: 算 variance
float var = 0;
for (int i = 0; i < H; ++i) var += (x[i] - mean) * (x[i] - mean);
var /= H;
// Pass 3: 归一化
float rstd = rsqrtf(var + eps);
for (int i = 0; i < H; ++i)
y[i] = (x[i] - mean) * rstd * gamma[i] + beta[i];
3 pass,每 pass 都遍历 x 一次。在 GPU 上意味着 x 数据从 HBM 读 3 次——和 safe softmax 一样的痛点。
7.3 一个数学上的等价:方差的两种公式
数学上有这样一个等式:
这意味着方差可以用"和"与"平方和"两个量计算:
如果用这个公式,可以用一遍同时收集 和 :
// 1 pass: 算 sum 和 sum_sq
float sum = 0, sum_sq = 0;
for (int i = 0; i < H; ++i) {
sum += x[i];
sum_sq += x[i] * x[i];
}
float mean = sum / H;
float var = sum_sq / H - mean * mean;
// Pass 2: 归一化 (这一遍不可避免, 因为输出依赖 mean/var)
float rstd = rsqrtf(var + eps);
for (int i = 0; i < H; ++i)
y[i] = (x[i] - mean) * rstd * gamma[i] + beta[i];
成功——从 3 pass 降到 2 pass!但有一个致命的数值稳定问题:
这个公式叫naive 方差公式,它在数值上非常不稳定。当 的均值很大、方差很小时, 和 会接近相等,相减时会发生灾难性消除(catastrophic cancellation),损失大量有效数字。
举个具体例子:
x = [1000.001, 1000.002, 1000.003, ..., 1000.010] (10 个数)
真实 mean = 1000.0055
真实 var = 0.000033 (大约)
朴素公式:
sum = 10000.055
sum_sq = 100001100.11 (大约)
mean = 1000.0055
var = (100001100.11 / 10) - 1000.0055^2
= 10000110.011 - 10000110.0110030...
= 0.0... (FP32 算这个会得到几乎随机的结果, 甚至负数)
这就是为什么真实场景下不能用这个公式——输入的 magnitude 一旦大了,FP32 精度根本不够。
LLM 训练中,激活值的 magnitude 经常超过 100。如果用朴素方差公式做 LayerNorm,方差经常会算成负数(开方变 NaN)或者完全错误的小数。
7.4 Welford 算法:数值稳定的一遍方差
1962 年统计学家 B. P. Welford 提出了一个数值稳定的在线方差算法。它的核心是维护"当前均值"和"M2(偏差平方和)"两个状态,用增量更新:
定义:
- = 已处理元素数量
- = 前 个元素的均值
- = 偏差平方和
初始 , , 。每来一个新元素 :
最后方差 。
Welford 在数值上是稳定的——它不像朴素公式那样减两个相近的大数,而是不断累积小的偏差量。FP32 下也能保持精度。
7.4.1 Welford 的并行合并规则
Welford 的妙处是它有一个可结合的合并规则。如果两个独立计算的 partial Welford 状态 和 要合并:
这个合并是对称、可结合的。和 online softmax 一样,它构成了一个 monoid——可以用并行 reduce 来计算。
7.4.2 Welford 的 GPU kernel 形式
struct WelfordState {
int n;
float mean;
float m2;
};
__device__ WelfordState welford_update(WelfordState s, float x) {
s.n += 1;
float delta = x - s.mean;
s.mean += delta / s.n;
s.m2 += delta * (x - s.mean);
return s;
}
__device__ WelfordState welford_combine(WelfordState a, WelfordState b) {
int n = a.n + b.n;
if (n == 0) return {0, 0, 0};
float delta = b.mean - a.mean;
float new_mean = a.mean + delta * b.n / n;
float new_m2 = a.m2 + b.m2 + delta * delta * a.n * b.n / n;
return {n, new_mean, new_m2};
}
把这两个函数用 warp shuffle 串起来:
__device__ WelfordState warp_welford_reduce(WelfordState s) {
for (int offset = 16; offset > 0; offset >>= 1) {
WelfordState other;
other.n = __shfl_xor_sync(0xFFFFFFFF, s.n, offset);
other.mean = __shfl_xor_sync(0xFFFFFFFF, s.mean, offset);
other.m2 = __shfl_xor_sync(0xFFFFFFFF, s.m2, offset);
s = welford_combine(s, other);
}
return s;
}
这就是 Apex FusedLayerNorm 内部的核心。第 7.6 节会给完整的 kernel。
7.5 RMSNorm:进一步简化
2019 年 Zhang & Sennrich 在论文 Root Mean Square Layer Normalization 中提出了 RMSNorm——一个比 LayerNorm 更简单的归一化方案。LLaMA、Mistral、Qwen、Gemma 等几乎所有现代开源 LLM 都用 RMSNorm。
RMSNorm 的定义:
对比 LayerNorm,RMSNorm 去掉了:
- 均值减法(不再需要 )
- 均值统计量(不再需要 )
- bias (只保留 )
为什么这样改?论文给出的理由是"重新中心化(减均值)对模型表现影响很小,但会增加计算"。后来的实证研究(特别是 LLaMA 的论文)也确认了这一点——大规模 LLM 训练中,RMSNorm 的最终效果和 LayerNorm 几乎一样,但训练速度快 7-10%。
7.5.1 RMSNorm 的算法只需要 1 个状态量
LayerNorm 的 Welford 需要维护 (n, mean, m2) 三个状态。RMSNorm 只需要维护平方和:
float ss = 0;
for (int i = tid; i < H; i += blockDim.x) {
float v = x[i];
ss += v * v;
}
// ss reduce 跨 block
float rms = rsqrtf(ss / H + eps);
for (int i = tid; i < H; i += blockDim.x) {
y[i] = x[i] * rms * gamma[i];
}
这个 kernel 比 LayerNorm 简单得多——纯标量 reduce,无需 Welford 那套增量更新。
7.5.2 性能对比
H100 上 H=8192 的 LayerNorm vs RMSNorm(per-token):
| 算子 | 浮点操作数(每 token) | 实测带宽 (GB/s) | 实测延迟 (μs) |
|---|---|---|---|
| LayerNorm (Welford) | ~6H | ~2700 | ~6.5 |
| RMSNorm | ~3H | ~3100 | ~5.5 |
带宽差 15%,延迟差 ~15%。考虑到一个 32 层 LLM 推理要做 64 次归一化,这个差距累积起来不算小。
7.6 完整的 Fused LayerNorm Kernel
把上面的元素拼起来,给一个完整的、达到 90%+ HBM 峰值的 LayerNorm kernel:
template <int BLOCK_SIZE = 512, int VEC_SIZE = 4>
__global__ void layernorm_fwd(
const float* __restrict__ x, // [B, H]
const float* __restrict__ gamma, // [H]
const float* __restrict__ beta, // [H]
float* __restrict__ y, // [B, H]
float* __restrict__ mean_out, // [B]
float* __restrict__ rstd_out, // [B]
int H,
float eps
) {
int row = blockIdx.x;
int tid = threadIdx.x;
const float* x_row = x + row * H;
float* y_row = y + row * H;
// ============ Phase 1: Welford reduce ============
WelfordState state = {0, 0.0f, 0.0f};
// 每线程读 VEC_SIZE 个元素
for (int i = tid * VEC_SIZE; i < H; i += BLOCK_SIZE * VEC_SIZE) {
float4 v = *reinterpret_cast<const float4*>(&x_row[i]);
state = welford_update(state, v.x);
state = welford_update(state, v.y);
state = welford_update(state, v.z);
state = welford_update(state, v.w);
}
// Warp 内合并
state = warp_welford_reduce(state);
// Block 内合并 (跨 warp)
__shared__ WelfordState warp_states[BLOCK_SIZE / 32];
int warp_id = tid / 32;
int lane_id = tid % 32;
if (lane_id == 0) warp_states[warp_id] = state;
__syncthreads();
if (warp_id == 0) {
if (lane_id < BLOCK_SIZE / 32) state = warp_states[lane_id];
else state = {0, 0.0f, 0.0f};
for (int offset = 16; offset > 0; offset >>= 1) {
WelfordState other;
other.n = __shfl_xor_sync(0xFFFFFFFF, state.n, offset);
other.mean = __shfl_xor_sync(0xFFFFFFFF, state.mean, offset);
other.m2 = __shfl_xor_sync(0xFFFFFFFF, state.m2, offset);
state = welford_combine(state, other);
}
}
__shared__ float final_mean, final_rstd;
if (warp_id == 0 && lane_id == 0) {
final_mean = state.mean;
float var = state.m2 / state.n;
final_rstd = rsqrtf(var + eps);
if (mean_out) mean_out[row] = final_mean;
if (rstd_out) rstd_out[row] = final_rstd;
}
__syncthreads();
// ============ Phase 2: Normalize + scale ============
for (int i = tid * VEC_SIZE; i < H; i += BLOCK_SIZE * VEC_SIZE) {
float4 xv = *reinterpret_cast<const float4*>(&x_row[i]);
float4 gv = *reinterpret_cast<const float4*>(&gamma[i]);
float4 bv = *reinterpret_cast<const float4*>(&beta[i]);
float4 yv;
yv.x = (xv.x - final_mean) * final_rstd * gv.x + bv.x;
yv.y = (xv.y - final_mean) * final_rstd * gv.y + bv.y;
yv.z = (xv.z - final_mean) * final_rstd * gv.z + bv.z;
yv.w = (xv.w - final_mean) * final_rstd * gv.w + bv.w;
*reinterpret_cast<float4*>(&y_row[i]) = yv;
}
}
关键点:
- 每个 block 处理一行:grid_size = batch_size。
- VEC_SIZE=4 用 float4 vectorized I/O:减少指令带宽。
- Welford state 在 warp 内、block 内归约:用 shfl 和 SMEM。
- 存 mean/rstd 给反向用:反向 LN 需要这两个量。
性能(H100, B=8192, H=8192, FP32):
- 实测带宽:~3050 GB/s(91% HBM 峰值)
- vs PyTorch 默认 LayerNorm:~2400 GB/s(72%)
- vs Apex
FusedLayerNorm:~2900 GB/s(87%)
7.7 反向:LayerNorm 的两个公式
LayerNorm 反向比前向复杂得多。给定 ,需要算:
其中 , 是 的平均。
这个反向公式有两个 reduce: 和 。可以一遍同时算两个 reduce——典型的双状态 Welford 结构。
完整反向 kernel 比较长,本书不展开。读者可以参考 Apex 的 layer_norm_cuda_kernel.cu,那是工业级的参考实现。
RMSNorm 反向比 LayerNorm 简单得多——只需要一个 reduce:。这是 RMSNorm 在训练中的另一个加速点。
7.8 与 PyTorch / Apex 实现对照
PyTorch 的 LayerNorm 实现在 aten/src/ATen/native/cuda/layer_norm_kernel.cu。它的核心是一个叫 RowwiseMomentsCUDAKernel 的 kernel,结构和上面的 layernorm_fwd 类似——也是 per-row、Welford reduce、shuffle 合并。
Apex 的 FusedLayerNorm 在 apex/csrc/layer_norm_cuda_kernel.cu,做了几个额外优化:
- 支持 mixed precision:输入 fp16 / bf16,中间计算 fp32(避免精度丢失)。
- 更细致的 Welford 状态压缩:把
n信息编码到mean的低位,省一个寄存器。 - 更激进的 vectorized I/O:根据 H 大小自动选择 vec_size。
Apex 整体性能比 PyTorch 默认快 ~20%,但代码复杂度高很多。绝大多数情况下 PyTorch 默认实现已经够用——除非你训练超大模型且 LayerNorm 占 forward 时间 5% 以上。
7.9 工程权衡:Pre-LN vs Post-LN
最后顺便说一下 LayerNorm 的两种位置安排,因为它影响 kernel 调度:
flowchart LR
subgraph PostLN [Post-LN · 原始 Transformer]
P1[x] --> P2[Attention]
P2 --> P3[+ x]
P3 --> P4[LayerNorm]
P4 --> P5[FFN]
P5 --> P6[+ ...]
P6 --> P7[LayerNorm]
end
subgraph PreLN [Pre-LN · 现代 LLM]
L1[x] --> L2[LayerNorm]
L2 --> L3[Attention]
L3 --> L4[+ x]
L4 --> L5[LayerNorm]
L5 --> L6[FFN]
L6 --> L7[+ ...]
end
Post-LN(原始 2017 Transformer):LayerNorm 在残差之后。训练困难(梯度消失/爆炸),现在很少用。
Pre-LN(GPT-2/LLaMA/几乎所有现代 LLM):LayerNorm 在残差之前。训练稳定,但有"残差累积"问题(不严重)。
工程上 Pre-LN 还有一个kernel fusion 优势:Pre-LN + 接下来的 GEMM 可以 fuse 在一起(LN 的输出直接喂给 GEMM 的 SMEM),避免一次 HBM 写。这是为什么 NVIDIA 的 TransformerEngine 和 vLLM 都喜欢做 Pre-LN+GEMM 的 fusion。
7.10 这一章的小结与下一章
这一章的关键收获:
- 方差有两个公式:朴素 数值不稳定;Welford 增量公式数值稳定。
- Welford 算法和 online softmax 同源——都是用一个可结合 monoid 的 combine 规则把多 pass 改成 1 pass。
- RMSNorm 是 LayerNorm 的简化:去掉均值统计,只算 RMS。性能提升 15-20%,效果几乎无损。这就是为什么所有现代 LLM 都用 RMSNorm。
- Fused LayerNorm 可以达到 90%+ HBM 带宽——这是带宽 bound kernel 的合理目标。
- Pre-LN + GEMM Fusion 是工业级推理引擎的常见优化。
第 8 章我们继续往工程化深入——讲 Element-wise 算子的融合。LLM 推理里有大量"小算子"(add、mul、ReLU、GeLU、SiLU、dropout),单独跑每个都是带宽 bound、利用率极低。把它们融合到一起跑(或者融合到 LayerNorm/GEMM 里),是减少 HBM 往返的关键手段。读完第 8 章读者会理解为什么 vLLM 的 fused_add_rms_norm kernel 能比"add + rms_norm 两个 kernel"快 50%。
本章动手练习:
- 实现两版 LayerNorm:朴素方差公式版本和 Welford 版本,输入用
[1000.0001, 1000.0002, ...]这种大均值小方差的数据,对比两者的精度。- 写一个 RMSNorm kernel,对比 LayerNorm 在 H=8192 上的实测延迟。
- 阅读 PyTorch 的
RowwiseMomentsCUDAKernel,找到 Welford combine 规则在源码里的对应行。