CUDA 算子工程:手写 FlashAttention v2 之路
第 8 章 Element-wise 与算子融合
第 8 章 Element-wise 与算子融合
"An unfused kernel is a kernel that wastes bandwidth twice — once to write, once to read again. In LLM inference, bandwidth is the only thing you have." ——一句在推理引擎团队里被反复重复的话
8.1 LLM 中的 element-wise 算子盘点
打开任何一个现代 LLM(LLaMA、Qwen、Mistral)的 forward 代码,统计一下出现的 element-wise 算子:
# LLaMA-style decoder block (简化版)
def decoder_block(x, ...):
# Attention 子层
h = rms_norm(x) # element-wise (per row)
h = qkv_proj(h) # GEMM (大算子)
q, k = rope(q, k) # element-wise (旋转)
h = attention(q, k, v) # 大算子
h = o_proj(h) # GEMM
x = x + h # element-wise (残差加)
# FFN 子层
h = rms_norm(x) # element-wise
g = gate_proj(h) # GEMM
u = up_proj(h) # GEMM
h = silu(g) * u # element-wise (激活 + 乘)
h = down_proj(h) # GEMM
x = x + h # element-wise
return x
仔细数一下,一个 decoder block 里的 element-wise 操作有:
- 2 次 RMSNorm(attention 前 + FFN 前)
- 1 次 RoPE 旋转(per-head 的 element-wise 复数乘法)
- 1 次 SiLU + 1 次 element-wise 乘(FFN 中的 SwiGLU)
- 2 次残差加(attention 后 + FFN 后)
每个算子都需要遍历整个 hidden state(B × H 大小)。如果每个算子都是独立 kernel,整个 block 仅 element-wise 部分就要 6 次完整的 HBM 往返——B × H × 4 字节读 + B × H × 4 字节写。
对于 7B LLaMA, B=1(单 token decoding), H=4096:
- 一次完整 element-wise pass = 32 KB(不大)
- 但要乘以 32 层 × 6 次 = 192 次完整 pass = 6 MB HBM 流量
- 每次 launch 5μs 开销 × 6 × 32 = ~1ms 仅仅是 launch 开销
而单 token decoding 总延迟目标是 ~10ms。仅仅 element-wise kernel 的 launch 开销就占了 10%——这是不可接受的。
所以算子融合是 LLM 推理中除了 attention、GEMM 优化之外的第二大主题。
8.2 算子融合的三个层次
按融合的"激进程度",可以分成三个层次:
flowchart TB
subgraph L1 [Level 1 · 同类 element-wise 串接]
L1A[add + scale + add] --> L1B[一个 kernel: y = scale*a + b + c]
end
subgraph L2 [Level 2 · element-wise + reduce 融合]
L2A[add + rms_norm] --> L2B[一个 kernel: residual add 在 SMEM 里完成,<br/>直接接 RMS 计算]
end
subgraph L3 [Level 3 · 跨异构算子融合]
L3A[GEMM + bias + GeLU] --> L3B["GEMM epilogue 直接出 GeLU(out + bias)"]
end
每一个层次的难度和收益都更大。
8.3 Level 1:纯 element-wise 串接
最简单的融合:把多个 element-wise 操作写在一个 kernel 里,输入数据从 HBM 读一次,所有算完之后写一次。
// 反例: 三个独立 kernel
add_kernel(a, b, tmp1); // tmp1 = a + b
scale_kernel(tmp1, alpha, tmp2); // tmp2 = alpha * tmp1
add_kernel(tmp2, c, out); // out = tmp2 + c
// 正例: 一个 fused kernel
__global__ void fused(const float* a, const float* b, const float* c,
float alpha, float* out, int N) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < N) {
out[tid] = alpha * (a[tid] + b[tid]) + c[tid];
}
}
这种 fusion 的收益分析:
- HBM 流量:3 次独立 kernel = 6 次 N×4B 读 + 3 次 N×4B 写 = 9N×4B;融合 = 3 次读 + 1 次写 = 4N×4B。节省 56%。
- launch 开销:3 次 → 1 次,节省 ~10μs。
- 代码改动:~10 行。
PyTorch 2.0 的 torch.compile 主要做的就是这个层次的 fusion——通过 TorchInductor 把多个 element-wise op 编译成一个 Triton kernel。
8.4 Level 2:Element-wise + Reduce 融合
更有价值的融合是 element-wise 算子和 reduce 算子(LayerNorm/RMSNorm/Softmax)的组合。比如 vLLM 中的 fused_add_rms_norm:
# 反例: 两个独立 kernel
residual = residual + hidden # 残差加 (element-wise)
hidden = rms_norm(residual) # RMSNorm (有 reduce)
# 正例: 一个 kernel
hidden, residual = fused_add_rms_norm(residual, hidden)
# 内部: 把 residual+hidden 累加到 SMEM, 直接在 SMEM 上做 RMS 计算
完整 kernel:
template <int BLOCK_SIZE = 512>
__global__ void fused_add_rms_norm(
float* __restrict__ residual, // [B, H], in/out
float* __restrict__ hidden, // [B, H], in/out
const float* __restrict__ gamma, // [H]
int H,
float eps
) {
int row = blockIdx.x;
int tid = threadIdx.x;
float* res_row = residual + row * H;
float* hid_row = hidden + row * H;
// Phase 1: 残差加 + 累积平方和 (一遍同时做两件事)
extern __shared__ float smem[];
float ss = 0.0f;
for (int i = tid; i < H; i += BLOCK_SIZE) {
float r = res_row[i];
float h = hid_row[i];
float v = r + h; // 残差加
smem[i] = v; // 暂存到 SMEM (避免再读 HBM)
res_row[i] = v; // 同时把更新的 residual 写回 (下一层用)
ss += v * v; // 累积平方和
}
// Block reduce ss
ss = block_reduce_sum(ss);
__shared__ float rms;
if (tid == 0) rms = rsqrtf(ss / H + eps);
__syncthreads();
// Phase 2: 归一化 + scale (从 SMEM 读, 不再读 HBM)
for (int i = tid; i < H; i += BLOCK_SIZE) {
hid_row[i] = smem[i] * rms * gamma[i];
}
}
关键点:
- HBM 流量减半:原来 add 要读 2 次写 1 次,rms_norm 要读 1 次写 1 次,共 5 次;融合后读 2 次(res+hid)写 2 次(res+hid_norm),共 4 次。
- SMEM 起到中间缓冲作用:残差加的结果暂存到 SMEM,rms_norm 不需要再从 HBM 读。
- Pre-LN 残差也写回:下一层(attention 或 FFN)的输入需要更新后的 residual,所以要写回 HBM 一份。
vLLM 中的 rms_norm kernel(csrc/layernorm_kernels.cu)就是这种结构。实测在 LLaMA-7B 上,融合版相比拆分版能快 30-40%。
8.4.1 多个 fusion 模式
LLM 中常见的 element-wise + reduce 融合:
| 融合 | 出现位置 | 收益 |
|---|---|---|
| add + RMSNorm | 残差 → RMS | ~30% |
| add + LayerNorm | 同上(旧模型) | ~30% |
| Softmax + Mask | Attention 之前 | ~50%(避免 mask 中间矩阵) |
| Softmax + Dropout | Training 时 | ~20% |
| Linear + bias + GeLU | FFN 中 | ~20% |
| GeMM + LoRA | LoRA fine-tune | ~10-15% |
8.5 Level 3:跨异构算子融合(GEMM Epilogue)
最深度的融合是把 element-wise 算子直接嵌入到 GEMM 的 epilogue(输出阶段)。这需要硬件友好的实现,CUTLASS 提供了完善的支持。
GEMM 的标准 epilogue 是 D = alpha * (A @ B) + beta * C。CUTLASS Epilogue 让你可以自定义这个最终阶段:
// 伪代码:GEMM + bias + ReLU
template <typename ElementOutput>
struct EpilogueOpBiasReLU {
ElementOutput bias;
__device__ ElementOutput operator()(ElementOutput accumulator) {
ElementOutput biased = accumulator + bias;
return biased > 0 ? biased : ElementOutput(0);
}
};
这个 epilogue 在 GEMM kernel 的最后阶段执行——accumulator 还在寄存器里时,就直接加 bias、过 ReLU、写出去。完全没有中间 HBM 写。
效果:
对比写法 | HBM 流量 | 性能
─────────────────────────────────────────────────────────
GEMM + 单独 add + ReLU | 3 次 D 矩阵 HBM | 基线
GEMM(epilogue=add+ReLU) | 1 次 D 矩阵 HBM | 比基线快 ~25%
CUTLASS 的 epilogue API 在 3.x 版本里被设计成了 CollectiveEpilogue——可以组合多个 element-wise 操作,用 lambda / functor 风格写。第 13 章会详细讲 CUTLASS 设计哲学。
8.5.1 vLLM / TensorRT-LLM 的常见 epilogue
工业级推理引擎里高频出现的 epilogue:
- Linear + bias + activation:FFN 的标配。
- Linear + LoRA add:LoRA 推理。
- Linear + scale + add residual:很多 LLM 把 residual add 直接 fuse 到 GEMM。
- Linear + quantize:W8A8/W4A16 量化推理时把输出量化也融合。
这些 epilogue 一般用 CUTLASS 写。手写 CUDA C++ 也行,但代码量大。
8.6 RoPE:一个特殊的 element-wise 算子
RoPE(Rotary Position Embedding)值得单独说一下。它是现代 LLM(LLaMA、Qwen 等)的位置编码,本质是对 Q 和 K 做按位置依赖的旋转:
其中 依赖于位置 和维度 。
RoPE 的算术强度只有 ~1(每对元素 4 个浮点操作 vs 8 字节读+写),是带宽 bound。但它有几个特点让它特别适合融合:
- per-head 的局部性:旋转操作只在每 (q_head, dim_pair) 内进行。
- cos/sin 表可以预计算(用 max_seq_len × head_dim 大的 lookup table)。
- 可以和 QKV projection 的 epilogue 融合:
qkv_proj输出后直接 RoPE。
vLLM 中 RoPE 的实现是单独的 kernel,但 TensorRT-LLM 和 SGLang 都尝试过把它 fuse 到 QKV projection 的 epilogue 里——能节省 ~15% 的 attention 前置开销。
8.7 TMA 在 Element-wise Kernel 中的应用
到这里读者可能会问:第 4 章那么强调 TMA,element-wise kernel 用得上吗?
用得上,但收益不像 GEMM 那么大。
Element-wise kernel 的主要瓶颈是 HBM 带宽,不是指令带宽——TMA 对带宽本身没有提升(HBM 物理带宽是固定的)。但 TMA 的几个特性还是有帮助:
- 省去地址计算:32 线程的 vectorized load 需要每线程算地址;TMA 一条指令搞定,省下指令带宽给真正的算术用。
- Async 提供更多 ILP:TMA 是异步的,可以在拷贝时同时算下一组。但 element-wise 算术开销极低,收益不明显。
- 二维数据布局更优雅:处理 [B, H] 矩阵时,二维 TMA 比手写 stride 计算更干净。
实测下来,element-wise kernel 用 TMA vs 不用,性能差距通常在 5-10%,远不如 GEMM/FA 的 ~50%。所以工业级实现里:
- GEMM、FA:必用 TMA。
- 大型 fused 算子(如 fused_add_rms_norm):一般用 TMA。
- 简单 element-wise(比如
y = a * x + b):用普通的 vectorized load,TMA 收益不值得复杂度。
8.8 一个完整案例:SwiGLU FFN 融合
LLaMA / Mistral 的 FFN 是 SwiGLU:
朴素实现:
g = gate_proj(x) # GEMM: [B, H] -> [B, F]
u = up_proj(x) # GEMM: [B, H] -> [B, F]
m = silu(g) * u # element-wise: [B, F]
y = down_proj(m) # GEMM: [B, F] -> [B, H]
注意 silu(g) * u 这一步:
// 朴素 element-wise
__global__ void silu_mul(const float* g, const float* u, float* m, int N) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < N) {
float gv = g[tid];
float silu_g = gv * (1.0f / (1.0f + expf(-gv))); // SiLU
m[tid] = silu_g * u[tid];
}
}
可以做的优化:
A. Fuse SiLU 到 gate_proj 的 epilogue:
g_silu = gate_proj_with_silu_epilogue(x) # 一个 kernel
u = up_proj(x)
m = g_silu * u
y = down_proj(m)
B. 进一步 fuse 上面三步(在 H100 上越来越常见):
m = gemm_silu_mul(g_silu, u, x, gate_proj.weight, up_proj.weight)
# 一个 kernel 内: gate_proj GEMM -> SiLU -> 与 u 相乘
y = down_proj(m)
C. 极致情况:把 down_proj 的 prologue 也融合:
y = full_swiglu_fused(x, gate_w, up_w, down_w)
# 一个 kernel: gate -> silu -> mul -> down -> output
C 这种程度的融合在通用框架里很少见,但 NVIDIA 的 TensorRT-LLM 在某些特殊情况下会生成这种巨型 kernel。代价是代码复杂度和编译时间。
实测在 LLaMA-7B 上,做到 B 这种程度的融合可以让 FFN 总延迟降低 15-20%。
8.9 Fusion 的边界与陷阱
不是融合越多越好。Fusion 有几个隐藏成本:
8.9.1 寄存器压力
把多个算子塞到一个 kernel 里,每线程需要的中间寄存器变多。Hopper 单线程 register 上限 255 个,超过就 spill。寄存器压力大会让 occupancy 降低,反而拖慢整体。
8.9.2 SMEM 容量
Fused kernel 经常需要在 SMEM 里 stage 中间数据。但 SMEM 总共 228 KB,过度融合会撑爆。
8.9.3 编译时间爆炸
CUTLASS 的模板化 fused kernel,编译可以慢到几分钟一个 kernel。生产环境一定要 cache 编译产物。
8.9.4 调试难度
朴素拆分的 kernel 每一步可以单独打印中间结果;fused kernel 的中间结果在寄存器里,调试需要技巧(用 conditional 写回 HBM 一段查看)。
8.9.5 错误传播
Fused kernel 的一个 bug 可能影响多个 op 的正确性。强烈推荐:fused 版本和拆分版本同时存在,CI 测试都跑,互相验证数值一致。
8.10 这一章的小结与下一章
这一章建立了 LLM kernel 优化的"融合直觉":
- Element-wise 算子在 LLM 中无处不在:每个 decoder block 都有 4-6 个,是 HBM 带宽的纯消费者。
- 三个融合层次:纯 element-wise 串接 → element-wise + reduce → 跨异构(GEMM epilogue)。每深一层收益和复杂度都更大。
- vLLM 的 fused_add_rms_norm 是 LLM 推理引擎的"小招牌":典型的 element-wise + reduce 融合,30%+ 提升。
- CUTLASS Epilogue 是工业级 GEMM fusion 的事实标准:第 13 章会详细讲。
- 过度融合有反效果:寄存器压力、SMEM 容量、编译时间、调试难度都是隐藏成本。
第 9 章我们继续往 LLM 推理的深处走——讲 量化 Kernel。INT8 / FP8 / INT4 量化是 LLM 推理性能跃迁的另一个关键,但 dequantize(解码)这一步本身又是一个"小算子",需要特别的设计才能高效。读完第 9 章读者会理解为什么 Marlin 的 INT4 GEMM 能比 cuBLAS 的 FP16 GEMM 还快。
本章动手练习:
- 实现 add + rms_norm 的拆分版和 fused 版,对比 H=8192 时的延迟。预期 fused 版快 30%+。
- 用 PyTorch 的
torch.compile编译一个 LLM block,看 TorchInductor 生成的 Triton kernel——观察它做了哪些 element-wise fusion。- 思考:如果 GEMM 的输出是 INT8(量化推理),怎么把 dequantize(INT8 → FP16)和后续 RMSNorm 融合起来?