CUDA 算子工程:手写 FlashAttention v2 之路

第 7 章 LayerNorm 与 RMSNorm

作者 杨艺韬 · 2,973 字

第 7 章 LayerNorm 与 RMSNorm

"RMSNorm is to LayerNorm what online softmax is to safe softmax — a simplification that removes one pass of bookkeeping while preserving the essential function." ——常见于 LLM 优化讨论

7.1 为什么要归一化

Transformer 里每个 block 都套着一层归一化:Pre-LN 或 Post-LN。它的作用是把每个 token 的隐藏维度(hidden_size,4096 / 8192 / ...)的激活值"拉到合适的尺度",防止训练时数值爆炸或塌陷。

LayerNorm 的标准定义:

LN(x)=xμσ2+ϵγ+β\text{LN}(\mathbf{x}) = \frac{\mathbf{x} - \mu}{\sqrt{\sigma^2 + \epsilon}} \cdot \gamma + \beta

其中:

每次 attention 之前一次 LayerNorm,每次 FFN 之前一次 LayerNorm——一个 32 层的 Transformer,每次 forward 要做 64+ 次 LayerNorm。这是 LLM 推理中除 GEMM 和 attention 之外最频繁的算子。

7.2 朴素两遍算法

最直观的写法:

// Pass 1: 算 mean
float mean = 0;
for (int i = 0; i < H; ++i) mean += x[i];
mean /= H;

// Pass 2: 算 variance
float var = 0;
for (int i = 0; i < H; ++i) var += (x[i] - mean) * (x[i] - mean);
var /= H;

// Pass 3: 归一化
float rstd = rsqrtf(var + eps);
for (int i = 0; i < H; ++i)
    y[i] = (x[i] - mean) * rstd * gamma[i] + beta[i];

3 pass,每 pass 都遍历 x 一次。在 GPU 上意味着 x 数据从 HBM 读 3 次——和 safe softmax 一样的痛点。

7.3 一个数学上的等价:方差的两种公式

数学上有这样一个等式:

σ2=E[X2](E[X])2\sigma^2 = E[X^2] - (E[X])^2

这意味着方差可以用"和"与"平方和"两个量计算:

σ2=1Hixi2μ2\sigma^2 = \frac{1}{H} \sum_{i} x_i^2 - \mu^2

如果用这个公式,可以用一遍同时收集 xi\sum x_ixi2\sum x_i^2

// 1 pass: 算 sum 和 sum_sq
float sum = 0, sum_sq = 0;
for (int i = 0; i < H; ++i) {
    sum += x[i];
    sum_sq += x[i] * x[i];
}
float mean = sum / H;
float var = sum_sq / H - mean * mean;

// Pass 2: 归一化 (这一遍不可避免, 因为输出依赖 mean/var)
float rstd = rsqrtf(var + eps);
for (int i = 0; i < H; ++i)
    y[i] = (x[i] - mean) * rstd * gamma[i] + beta[i];

成功——从 3 pass 降到 2 pass!但有一个致命的数值稳定问题

E[X2](E[X])2E[X^2] - (E[X])^2 这个公式叫naive 方差公式,它在数值上非常不稳定。当 XX 的均值很大、方差很小时,E[X2]E[X^2](E[X])2(E[X])^2 会接近相等,相减时会发生灾难性消除(catastrophic cancellation),损失大量有效数字。

举个具体例子:

x = [1000.001, 1000.002, 1000.003, ..., 1000.010]  (10 个数)

真实 mean = 1000.0055
真实 var = 0.000033 (大约)

朴素公式:
  sum = 10000.055
  sum_sq = 100001100.11 (大约)
  mean = 1000.0055
  var = (100001100.11 / 10) - 1000.0055^2
      = 10000110.011 - 10000110.0110030...
      = 0.0... (FP32 算这个会得到几乎随机的结果, 甚至负数)

这就是为什么真实场景下不能用这个公式——输入的 magnitude 一旦大了,FP32 精度根本不够。

LLM 训练中,激活值的 magnitude 经常超过 100。如果用朴素方差公式做 LayerNorm,方差经常会算成负数(开方变 NaN)或者完全错误的小数。

7.4 Welford 算法:数值稳定的一遍方差

1962 年统计学家 B. P. Welford 提出了一个数值稳定的在线方差算法。它的核心是维护"当前均值"和"M2(偏差平方和)"两个状态,用增量更新:

定义:

初始 n=0n=0, μ=0\mu = 0, M2=0M_2 = 0。每来一个新元素 xn+1x_{n+1}

n=n+1δ=xn+1μnμn=μn+δ/nM2,n=M2,n+δ(xn+1μn)\begin{aligned} n' &= n + 1 \\ \delta &= x_{n+1} - \mu_n \\ \mu_{n'} &= \mu_n + \delta / n' \\ M_{2, n'} &= M_{2, n} + \delta \cdot (x_{n+1} - \mu_{n'}) \end{aligned}

最后方差 σ2=M2,N/N\sigma^2 = M_{2, N} / N

Welford 在数值上是稳定的——它不像朴素公式那样减两个相近的大数,而是不断累积小的偏差量。FP32 下也能保持精度。

7.4.1 Welford 的并行合并规则

Welford 的妙处是它有一个可结合的合并规则。如果两个独立计算的 partial Welford 状态 (na,μa,M2,a)(n_a, \mu_a, M_{2,a})(nb,μb,M2,b)(n_b, \mu_b, M_{2,b}) 要合并:

n=na+nbδ=μbμaμ=μa+δnb/nM2=M2,a+M2,b+δ2nanb/n\begin{aligned} n &= n_a + n_b \\ \delta &= \mu_b - \mu_a \\ \mu &= \mu_a + \delta \cdot n_b / n \\ M_2 &= M_{2,a} + M_{2,b} + \delta^2 \cdot n_a \cdot n_b / n \end{aligned}

这个合并是对称、可结合的。和 online softmax 一样,它构成了一个 monoid——可以用并行 reduce 来计算。

7.4.2 Welford 的 GPU kernel 形式

struct WelfordState {
    int n;
    float mean;
    float m2;
};

__device__ WelfordState welford_update(WelfordState s, float x) {
    s.n += 1;
    float delta = x - s.mean;
    s.mean += delta / s.n;
    s.m2 += delta * (x - s.mean);
    return s;
}

__device__ WelfordState welford_combine(WelfordState a, WelfordState b) {
    int n = a.n + b.n;
    if (n == 0) return {0, 0, 0};
    float delta = b.mean - a.mean;
    float new_mean = a.mean + delta * b.n / n;
    float new_m2 = a.m2 + b.m2 + delta * delta * a.n * b.n / n;
    return {n, new_mean, new_m2};
}

把这两个函数用 warp shuffle 串起来:

__device__ WelfordState warp_welford_reduce(WelfordState s) {
    for (int offset = 16; offset > 0; offset >>= 1) {
        WelfordState other;
        other.n    = __shfl_xor_sync(0xFFFFFFFF, s.n,    offset);
        other.mean = __shfl_xor_sync(0xFFFFFFFF, s.mean, offset);
        other.m2   = __shfl_xor_sync(0xFFFFFFFF, s.m2,   offset);
        s = welford_combine(s, other);
    }
    return s;
}

这就是 Apex FusedLayerNorm 内部的核心。第 7.6 节会给完整的 kernel。

7.5 RMSNorm:进一步简化

2019 年 Zhang & Sennrich 在论文 Root Mean Square Layer Normalization 中提出了 RMSNorm——一个比 LayerNorm 更简单的归一化方案。LLaMA、Mistral、Qwen、Gemma 等几乎所有现代开源 LLM 都用 RMSNorm。

RMSNorm 的定义:

RMS(x)=x1Hixi2+ϵγ\text{RMS}(\mathbf{x}) = \frac{\mathbf{x}}{\sqrt{\frac{1}{H}\sum_i x_i^2 + \epsilon}} \cdot \gamma

对比 LayerNorm,RMSNorm 去掉了:

  1. 均值减法(不再需要 xμ\mathbf{x} - \mu
  2. 均值统计量(不再需要 μ\mu
  3. bias β\beta(只保留 γ\gamma

为什么这样改?论文给出的理由是"重新中心化(减均值)对模型表现影响很小,但会增加计算"。后来的实证研究(特别是 LLaMA 的论文)也确认了这一点——大规模 LLM 训练中,RMSNorm 的最终效果和 LayerNorm 几乎一样,但训练速度快 7-10%。

7.5.1 RMSNorm 的算法只需要 1 个状态量

LayerNorm 的 Welford 需要维护 (n, mean, m2) 三个状态。RMSNorm 只需要维护平方和

float ss = 0;
for (int i = tid; i < H; i += blockDim.x) {
    float v = x[i];
    ss += v * v;
}
// ss reduce 跨 block
float rms = rsqrtf(ss / H + eps);
for (int i = tid; i < H; i += blockDim.x) {
    y[i] = x[i] * rms * gamma[i];
}

这个 kernel 比 LayerNorm 简单得多——纯标量 reduce,无需 Welford 那套增量更新。

7.5.2 性能对比

H100 上 H=8192 的 LayerNorm vs RMSNorm(per-token):

算子 浮点操作数(每 token) 实测带宽 (GB/s) 实测延迟 (μs)
LayerNorm (Welford) ~6H ~2700 ~6.5
RMSNorm ~3H ~3100 ~5.5

带宽差 15%,延迟差 ~15%。考虑到一个 32 层 LLM 推理要做 64 次归一化,这个差距累积起来不算小。

7.6 完整的 Fused LayerNorm Kernel

把上面的元素拼起来,给一个完整的、达到 90%+ HBM 峰值的 LayerNorm kernel:

template <int BLOCK_SIZE = 512, int VEC_SIZE = 4>
__global__ void layernorm_fwd(
    const float* __restrict__ x,        // [B, H]
    const float* __restrict__ gamma,    // [H]
    const float* __restrict__ beta,     // [H]
    float* __restrict__ y,              // [B, H]
    float* __restrict__ mean_out,       // [B]
    float* __restrict__ rstd_out,       // [B]
    int H,
    float eps
) {
    int row = blockIdx.x;
    int tid = threadIdx.x;
    const float* x_row = x + row * H;
          float* y_row = y + row * H;

    // ============ Phase 1: Welford reduce ============
    WelfordState state = {0, 0.0f, 0.0f};

    // 每线程读 VEC_SIZE 个元素
    for (int i = tid * VEC_SIZE; i < H; i += BLOCK_SIZE * VEC_SIZE) {
        float4 v = *reinterpret_cast<const float4*>(&x_row[i]);
        state = welford_update(state, v.x);
        state = welford_update(state, v.y);
        state = welford_update(state, v.z);
        state = welford_update(state, v.w);
    }

    // Warp 内合并
    state = warp_welford_reduce(state);

    // Block 内合并 (跨 warp)
    __shared__ WelfordState warp_states[BLOCK_SIZE / 32];
    int warp_id = tid / 32;
    int lane_id = tid % 32;
    if (lane_id == 0) warp_states[warp_id] = state;
    __syncthreads();

    if (warp_id == 0) {
        if (lane_id < BLOCK_SIZE / 32) state = warp_states[lane_id];
        else state = {0, 0.0f, 0.0f};
        for (int offset = 16; offset > 0; offset >>= 1) {
            WelfordState other;
            other.n    = __shfl_xor_sync(0xFFFFFFFF, state.n,    offset);
            other.mean = __shfl_xor_sync(0xFFFFFFFF, state.mean, offset);
            other.m2   = __shfl_xor_sync(0xFFFFFFFF, state.m2,   offset);
            state = welford_combine(state, other);
        }
    }

    __shared__ float final_mean, final_rstd;
    if (warp_id == 0 && lane_id == 0) {
        final_mean = state.mean;
        float var = state.m2 / state.n;
        final_rstd = rsqrtf(var + eps);
        if (mean_out) mean_out[row] = final_mean;
        if (rstd_out) rstd_out[row] = final_rstd;
    }
    __syncthreads();

    // ============ Phase 2: Normalize + scale ============
    for (int i = tid * VEC_SIZE; i < H; i += BLOCK_SIZE * VEC_SIZE) {
        float4 xv = *reinterpret_cast<const float4*>(&x_row[i]);
        float4 gv = *reinterpret_cast<const float4*>(&gamma[i]);
        float4 bv = *reinterpret_cast<const float4*>(&beta[i]);
        float4 yv;
        yv.x = (xv.x - final_mean) * final_rstd * gv.x + bv.x;
        yv.y = (xv.y - final_mean) * final_rstd * gv.y + bv.y;
        yv.z = (xv.z - final_mean) * final_rstd * gv.z + bv.z;
        yv.w = (xv.w - final_mean) * final_rstd * gv.w + bv.w;
        *reinterpret_cast<float4*>(&y_row[i]) = yv;
    }
}

关键点:

  1. 每个 block 处理一行:grid_size = batch_size。
  2. VEC_SIZE=4 用 float4 vectorized I/O:减少指令带宽。
  3. Welford state 在 warp 内、block 内归约:用 shfl 和 SMEM。
  4. 存 mean/rstd 给反向用:反向 LN 需要这两个量。

性能(H100, B=8192, H=8192, FP32):

7.7 反向:LayerNorm 的两个公式

LayerNorm 反向比前向复杂得多。给定 dy=L/ydy = \partial L / \partial y,需要算:

Lγi=ndyi(n)x^i(n)Lβi=ndyi(n)Lxi=γiσ2+ϵ(dyidyˉγx^idyγx^ˉ)\begin{aligned} \frac{\partial L}{\partial \gamma_i} &= \sum_n dy^{(n)}_i \cdot \hat{x}^{(n)}_i \\ \frac{\partial L}{\partial \beta_i} &= \sum_n dy^{(n)}_i \\ \frac{\partial L}{\partial x_i} &= \frac{\gamma_i}{\sqrt{\sigma^2 + \epsilon}} \left( dy_i - \bar{dy}_{\gamma} - \hat{x}_i \cdot \bar{dy_{\gamma} \hat{x}} \right) \end{aligned}

其中 x^i=(xiμ)/σ2+ϵ\hat{x}_i = (x_i - \mu) / \sqrt{\sigma^2 + \epsilon}dyγˉ\bar{dy_{\gamma}}dyiγi/H\sum dy_i \gamma_i / H 的平均。

这个反向公式有两个 reduce:dyiγi\sum dy_i \gamma_idyiγix^i\sum dy_i \gamma_i \hat{x}_i。可以一遍同时算两个 reduce——典型的双状态 Welford 结构。

完整反向 kernel 比较长,本书不展开。读者可以参考 Apex 的 layer_norm_cuda_kernel.cu,那是工业级的参考实现。

RMSNorm 反向比 LayerNorm 简单得多——只需要一个 reduce:dyiγixi/r\sum dy_i \gamma_i x_i / r。这是 RMSNorm 在训练中的另一个加速点。

7.8 与 PyTorch / Apex 实现对照

PyTorch 的 LayerNorm 实现在 aten/src/ATen/native/cuda/layer_norm_kernel.cu。它的核心是一个叫 RowwiseMomentsCUDAKernel 的 kernel,结构和上面的 layernorm_fwd 类似——也是 per-row、Welford reduce、shuffle 合并。

Apex 的 FusedLayerNormapex/csrc/layer_norm_cuda_kernel.cu,做了几个额外优化:

  1. 支持 mixed precision:输入 fp16 / bf16,中间计算 fp32(避免精度丢失)。
  2. 更细致的 Welford 状态压缩:把 n 信息编码到 mean 的低位,省一个寄存器。
  3. 更激进的 vectorized I/O:根据 H 大小自动选择 vec_size。

Apex 整体性能比 PyTorch 默认快 ~20%,但代码复杂度高很多。绝大多数情况下 PyTorch 默认实现已经够用——除非你训练超大模型且 LayerNorm 占 forward 时间 5% 以上。

7.9 工程权衡:Pre-LN vs Post-LN

最后顺便说一下 LayerNorm 的两种位置安排,因为它影响 kernel 调度:

flowchart LR
  subgraph PostLN [Post-LN · 原始 Transformer]
    P1[x] --> P2[Attention]
    P2 --> P3[+ x]
    P3 --> P4[LayerNorm]
    P4 --> P5[FFN]
    P5 --> P6[+ ...]
    P6 --> P7[LayerNorm]
  end
  subgraph PreLN [Pre-LN · 现代 LLM]
    L1[x] --> L2[LayerNorm]
    L2 --> L3[Attention]
    L3 --> L4[+ x]
    L4 --> L5[LayerNorm]
    L5 --> L6[FFN]
    L6 --> L7[+ ...]
  end

Post-LN(原始 2017 Transformer):LayerNorm 在残差之后。训练困难(梯度消失/爆炸),现在很少用。

Pre-LN(GPT-2/LLaMA/几乎所有现代 LLM):LayerNorm 在残差之前。训练稳定,但有"残差累积"问题(不严重)。

工程上 Pre-LN 还有一个kernel fusion 优势:Pre-LN + 接下来的 GEMM 可以 fuse 在一起(LN 的输出直接喂给 GEMM 的 SMEM),避免一次 HBM 写。这是为什么 NVIDIA 的 TransformerEngine 和 vLLM 都喜欢做 Pre-LN+GEMM 的 fusion。

7.10 这一章的小结与下一章

这一章的关键收获:

  1. 方差有两个公式:朴素 E[X2](E[X])2E[X^2] - (E[X])^2 数值不稳定;Welford 增量公式数值稳定。
  2. Welford 算法和 online softmax 同源——都是用一个可结合 monoid 的 combine 规则把多 pass 改成 1 pass。
  3. RMSNorm 是 LayerNorm 的简化:去掉均值统计,只算 RMS。性能提升 15-20%,效果几乎无损。这就是为什么所有现代 LLM 都用 RMSNorm。
  4. Fused LayerNorm 可以达到 90%+ HBM 带宽——这是带宽 bound kernel 的合理目标。
  5. Pre-LN + GEMM Fusion 是工业级推理引擎的常见优化。

第 8 章我们继续往工程化深入——讲 Element-wise 算子的融合。LLM 推理里有大量"小算子"(add、mul、ReLU、GeLU、SiLU、dropout),单独跑每个都是带宽 bound、利用率极低。把它们融合到一起跑(或者融合到 LayerNorm/GEMM 里),是减少 HBM 往返的关键手段。读完第 8 章读者会理解为什么 vLLM 的 fused_add_rms_norm kernel 能比"add + rms_norm 两个 kernel"快 50%。

本章动手练习

  1. 实现两版 LayerNorm:朴素方差公式版本和 Welford 版本,输入用 [1000.0001, 1000.0002, ...] 这种大均值小方差的数据,对比两者的精度。
  2. 写一个 RMSNorm kernel,对比 LayerNorm 在 H=8192 上的实测延迟。
  3. 阅读 PyTorch 的 RowwiseMomentsCUDAKernel,找到 Welford combine 规则在源码里的对应行。