CUDA 算子工程:手写 FlashAttention v2 之路

第 9 章 量化 Kernel:INT8 / FP8 / INT4

作者 杨艺韬 · 2,562 字

第 9 章 量化 Kernel:INT8 / FP8 / INT4

"Quantization is not just about saving memory. It's about saving bandwidth — the only resource you actually run out of in LLM inference." ——一句在推理优化讨论中常见的判断

9.1 为什么 LLM 推理必须量化

来看一组数字。LLaMA-2 70B 模型,FP16 权重大小:

70B 参数 × 2 字节/参数 = 140 GB

H100 一张卡只有 80 GB HBM——70B 模型根本放不进单卡。如果 INT8 量化:

70B × 1 字节 = 70 GB    可以放进 H100 80GB 卡!

INT4 量化:

70B × 0.5 字节 = 35 GB  甚至能放进 RTX 4090 24GB(部分卸载)

部署成本还不是量化的最大动机——它真正解决的是带宽瓶颈

LLM 单 token decoding 的本质是"读一遍权重做一次 GEMV"。一个 70B 模型每 token 要读 140 GB 权重(FP16)/ 70 GB(INT8)/ 35 GB(INT4)。在 H100 3.35 TB/s 的 HBM 上:

量化 每 token HBM 流量 理论延迟 实际延迟(含算和其他开销)
FP16 140 GB ~42 ms ~50-60 ms
INT8 70 GB ~21 ms ~25-30 ms
INT4 35 GB ~10 ms ~12-15 ms

INT4 推理的 token/s 是 FP16 的 4 倍——这就是量化在 LLM 推理上的现金价值。

9.2 量化基础:把 FP16 压成整数

量化的核心数学:

xint=round(xfps)zx_{\text{int}} = \text{round}\left(\frac{x_{\text{fp}}}{s}\right) - z xfps(xint+z)x_{\text{fp}} \approx s \cdot (x_{\text{int}} + z)

其中:

例子:把 FP16 范围 [-2.0, +2.0] 压到 INT8:

s=(2.0(2.0))/255=0.0157s = (2.0 - (-2.0)) / 255 = 0.0157
fp16 = 1.5  ->  int8 = round(1.5 / 0.0157) = 96
fp16 = -0.5 ->  int8 = round(-0.5 / 0.0157) = -32

dequant:

int8 = 96  ->  fp16 = 0.0157 * 96 = 1.507  (误差 0.007)

这就是量化的"信息损失"——0.007 的误差。但只要总体上模型仍能给出合理输出,就值得。

9.3 量化方案的"粒度"

scale ss 怎么选?这决定了量化方案的"粒度",权衡精度和元数据开销:

flowchart TB
  subgraph Per-tensor [Per-tensor 一个 scale]
    PT1[整个 weight 矩阵共享一个 s]
    PT2[元数据极少, 1 个 fp32]
    PT3[精度差, 大异常值会主导 scale]
  end
  subgraph Per-channel [Per-channel 每行/列一个 scale]
    PC1[每个输出通道一个 s]
    PC2[元数据 N_out × 4 字节]
    PC3[精度好, 工业标配]
  end
  subgraph Per-group [Per-group 每 K 列一组 scale]
    PG1[每 128 列一组 s]
    PG2[元数据 N × N_groups × 4 字节]
    PG3[精度更好, INT4 必备]
  end
  subgraph Per-token [Per-token 推理时每行一个 s]
    PTK1[激活按 batch 维度量化]
    PTK2[运行时计算 s]
    PTK3[配合 SmoothQuant]
  end

不同方案的精度-开销权衡:

方案 精度损失 元数据 适用
Per-tensor 极小 老的部署,不推荐
Per-channel INT8 PTQ 标配
Per-group (g=128) AWQ / GPTQ INT4 标配
Per-token 运行时算 SmoothQuant W8A8

INT4 必须用 per-group——因为 INT4 范围太窄(只有 16 个值),单个 channel 内不同列的数值范围差异能让 per-channel 的精度崩溃。AWQ 和 GPTQ 默认 group_size=128,意思是每 128 个连续元素共享一个 scale。

9.4 LLM 量化算法概要

LLM 量化论文很多,工业上最常用的几个:

9.4.1 SmoothQuant (W8A8)

权重和激活都量化到 INT8。核心 insight:激活的异常值(outlier)远大于权重,但通过把"激活的难度"挪一部分到"权重的难度",可以让两边都好量化

具体做法:用一个对角矩阵 SS 把激活除回去、权重乘进去:

Y=XW=(XS1)(SW)=XWY = X W = (X \cdot S^{-1})(S \cdot W) = X' W'

数学上完全等价,但 X=X/SX' = X / S 的范围变小了(容易 INT8),W=SWW' = S \cdot W 的范围只略微变大(仍能 INT8)。

9.4.2 GPTQ (W4A16)

权重 INT4,激活 FP16。GPTQ 通过逐列贪心寻找最优量化值——量化第 ii 列时,根据 Hessian 信息调整后面列的权重以最小化误差。开源工具 auto_gptq 是事实标准之一。

9.4.3 AWQ (W4A16)

Activation-aware Weight Quantization。核心 insight:不是所有权重都同样重要——按激活幅度选出 1% "重要权重"保持 FP16,其他 99% 量化。比 GPTQ 更简单,效果相当或更好。

9.4.4 FP8 量化

Hopper 引入的 FP8(E4M3/E5M2)是另一种思路——不再做整数量化,而是用更窄的浮点格式。FP8 量化比 INT8 精度更好(指数位提供动态范围),但需要 Hopper 这种硬件原生支持。NVIDIA 的 Transformer Engine 和 H100 的 FP8 训练栈是这条路线。

9.5 关键 Kernel: Dequant + GEMM

量化 LLM 推理的核心 kernel 是 dequant-fused GEMM

# 输入: 量化的 W (INT4 或 INT8) 和 FP16 激活 X
# 输出: FP16 结果 Y = X @ dequant(W)

# 朴素拆分:
W_fp16 = dequantize(W, scale, zero)  # 临时分配 N×K 的 FP16
Y = X @ W_fp16                        # cuBLAS GEMM

# Dequant-fused:
Y = quantized_gemm(X, W_int4, scale, zero)

朴素版本的问题:

  1. 临时分配 N×K 的 FP16 矩阵:70B 模型最大的层 N=K=8192,临时矩阵 128 MB,HBM 写一次。
  2. W_fp16 写完立刻读:完全是浪费 HBM 带宽。
  3. 失去了量化的带宽优势:cuBLAS GEMM 还是按 FP16 读 W_fp16,HBM 流量没省。

Dequant-fused 的核心思想:在把 W 拉到 SMEM 时就 dequant——SMEM 上仍然是 FP16,但 HBM 仍然只读了 INT4。

flowchart LR
  HBM[HBM] -->|读 INT4 数据 + INT4 scale| SMEM[SMEM/Register]
  SMEM -->|dequant 在寄存器中| FP16[FP16 fragment]
  FP16 -->|喂给 Tensor Core| TC[mma.sync FP16]
  TC -->|FP32 accumulator| OUT[输出 D]

这是工业级量化 GEMM 的标准结构。Marlin、Machete、Bitsandbytes、TensorRT-LLM 的 gptq_marlin_gemm 都是这种思路。

9.6 INT4 解包的黑魔法

INT4 dequant 中最关键的是怎么高效地从 packed INT4 解出 FP16

INT4 用 4 位存储,每字节装 2 个 INT4。从 8 个 packed INT4 字节解出 16 个 FP16 是个"小问题"——但在 GEMM 内层这一步可能跑几亿次,必须做到极致。

朴素写法:

// 8 字节 packed INT4 -> 16 个 FP16 (FP16 = bf16 或 half)
half decoded[16];
#pragma unroll
for (int i = 0; i < 8; ++i) {
    int8_t packed = packed_data[i];
    int8_t lo = packed & 0x0F;       // 低 4 位
    int8_t hi = (packed >> 4) & 0x0F;
    decoded[i * 2 + 0] = __int2half_rn(lo - 8);  // 中心化到 [-8, 7]
    decoded[i * 2 + 1] = __int2half_rn(hi - 8);
}

这段代码 16 次类型转换,编译出来很多指令。性能差。

工业级写法(Marlin 论文里的"快速 INT4 → FP16 解包")用了一个位运算技巧

// Marlin 的 INT4 -> FP16 快速解包 (简化版)
// 利用 FP16 的 IEEE 754 编码: 把 INT4 直接拼到 FP16 的尾数位
__device__ uint32_t int4_to_fp16x2(int8_t packed) {
    // packed: 0xAB (高 4 位 a, 低 4 位 b)
    // 目标: 输出两个 FP16 = (a - 8, b - 8)

    // 把 4 个 INT4 拼成 0x?6?6 模式 (其中 6 是 FP16 尾数高位)
    uint32_t i4s = (uint32_t)packed;
    // 用 lop3 (NVIDIA 三输入逻辑指令) 一次完成 mask + or
    uint32_t result;
    asm("lop3.b32 %0, %1, 0x000F000F, 0x64006400, 0xea;\n"
        : "=r"(result) : "r"(i4s));

    // 此时 result 是两个 FP16, 数值上 = (i4 + 1024) ?
    // 减去常数 (1024 + 8) 得到 (i4 - 8) 的 FP16 表示
    asm("sub.f16x2 %0, %0, %1;\n"
        : "+r"(result)
        : "r"(0x64086408));  // 常数: (1024+8) 编码为 FP16

    return result;
}

这段代码是 Marlin 论文 (Frantar et al., 2024) 中"fast dequantization"的简化版,省略了 group scale 应用部分。

这种"用位运算把整数直接拼成 FP16 编码"的技巧能比朴素写法快 5-10 倍——因为它避免了 __int2half_rn 那种"整数→浮点"硬件转换指令。

类似的技巧在 cutlass / Marlin / Machete 源码里大量存在。这是 LLM 量化 kernel 性能极限的关键之一。

9.7 Marlin INT4 GEMM 简介

Marlin 是 IST Austria 的 Frantar et al. 2024 年发布的开源 INT4 GEMM 实现,单卡 H100 上 batch_size=1 推理可以达到 cuBLAS FP16 性能的 95%——也就是说,INT4 + Marlin 几乎不损失算力,但带宽消耗砍 4 倍。

Marlin 的核心创新:

  1. Async pipeline:用 cp.async.bulk + WGMMA 异步流水线,TMA 拷贝和 Tensor Core 计算重叠。
  2. 快速 INT4→FP16 解包:上面 9.6 节的 lop3 + sub.f16x2 技巧。
  3. Group scale 寄存器存储:group_size=128 的 scale 表很小,全部存寄存器,避免 SMEM 访问。
  4. Tile 配置精调:根据 H100 的 SMEM/寄存器预算精确选择 M/N/K tile 大小。

vLLM 0.5+ 版本默认用 Marlin 跑 GPTQ INT4 模型。从 cuBLAS FP16 切到 Marlin INT4 后,LLaMA-70B 推理吞吐能从 ~50 tok/s 提到 ~200 tok/s。

9.8 FP8 GEMM 的不同路径

Hopper 的 FP8 是另一条路。FP8 不需要 dequant——Tensor Core 直接吃 FP8 输入:

// FP8 直接走 mma.sync
// 输入: FP8 A, FP8 B
// 输出: FP32 累加
asm("mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 ...");

FP8 GEMM 的关键:

  1. 没有 dequant 开销:直接吃 FP8。
  2. 算力翻倍:1979 TFLOPs vs FP16 的 989 TFLOPs。
  3. 每张量需要 scale:FP8 动态范围窄,必须配 scale 防溢出。

NVIDIA 的 Transformer Engine 库提供了"自动 FP8"——它会监控每层激活的范围,动态调整 scale。Megatron-LM 用 TE 训练 FP8 模型。

INT4 vs FP8 的选择:

vLLM 这两种都支持,根据用户指定的量化方案选不同的 kernel 路径。

9.9 量化 Kernel 的工程要点

最后总结量化 kernel 的几个工程要点:

  1. 永远 fuse dequant 到下游计算:单独 dequant 到 HBM 是浪费。
  2. 位运算技巧极重要:INT4/INT8 解包要用 lop3、sub.f16x2 等专用指令。
  3. Group scale 用寄存器存:group_size=128 时 scale 表很小,不要走 SMEM。
  4. Per-token 量化的 reduce 和 GEMM 融合:W8A8 时激活也要量化,把"per-row max + 量化"和上一个 GEMM 的 epilogue fuse。
  5. 数值稳定性测试:量化后模型质量要测,不要相信"应该没问题"——LLM 量化偶尔会让某些 prompt 输出退化。

9.10 这一章的小结与下一篇

第二篇我们走完了 LLM 推理中的"小算子"舞台:

读完这五章,读者已经具备了 LLM 推理中绝大多数"非 GEMM 非 attention"算子的优化能力。剩下的就是两个真正的大头:GEMMAttention

第三篇(第 10-13 章)我们正式进入 GEMM——从朴素 GEMM 出发,经过 Tiled GEMM、Tensor Core GEMM,最后到 CUTLASS 设计哲学。读完第三篇,读者会理解为什么 cuBLAS 的 SGEMM 比朴素写法快 30 倍,以及现代 GEMM kernel 的所有"模板武器"是怎么组装的。

本章动手练习

  1. 实现一个 INT8 dequant kernel,对比朴素版本 vs 用 vectorized load + lop3 的优化版本,看带宽差距。
  2. 用 vLLM 跑同一个模型的 FP16 和 INT4 (Marlin) 版本,记录单 token 延迟和总吞吐。
  3. 阅读 vLLM 的 csrc/quantization/gptq_marlin/gptq_marlin.cu,找到 9.6 节描述的 lop3 dequant 代码段。