CUDA 算子工程:手写 FlashAttention v2 之路

第 8 章 Element-wise 与算子融合

作者 杨艺韬 · 2,922 字

第 8 章 Element-wise 与算子融合

"An unfused kernel is a kernel that wastes bandwidth twice — once to write, once to read again. In LLM inference, bandwidth is the only thing you have." ——一句在推理引擎团队里被反复重复的话

8.1 LLM 中的 element-wise 算子盘点

打开任何一个现代 LLM(LLaMA、Qwen、Mistral)的 forward 代码,统计一下出现的 element-wise 算子:

# LLaMA-style decoder block (简化版)
def decoder_block(x, ...):
    # Attention 子层
    h = rms_norm(x)               # element-wise (per row)
    h = qkv_proj(h)               # GEMM (大算子)
    q, k = rope(q, k)             # element-wise (旋转)
    h = attention(q, k, v)        # 大算子
    h = o_proj(h)                 # GEMM
    x = x + h                     # element-wise (残差加)
    # FFN 子层
    h = rms_norm(x)               # element-wise
    g = gate_proj(h)              # GEMM
    u = up_proj(h)                # GEMM
    h = silu(g) * u               # element-wise (激活 + 乘)
    h = down_proj(h)              # GEMM
    x = x + h                     # element-wise
    return x

仔细数一下,一个 decoder block 里的 element-wise 操作有:

  1. 2 次 RMSNorm(attention 前 + FFN 前)
  2. 1 次 RoPE 旋转(per-head 的 element-wise 复数乘法)
  3. 1 次 SiLU + 1 次 element-wise 乘(FFN 中的 SwiGLU)
  4. 2 次残差加(attention 后 + FFN 后)

每个算子都需要遍历整个 hidden state(B × H 大小)。如果每个算子都是独立 kernel,整个 block 仅 element-wise 部分就要 6 次完整的 HBM 往返——B × H × 4 字节读 + B × H × 4 字节写。

对于 7B LLaMA, B=1(单 token decoding), H=4096:

而单 token decoding 总延迟目标是 ~10ms。仅仅 element-wise kernel 的 launch 开销就占了 10%——这是不可接受的。

所以算子融合是 LLM 推理中除了 attention、GEMM 优化之外的第二大主题。

8.2 算子融合的三个层次

按融合的"激进程度",可以分成三个层次:

flowchart TB
  subgraph L1 [Level 1 · 同类 element-wise 串接]
    L1A[add + scale + add] --> L1B[一个 kernel: y = scale*a + b + c]
  end
  subgraph L2 [Level 2 · element-wise + reduce 融合]
    L2A[add + rms_norm] --> L2B[一个 kernel: residual add 在 SMEM 里完成,<br/>直接接 RMS 计算]
  end
  subgraph L3 [Level 3 · 跨异构算子融合]
    L3A[GEMM + bias + GeLU] --> L3B["GEMM epilogue 直接出 GeLU(out + bias)"]
  end

每一个层次的难度和收益都更大。

8.3 Level 1:纯 element-wise 串接

最简单的融合:把多个 element-wise 操作写在一个 kernel 里,输入数据从 HBM 读一次,所有算完之后写一次。

// 反例: 三个独立 kernel
add_kernel(a, b, tmp1);          // tmp1 = a + b
scale_kernel(tmp1, alpha, tmp2); // tmp2 = alpha * tmp1
add_kernel(tmp2, c, out);        // out = tmp2 + c

// 正例: 一个 fused kernel
__global__ void fused(const float* a, const float* b, const float* c,
                      float alpha, float* out, int N) {
    int tid = blockIdx.x * blockDim.x + threadIdx.x;
    if (tid < N) {
        out[tid] = alpha * (a[tid] + b[tid]) + c[tid];
    }
}

这种 fusion 的收益分析:

PyTorch 2.0 的 torch.compile 主要做的就是这个层次的 fusion——通过 TorchInductor 把多个 element-wise op 编译成一个 Triton kernel。

8.4 Level 2:Element-wise + Reduce 融合

更有价值的融合是 element-wise 算子和 reduce 算子(LayerNorm/RMSNorm/Softmax)的组合。比如 vLLM 中的 fused_add_rms_norm

# 反例: 两个独立 kernel
residual = residual + hidden        # 残差加 (element-wise)
hidden = rms_norm(residual)         # RMSNorm (有 reduce)

# 正例: 一个 kernel
hidden, residual = fused_add_rms_norm(residual, hidden)
# 内部: 把 residual+hidden 累加到 SMEM, 直接在 SMEM 上做 RMS 计算

完整 kernel:

template <int BLOCK_SIZE = 512>
__global__ void fused_add_rms_norm(
    float* __restrict__ residual,    // [B, H], in/out
    float* __restrict__ hidden,      // [B, H], in/out
    const float* __restrict__ gamma, // [H]
    int H,
    float eps
) {
    int row = blockIdx.x;
    int tid = threadIdx.x;
    float* res_row = residual + row * H;
    float* hid_row = hidden + row * H;

    // Phase 1: 残差加 + 累积平方和 (一遍同时做两件事)
    extern __shared__ float smem[];
    float ss = 0.0f;
    for (int i = tid; i < H; i += BLOCK_SIZE) {
        float r = res_row[i];
        float h = hid_row[i];
        float v = r + h;       // 残差加
        smem[i] = v;           // 暂存到 SMEM (避免再读 HBM)
        res_row[i] = v;        // 同时把更新的 residual 写回 (下一层用)
        ss += v * v;           // 累积平方和
    }
    // Block reduce ss
    ss = block_reduce_sum(ss);
    __shared__ float rms;
    if (tid == 0) rms = rsqrtf(ss / H + eps);
    __syncthreads();

    // Phase 2: 归一化 + scale (从 SMEM 读, 不再读 HBM)
    for (int i = tid; i < H; i += BLOCK_SIZE) {
        hid_row[i] = smem[i] * rms * gamma[i];
    }
}

关键点:

  1. HBM 流量减半:原来 add 要读 2 次写 1 次,rms_norm 要读 1 次写 1 次,共 5 次;融合后读 2 次(res+hid)写 2 次(res+hid_norm),共 4 次。
  2. SMEM 起到中间缓冲作用:残差加的结果暂存到 SMEM,rms_norm 不需要再从 HBM 读。
  3. Pre-LN 残差也写回:下一层(attention 或 FFN)的输入需要更新后的 residual,所以要写回 HBM 一份。

vLLM 中的 rms_norm kernel(csrc/layernorm_kernels.cu)就是这种结构。实测在 LLaMA-7B 上,融合版相比拆分版能快 30-40%。

8.4.1 多个 fusion 模式

LLM 中常见的 element-wise + reduce 融合:

融合 出现位置 收益
add + RMSNorm 残差 → RMS ~30%
add + LayerNorm 同上(旧模型) ~30%
Softmax + Mask Attention 之前 ~50%(避免 mask 中间矩阵)
Softmax + Dropout Training 时 ~20%
Linear + bias + GeLU FFN 中 ~20%
GeMM + LoRA LoRA fine-tune ~10-15%

8.5 Level 3:跨异构算子融合(GEMM Epilogue)

最深度的融合是把 element-wise 算子直接嵌入到 GEMM 的 epilogue(输出阶段)。这需要硬件友好的实现,CUTLASS 提供了完善的支持。

GEMM 的标准 epilogue 是 D = alpha * (A @ B) + beta * C。CUTLASS Epilogue 让你可以自定义这个最终阶段

// 伪代码:GEMM + bias + ReLU
template <typename ElementOutput>
struct EpilogueOpBiasReLU {
    ElementOutput bias;
    __device__ ElementOutput operator()(ElementOutput accumulator) {
        ElementOutput biased = accumulator + bias;
        return biased > 0 ? biased : ElementOutput(0);
    }
};

这个 epilogue 在 GEMM kernel 的最后阶段执行——accumulator 还在寄存器里时,就直接加 bias、过 ReLU、写出去。完全没有中间 HBM 写

效果:

对比写法                     | HBM 流量            | 性能
─────────────────────────────────────────────────────────
GEMM + 单独 add + ReLU       | 3 次 D 矩阵 HBM    | 基线
GEMM(epilogue=add+ReLU)      | 1 次 D 矩阵 HBM    | 比基线快 ~25%

CUTLASS 的 epilogue API 在 3.x 版本里被设计成了 CollectiveEpilogue——可以组合多个 element-wise 操作,用 lambda / functor 风格写。第 13 章会详细讲 CUTLASS 设计哲学。

8.5.1 vLLM / TensorRT-LLM 的常见 epilogue

工业级推理引擎里高频出现的 epilogue:

  1. Linear + bias + activation:FFN 的标配。
  2. Linear + LoRA add:LoRA 推理。
  3. Linear + scale + add residual:很多 LLM 把 residual add 直接 fuse 到 GEMM。
  4. Linear + quantize:W8A8/W4A16 量化推理时把输出量化也融合。

这些 epilogue 一般用 CUTLASS 写。手写 CUDA C++ 也行,但代码量大。

8.6 RoPE:一个特殊的 element-wise 算子

RoPE(Rotary Position Embedding)值得单独说一下。它是现代 LLM(LLaMA、Qwen 等)的位置编码,本质是对 Q 和 K 做按位置依赖的旋转

q2i=q2icos(θi,p)q2i+1sin(θi,p)q2i+1=q2isin(θi,p)+q2i+1cos(θi,p)\begin{aligned} q'_{2i} &= q_{2i} \cos(\theta_{i,p}) - q_{2i+1} \sin(\theta_{i,p}) \\ q'_{2i+1} &= q_{2i} \sin(\theta_{i,p}) + q_{2i+1} \cos(\theta_{i,p}) \end{aligned}

其中 θi,p=p/100002i/d\theta_{i,p} = p / 10000^{2i/d} 依赖于位置 pp 和维度 ii

RoPE 的算术强度只有 ~1(每对元素 4 个浮点操作 vs 8 字节读+写),是带宽 bound。但它有几个特点让它特别适合融合:

  1. per-head 的局部性:旋转操作只在每 (q_head, dim_pair) 内进行。
  2. cos/sin 表可以预计算(用 max_seq_len × head_dim 大的 lookup table)。
  3. 可以和 QKV projection 的 epilogue 融合qkv_proj 输出后直接 RoPE。

vLLM 中 RoPE 的实现是单独的 kernel,但 TensorRT-LLM 和 SGLang 都尝试过把它 fuse 到 QKV projection 的 epilogue 里——能节省 ~15% 的 attention 前置开销。

8.7 TMA 在 Element-wise Kernel 中的应用

到这里读者可能会问:第 4 章那么强调 TMA,element-wise kernel 用得上吗?

用得上,但收益不像 GEMM 那么大

Element-wise kernel 的主要瓶颈是 HBM 带宽,不是指令带宽——TMA 对带宽本身没有提升(HBM 物理带宽是固定的)。但 TMA 的几个特性还是有帮助:

  1. 省去地址计算:32 线程的 vectorized load 需要每线程算地址;TMA 一条指令搞定,省下指令带宽给真正的算术用。
  2. Async 提供更多 ILP:TMA 是异步的,可以在拷贝时同时算下一组。但 element-wise 算术开销极低,收益不明显。
  3. 二维数据布局更优雅:处理 [B, H] 矩阵时,二维 TMA 比手写 stride 计算更干净。

实测下来,element-wise kernel 用 TMA vs 不用,性能差距通常在 5-10%,远不如 GEMM/FA 的 ~50%。所以工业级实现里:

8.8 一个完整案例:SwiGLU FFN 融合

LLaMA / Mistral 的 FFN 是 SwiGLU

FFN(x)=down_proj(SiLU(gate_proj(x))up_proj(x))\text{FFN}(x) = \text{down\_proj}(\text{SiLU}(\text{gate\_proj}(x)) \odot \text{up\_proj}(x))

朴素实现:

g = gate_proj(x)         # GEMM: [B, H] -> [B, F]
u = up_proj(x)           # GEMM: [B, H] -> [B, F]
m = silu(g) * u          # element-wise: [B, F]
y = down_proj(m)         # GEMM: [B, F] -> [B, H]

注意 silu(g) * u 这一步:

// 朴素 element-wise
__global__ void silu_mul(const float* g, const float* u, float* m, int N) {
    int tid = blockIdx.x * blockDim.x + threadIdx.x;
    if (tid < N) {
        float gv = g[tid];
        float silu_g = gv * (1.0f / (1.0f + expf(-gv)));  // SiLU
        m[tid] = silu_g * u[tid];
    }
}

可以做的优化:

A. Fuse SiLU 到 gate_proj 的 epilogue

g_silu = gate_proj_with_silu_epilogue(x)   # 一个 kernel
u = up_proj(x)
m = g_silu * u
y = down_proj(m)

B. 进一步 fuse 上面三步(在 H100 上越来越常见):

m = gemm_silu_mul(g_silu, u, x, gate_proj.weight, up_proj.weight)
# 一个 kernel 内: gate_proj GEMM -> SiLU -> 与 u 相乘
y = down_proj(m)

C. 极致情况:把 down_proj 的 prologue 也融合

y = full_swiglu_fused(x, gate_w, up_w, down_w)
# 一个 kernel: gate -> silu -> mul -> down -> output

C 这种程度的融合在通用框架里很少见,但 NVIDIA 的 TensorRT-LLM 在某些特殊情况下会生成这种巨型 kernel。代价是代码复杂度和编译时间。

实测在 LLaMA-7B 上,做到 B 这种程度的融合可以让 FFN 总延迟降低 15-20%。

8.9 Fusion 的边界与陷阱

不是融合越多越好。Fusion 有几个隐藏成本:

8.9.1 寄存器压力

把多个算子塞到一个 kernel 里,每线程需要的中间寄存器变多。Hopper 单线程 register 上限 255 个,超过就 spill。寄存器压力大会让 occupancy 降低,反而拖慢整体。

8.9.2 SMEM 容量

Fused kernel 经常需要在 SMEM 里 stage 中间数据。但 SMEM 总共 228 KB,过度融合会撑爆。

8.9.3 编译时间爆炸

CUTLASS 的模板化 fused kernel,编译可以慢到几分钟一个 kernel。生产环境一定要 cache 编译产物。

8.9.4 调试难度

朴素拆分的 kernel 每一步可以单独打印中间结果;fused kernel 的中间结果在寄存器里,调试需要技巧(用 conditional 写回 HBM 一段查看)。

8.9.5 错误传播

Fused kernel 的一个 bug 可能影响多个 op 的正确性。强烈推荐:fused 版本和拆分版本同时存在,CI 测试都跑,互相验证数值一致。

8.10 这一章的小结与下一章

这一章建立了 LLM kernel 优化的"融合直觉":

  1. Element-wise 算子在 LLM 中无处不在:每个 decoder block 都有 4-6 个,是 HBM 带宽的纯消费者。
  2. 三个融合层次:纯 element-wise 串接 → element-wise + reduce → 跨异构(GEMM epilogue)。每深一层收益和复杂度都更大。
  3. vLLM 的 fused_add_rms_norm 是 LLM 推理引擎的"小招牌":典型的 element-wise + reduce 融合,30%+ 提升。
  4. CUTLASS Epilogue 是工业级 GEMM fusion 的事实标准:第 13 章会详细讲。
  5. 过度融合有反效果:寄存器压力、SMEM 容量、编译时间、调试难度都是隐藏成本。

第 9 章我们继续往 LLM 推理的深处走——讲 量化 Kernel。INT8 / FP8 / INT4 量化是 LLM 推理性能跃迁的另一个关键,但 dequantize(解码)这一步本身又是一个"小算子",需要特别的设计才能高效。读完第 9 章读者会理解为什么 Marlin 的 INT4 GEMM 能比 cuBLAS 的 FP16 GEMM 还快。

本章动手练习

  1. 实现 add + rms_norm 的拆分版和 fused 版,对比 H=8192 时的延迟。预期 fused 版快 30%+。
  2. 用 PyTorch 的 torch.compile 编译一个 LLM block,看 TorchInductor 生成的 Triton kernel——观察它做了哪些 element-wise fusion。
  3. 思考:如果 GEMM 的输出是 INT8(量化推理),怎么把 dequantize(INT8 → FP16)和后续 RMSNorm 融合起来?