CUDA 算子工程:手写 FlashAttention v2 之路
第 9 章 量化 Kernel:INT8 / FP8 / INT4
第 9 章 量化 Kernel:INT8 / FP8 / INT4
"Quantization is not just about saving memory. It's about saving bandwidth — the only resource you actually run out of in LLM inference." ——一句在推理优化讨论中常见的判断
9.1 为什么 LLM 推理必须量化
来看一组数字。LLaMA-2 70B 模型,FP16 权重大小:
70B 参数 × 2 字节/参数 = 140 GB
H100 一张卡只有 80 GB HBM——70B 模型根本放不进单卡。如果 INT8 量化:
70B × 1 字节 = 70 GB 可以放进 H100 80GB 卡!
INT4 量化:
70B × 0.5 字节 = 35 GB 甚至能放进 RTX 4090 24GB(部分卸载)
但部署成本还不是量化的最大动机——它真正解决的是带宽瓶颈。
LLM 单 token decoding 的本质是"读一遍权重做一次 GEMV"。一个 70B 模型每 token 要读 140 GB 权重(FP16)/ 70 GB(INT8)/ 35 GB(INT4)。在 H100 3.35 TB/s 的 HBM 上:
| 量化 | 每 token HBM 流量 | 理论延迟 | 实际延迟(含算和其他开销) |
|---|---|---|---|
| FP16 | 140 GB | ~42 ms | ~50-60 ms |
| INT8 | 70 GB | ~21 ms | ~25-30 ms |
| INT4 | 35 GB | ~10 ms | ~12-15 ms |
INT4 推理的 token/s 是 FP16 的 4 倍——这就是量化在 LLM 推理上的现金价值。
9.2 量化基础:把 FP16 压成整数
量化的核心数学:
其中:
- = scale(浮点缩放因子)
- = zero-point(整数零点偏移)
- INT8: 或无符号
- INT4: 或无符号
例子:把 FP16 范围 [-2.0, +2.0] 压到 INT8:
fp16 = 1.5 -> int8 = round(1.5 / 0.0157) = 96
fp16 = -0.5 -> int8 = round(-0.5 / 0.0157) = -32
dequant:
int8 = 96 -> fp16 = 0.0157 * 96 = 1.507 (误差 0.007)
这就是量化的"信息损失"——0.007 的误差。但只要总体上模型仍能给出合理输出,就值得。
9.3 量化方案的"粒度"
scale 怎么选?这决定了量化方案的"粒度",权衡精度和元数据开销:
flowchart TB
subgraph Per-tensor [Per-tensor 一个 scale]
PT1[整个 weight 矩阵共享一个 s]
PT2[元数据极少, 1 个 fp32]
PT3[精度差, 大异常值会主导 scale]
end
subgraph Per-channel [Per-channel 每行/列一个 scale]
PC1[每个输出通道一个 s]
PC2[元数据 N_out × 4 字节]
PC3[精度好, 工业标配]
end
subgraph Per-group [Per-group 每 K 列一组 scale]
PG1[每 128 列一组 s]
PG2[元数据 N × N_groups × 4 字节]
PG3[精度更好, INT4 必备]
end
subgraph Per-token [Per-token 推理时每行一个 s]
PTK1[激活按 batch 维度量化]
PTK2[运行时计算 s]
PTK3[配合 SmoothQuant]
end
不同方案的精度-开销权衡:
| 方案 | 精度损失 | 元数据 | 适用 |
|---|---|---|---|
| Per-tensor | 高 | 极小 | 老的部署,不推荐 |
| Per-channel | 中 | 小 | INT8 PTQ 标配 |
| Per-group (g=128) | 低 | 中 | AWQ / GPTQ INT4 标配 |
| Per-token | 低 | 运行时算 | SmoothQuant W8A8 |
INT4 必须用 per-group——因为 INT4 范围太窄(只有 16 个值),单个 channel 内不同列的数值范围差异能让 per-channel 的精度崩溃。AWQ 和 GPTQ 默认 group_size=128,意思是每 128 个连续元素共享一个 scale。
9.4 LLM 量化算法概要
LLM 量化论文很多,工业上最常用的几个:
9.4.1 SmoothQuant (W8A8)
权重和激活都量化到 INT8。核心 insight:激活的异常值(outlier)远大于权重,但通过把"激活的难度"挪一部分到"权重的难度",可以让两边都好量化。
具体做法:用一个对角矩阵 把激活除回去、权重乘进去:
数学上完全等价,但 的范围变小了(容易 INT8), 的范围只略微变大(仍能 INT8)。
9.4.2 GPTQ (W4A16)
权重 INT4,激活 FP16。GPTQ 通过逐列贪心寻找最优量化值——量化第 列时,根据 Hessian 信息调整后面列的权重以最小化误差。开源工具 auto_gptq 是事实标准之一。
9.4.3 AWQ (W4A16)
Activation-aware Weight Quantization。核心 insight:不是所有权重都同样重要——按激活幅度选出 1% "重要权重"保持 FP16,其他 99% 量化。比 GPTQ 更简单,效果相当或更好。
9.4.4 FP8 量化
Hopper 引入的 FP8(E4M3/E5M2)是另一种思路——不再做整数量化,而是用更窄的浮点格式。FP8 量化比 INT8 精度更好(指数位提供动态范围),但需要 Hopper 这种硬件原生支持。NVIDIA 的 Transformer Engine 和 H100 的 FP8 训练栈是这条路线。
9.5 关键 Kernel: Dequant + GEMM
量化 LLM 推理的核心 kernel 是 dequant-fused GEMM:
# 输入: 量化的 W (INT4 或 INT8) 和 FP16 激活 X
# 输出: FP16 结果 Y = X @ dequant(W)
# 朴素拆分:
W_fp16 = dequantize(W, scale, zero) # 临时分配 N×K 的 FP16
Y = X @ W_fp16 # cuBLAS GEMM
# Dequant-fused:
Y = quantized_gemm(X, W_int4, scale, zero)
朴素版本的问题:
- 临时分配 N×K 的 FP16 矩阵:70B 模型最大的层 N=K=8192,临时矩阵 128 MB,HBM 写一次。
- W_fp16 写完立刻读:完全是浪费 HBM 带宽。
- 失去了量化的带宽优势:cuBLAS GEMM 还是按 FP16 读 W_fp16,HBM 流量没省。
Dequant-fused 的核心思想:在把 W 拉到 SMEM 时就 dequant——SMEM 上仍然是 FP16,但 HBM 仍然只读了 INT4。
flowchart LR HBM[HBM] -->|读 INT4 数据 + INT4 scale| SMEM[SMEM/Register] SMEM -->|dequant 在寄存器中| FP16[FP16 fragment] FP16 -->|喂给 Tensor Core| TC[mma.sync FP16] TC -->|FP32 accumulator| OUT[输出 D]
这是工业级量化 GEMM 的标准结构。Marlin、Machete、Bitsandbytes、TensorRT-LLM 的 gptq_marlin_gemm 都是这种思路。
9.6 INT4 解包的黑魔法
INT4 dequant 中最关键的是怎么高效地从 packed INT4 解出 FP16。
INT4 用 4 位存储,每字节装 2 个 INT4。从 8 个 packed INT4 字节解出 16 个 FP16 是个"小问题"——但在 GEMM 内层这一步可能跑几亿次,必须做到极致。
朴素写法:
// 8 字节 packed INT4 -> 16 个 FP16 (FP16 = bf16 或 half)
half decoded[16];
#pragma unroll
for (int i = 0; i < 8; ++i) {
int8_t packed = packed_data[i];
int8_t lo = packed & 0x0F; // 低 4 位
int8_t hi = (packed >> 4) & 0x0F;
decoded[i * 2 + 0] = __int2half_rn(lo - 8); // 中心化到 [-8, 7]
decoded[i * 2 + 1] = __int2half_rn(hi - 8);
}
这段代码 16 次类型转换,编译出来很多指令。性能差。
工业级写法(Marlin 论文里的"快速 INT4 → FP16 解包")用了一个位运算技巧:
// Marlin 的 INT4 -> FP16 快速解包 (简化版)
// 利用 FP16 的 IEEE 754 编码: 把 INT4 直接拼到 FP16 的尾数位
__device__ uint32_t int4_to_fp16x2(int8_t packed) {
// packed: 0xAB (高 4 位 a, 低 4 位 b)
// 目标: 输出两个 FP16 = (a - 8, b - 8)
// 把 4 个 INT4 拼成 0x?6?6 模式 (其中 6 是 FP16 尾数高位)
uint32_t i4s = (uint32_t)packed;
// 用 lop3 (NVIDIA 三输入逻辑指令) 一次完成 mask + or
uint32_t result;
asm("lop3.b32 %0, %1, 0x000F000F, 0x64006400, 0xea;\n"
: "=r"(result) : "r"(i4s));
// 此时 result 是两个 FP16, 数值上 = (i4 + 1024) ?
// 减去常数 (1024 + 8) 得到 (i4 - 8) 的 FP16 表示
asm("sub.f16x2 %0, %0, %1;\n"
: "+r"(result)
: "r"(0x64086408)); // 常数: (1024+8) 编码为 FP16
return result;
}
这段代码是 Marlin 论文 (Frantar et al., 2024) 中"fast dequantization"的简化版,省略了 group scale 应用部分。
这种"用位运算把整数直接拼成 FP16 编码"的技巧能比朴素写法快 5-10 倍——因为它避免了 __int2half_rn 那种"整数→浮点"硬件转换指令。
类似的技巧在 cutlass / Marlin / Machete 源码里大量存在。这是 LLM 量化 kernel 性能极限的关键之一。
9.7 Marlin INT4 GEMM 简介
Marlin 是 IST Austria 的 Frantar et al. 2024 年发布的开源 INT4 GEMM 实现,单卡 H100 上 batch_size=1 推理可以达到 cuBLAS FP16 性能的 95%——也就是说,INT4 + Marlin 几乎不损失算力,但带宽消耗砍 4 倍。
Marlin 的核心创新:
- Async pipeline:用 cp.async.bulk + WGMMA 异步流水线,TMA 拷贝和 Tensor Core 计算重叠。
- 快速 INT4→FP16 解包:上面 9.6 节的 lop3 + sub.f16x2 技巧。
- Group scale 寄存器存储:group_size=128 的 scale 表很小,全部存寄存器,避免 SMEM 访问。
- Tile 配置精调:根据 H100 的 SMEM/寄存器预算精确选择 M/N/K tile 大小。
vLLM 0.5+ 版本默认用 Marlin 跑 GPTQ INT4 模型。从 cuBLAS FP16 切到 Marlin INT4 后,LLaMA-70B 推理吞吐能从 ~50 tok/s 提到 ~200 tok/s。
9.8 FP8 GEMM 的不同路径
Hopper 的 FP8 是另一条路。FP8 不需要 dequant——Tensor Core 直接吃 FP8 输入:
// FP8 直接走 mma.sync
// 输入: FP8 A, FP8 B
// 输出: FP32 累加
asm("mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 ...");
FP8 GEMM 的关键:
- 没有 dequant 开销:直接吃 FP8。
- 算力翻倍:1979 TFLOPs vs FP16 的 989 TFLOPs。
- 每张量需要 scale:FP8 动态范围窄,必须配 scale 防溢出。
NVIDIA 的 Transformer Engine 库提供了"自动 FP8"——它会监控每层激活的范围,动态调整 scale。Megatron-LM 用 TE 训练 FP8 模型。
INT4 vs FP8 的选择:
- 极致带宽节省、低 batch 推理:选 INT4 (W4A16),4 倍带宽节省。
- 训练或大 batch 推理:选 FP8 (W8A8),无 dequant 开销,算力翻倍。
vLLM 这两种都支持,根据用户指定的量化方案选不同的 kernel 路径。
9.9 量化 Kernel 的工程要点
最后总结量化 kernel 的几个工程要点:
- 永远 fuse dequant 到下游计算:单独 dequant 到 HBM 是浪费。
- 位运算技巧极重要:INT4/INT8 解包要用 lop3、sub.f16x2 等专用指令。
- Group scale 用寄存器存:group_size=128 时 scale 表很小,不要走 SMEM。
- Per-token 量化的 reduce 和 GEMM 融合:W8A8 时激活也要量化,把"per-row max + 量化"和上一个 GEMM 的 epilogue fuse。
- 数值稳定性测试:量化后模型质量要测,不要相信"应该没问题"——LLM 量化偶尔会让某些 prompt 输出退化。
9.10 这一章的小结与下一篇
第二篇我们走完了 LLM 推理中的"小算子"舞台:
- 第 5 章 Reduction:所有归约的祖师爷。
- 第 6 章 Online Softmax:让 softmax 流式化的数学魔法。
- 第 7 章 LayerNorm/RMSNorm:用 Welford 把方差变成 1-pass。
- 第 8 章 Element-wise 融合:把无数小算子拼成大 kernel。
- 第 9 章 量化 Kernel:用 INT4/FP8 解决带宽瓶颈。
读完这五章,读者已经具备了 LLM 推理中绝大多数"非 GEMM 非 attention"算子的优化能力。剩下的就是两个真正的大头:GEMM 和 Attention。
第三篇(第 10-13 章)我们正式进入 GEMM——从朴素 GEMM 出发,经过 Tiled GEMM、Tensor Core GEMM,最后到 CUTLASS 设计哲学。读完第三篇,读者会理解为什么 cuBLAS 的 SGEMM 比朴素写法快 30 倍,以及现代 GEMM kernel 的所有"模板武器"是怎么组装的。
本章动手练习:
- 实现一个 INT8 dequant kernel,对比朴素版本 vs 用 vectorized load + lop3 的优化版本,看带宽差距。
- 用 vLLM 跑同一个模型的 FP16 和 INT4 (Marlin) 版本,记录单 token 延迟和总吞吐。
- 阅读 vLLM 的
csrc/quantization/gptq_marlin/gptq_marlin.cu,找到 9.6 节描述的 lop3 dequant 代码段。