第20章 量化与混合精度训练

“Mixed precision is the cheap free lunch of deep learning. Quantization is the cheap not-quite-free dinner.”

—— PyTorch quantization team

本章要点

  • 混合精度 (AMP) 让 forward / backward 大部分走 fp16 或 bf16,省一半显存 + 1.5-2x 加速。autocast + GradScaler 两件套
  • autocast 是 dispatcher mode(第 5 章 §5.7):拦截每个算子,根据 op 类别决定要不要把 fp32 输入 cast 到 fp16
  • GradScaler 解决 fp16 梯度下溢:把 loss 放大固定倍数让 grad 不变 0,反向后再 unscale;如果 inf/nan 检测到则跳过该 step
  • bf16 不需要 GradScaler(数值范围与 fp32 同),是大模型训练首选
  • 量化(quantization)针对推理,不是训练:把 fp16/fp32 模型转 int8,4x 显存节省 + 2-3x 推理加速
  • PT2E 是新一代量化框架:基于 torch.export + Inductor,取代 v1.x 的 FX graph mode

20.1 为什么要降精度

fp32 单精度浮点:4 字节/数。70B 模型 = 280GB params。fp16 半精度:2 字节/数 → 140GB,省一半。int8:1 字节/数 → 70GB,省 3/4。

但精度降低有代价:

  • fp16 数值范围 ±65504,太小的梯度(如 1e-7)会下溢成 0
  • bf16 范围与 fp32 一样大但只有 7 位尾数(fp16 是 10 位),表达精度低
  • int8 只能表示 256 个值,需要”量化-反量化”过程

PyTorch 把这些权衡封装进两套独立机制:AMP(混合精度训练)量化(推理)

20.2 AMP:混合精度训练

torch.amp.autocast + torch.amp.GradScaler 是经典组合:

from torch.amp import autocast, GradScaler

scaler = GradScaler()

for batch in loader:
    optimizer.zero_grad()
    with autocast(device_type='cuda', dtype=torch.bfloat16):
        out = model(batch)             # forward 自动 cast 到 bf16
        loss = loss_fn(out, target)
    scaler.scale(loss).backward()      # bf16 不用 scale, fp16 时 scaler 放大梯度
    scaler.step(optimizer)             # 检查 inf/nan, 没问题就 step
    scaler.update()                     # 调整 scale 系数

20.2.1 autocast 是 dispatcher mode

autocast 是注册在 dispatcher 上的 mode(第 5 章 §5.7)。它工作机制:

  1. 进入 with autocast(dtype=bfloat16):在 thread-local 加入 AutocastCUDA DispatchKey
  2. 每次算子调用进 dispatcher,命中 AutocastCUDA key
  3. autocast kernel 检查这个 op 的”cast 策略”:
    • 应该用 bf16(GEMM 类):cast 输入到 bf16,调底层 op
    • 应该用 fp32(reduction、normalize 等,数值敏感):保留 fp32
    • 保留输入 dtype:不 cast
  4. redispatch 到下一层(autograd / backend)

每个 op 的策略硬编码在 aten/src/ATen/autocast_mode.cpp 里。比如 mm / linear / conv2d 都是 bf16,softmax / cross_entropy / layer_norm 都是 fp32。这套白名单是社区精心调出的”哪些 op 在 bf16 下安全”经验。

20.2.2 GradScaler:解决 fp16 下溢

fp16 数值范围 ±65504,但 ML 梯度经常在 1e-5 到 1e-7 之间。直接 fp16 backward 大量梯度下溢成 0 → 训练不收敛。

GradScaler 的招数:

loss * scale (default 65536) → backward → grads 自动放大 scale 倍

                                      unscale grads (除以 scale)

                                      检查 inf/nan

                              没问题 → optimizer.step()
                              有问题 → 跳过 step, 把 scale 减半重试

scaler.update() 用一个状态机调 scale:连续 N 步无 inf 就翻倍 scale(找最大可用值),出现 inf 就减半。这种”动态 loss scaling”让 fp16 训练可用。

20.2.3 bf16 不需要 GradScaler

bf16 与 fp32 同 8 位指数 → 数值范围一样大(±3.4e38)。即使梯度小到 1e-30 也不会下溢,不需要 loss scaling。代价是尾数只 7 位,相比 fp32 的 23 位精度低 65000 倍。

实测对绝大多数训练任务,bf16 与 fp32 精度差异不可见(loss 曲线几乎重合)。bf16 是大模型训练的事实标准 —— Llama / Qwen / DeepSeek 都用 bf16 + fp32 master weights 训练。

# bf16 标准训练 (无 GradScaler)
with autocast(device_type='cuda', dtype=torch.bfloat16):
    out = model(x)
    loss = loss_fn(out, y)
loss.backward()
optimizer.step()

代码比 fp16 简单。

20.3 量化:推理时把权重压成 int8

训练用 fp32 / bf16 已经够,推理时再量化。流程:

  1. 训练完拿到 fp32 / bf16 模型
  2. 用一批 calibration data 跑模型,记录每层的激活值范围
  3. 算出每层的量化参数(scale + zero_point):int8_val = round(fp32_val / scale) + zero_point
  4. 把 weight 与 activation 都用 int8 存储
  5. 推理时用 int8 GEMM kernel 算,输出反量化回 fp32

收益:

  • weight 显存 4x 缩小(fp32 → int8)
  • int8 GEMM 在 NVIDIA Tensor Core 上是 fp16 的 2x 吞吐
  • 显存带宽节省 75%(memory-bound op 提升明显)

代价:精度损失(typically 0.5-2% accuracy drop)。生产推理服务能接受这个 trade-off。

20.4 PT2E:v2.x 的新一代量化

PyTorch 量化历史:

  • v1.x eager mode:手动 prepare → calibrate → convert,麻烦
  • v1.x FX graph mode:自动化但用 torch.fx.symbolic_trace,对动态控制流处理差
  • v2.x PT2E (PyTorch 2 Export):基于 torch.export + Inductor,是当前推荐路径

PT2E 流程:

flowchart LR
    Model[fp32 model]
    Model --> Export["torch.export(model)<br/>导出 ExportedProgram"]
    Export --> Prep["prepare_pt2e<br/>插入 observer 节点"]
    Prep --> Cal["跑 calibration data<br/>observer 记录激活范围"]
    Cal --> Conv["convert_pt2e<br/>observer → quantize/dequantize"]
    Conv --> Comp["torch.compile<br/>Inductor 生成 int8 kernel"]
    Comp --> Run[int8 推理]

    style Conv fill:#fef3c7,stroke:#f59e0b,stroke-width:2px

代码:

import torch
from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e
from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer

example_inputs = (torch.randn(1, 3, 224, 224),)
exported = torch.export.export(model, example_inputs)
quantizer = X86InductorQuantizer().set_global(...)
prepared = prepare_pt2e(exported.module(), quantizer)
# ... calibration loop ...
quantized = convert_pt2e(prepared)
optimized = torch.compile(quantized)   # Inductor 生成 int8 kernel

关键好处:

  • 与 torch.compile 完全集成:量化模型仍享受 Inductor 的 fusion、autotune
  • Quantizer 抽象:不同硬件后端(X86 / XNNPACK / 自家芯片)实现自己的 Quantizer,决定哪些 op 怎么量化
  • 基于 ExportedProgram:稳定的 IR,避免 FX 的 trace 不全问题

PT2E 是大模型推理服务的事实新选择。HuggingFace、vLLM 都在迁移过去。

20.5 量化粒度与精度

量化的几个维度:

粒度描述精度 / 性能 trade-off
Per-tensor整张量一个 scale简单但精度差
Per-channel每个 channel 一个 scale精度好(CV 必备)
Per-tokenNLP 中每个 token 一个 scale适合 LLM
Per-group每 N 个元素一个 scale(如 128)LLM 量化新标准 (GPTQ / AWQ)

LLM 量化(如 4-bit AWQ / GPTQ)通常是 per-group + 复杂 calibration,不在 PyTorch 内置量化路径里 —— 用专门库(bitsandbytes / auto-gptq)。

20.5.5 autocast 的 5 类 CastPolicy

aten/src/ATen/autocast_mode.h:416 定义了 5 种 CastPolicy每个算子归到某一类决定 autocast 怎么处理:

CastPolicy含义典型算子
lower_precision_fp把输入全部 cast 到 lower precision (CUDA 默认 fp16,CPU 默认 bf16)mm / addmm / bmm / conv2d / linear
fp32强制全部 fp32 算softmax / log_softmax / cross_entropy / binary_cross_entropy / mse_loss / cumsum / prod
fp32_set_opt_dtype输出 dtype 用户没设就用 fp32(针对带 dtype: Optional 参数的 op)sum / mean / prod(带 dtype kwarg 时)
fp32_append_dtype给本来无 dtype 参数的 overload 自动 append fp32norm 系列
promote取所有输入 dtype 里”最宽”的(fp32 > fp16)addcmul / addcdiv / cat(输入混合 dtype 时)

为什么 softmax 强制 fp32? 因为 exp(x) 在 fp16 下数值范围 ±65504 容易溢出。fp32 能稳定表示。norm 类似 —— 平方和容易炸。mm / conv 这些”线性”算子在 lower_precision 下数值稳定,可以放心走 bf16/fp16。

每个算子被注册到哪一类,定义在 aten/src/ATen/autocast_mode.cpp 的几百行 KERNEL_CUDA(...) / KERNEL_CPU(...) 宏里。新加自定义算子时如果要参与 amp,要给它显式标 CastPolicy(用 torch.cuda.amp.custom_fwd(cast_inputs=...))。

理解这套白名单,你能解释为什么 amp 训练时某个算子”看起来没省内存” —— 它被归在 fp32 类、根本没 cast。

20.5.6 PT2E Quantizer 接口

PT2E 把”哪些 op 怎么量化”抽象成 Quantizer 类。每个硬件后端实现自己的:

Quantizer后端
X86InductorQuantizerx86 CPU + Inductor
XNNPACKQuantizer移动端 XNNPACK
OpenVINOQuantizerIntel OpenVINO
厂商自家 (NPU / MLU / 寒武纪)各自实现

Quantizer 主要 3 个方法:

  • annotate(model):扫描 graph,在每个要量化的 op 上加 QuantizationAnnotation 标记(哪些输入要 quantize、什么 scale 范围)
  • validate(model):annotation 之间一致性检查
  • transform_for_annotation(model):annotation 之前的图变换(如把 conv-bn fuse 后再量化)

prepare_pt2e(exported, quantizer) 内部就是调 quantizer.annotate → 在每个标记位置插 observer 节点。convert_pt2e 把 observer 替换成 quantize / dequantize 节点对。最后 torch.compile 让 Inductor 看到 q → kernel → dq 这种模式自动 fuse 成 int8 kernel。

这套 Quantizer 抽象 + Inductor codegen 让国产芯片厂商接量化路径不用改 PT2E 主仓 —— 实现自家 Quantizer 即可。

20.5.7 fake_quantize 与 QAT

torch/ao/quantization/fake_quantize.pyFakeQuantize 是个 nn.Module:

class FakeQuantize(Module):
    def forward(self, x):
        if self.fake_quant_enabled[0] == 1:
            x = torch.fake_quantize_per_tensor_affine(
                x, self.scale, self.zero_point, self.quant_min, self.quant_max
            )
        return x

fake_quantize_per_tensor_affine 的语义:“先量化再反量化”,模拟量化噪声但保持张量为浮点。这让训练时模型能”看到量化误差”、调整权重适应它。这是 QAT (Quantization-Aware Training) 的核心

PT2E 提供 prepare_qat_pt2e(exported, quantizer) 走 QAT 路径:

  1. 在每个量化点插入 FakeQuantize 模块(比 PTQ 的 observer 多了”反量化”动作)
  2. 用户继续训练几个 epoch,模型在 fake quantize 下学习
  3. convert_pt2e(prepared) 把 FakeQuantize 替换成真量化 / 反量化对

QAT 通常比 PTQ 精度高 1-3%,代价是训练时间 + 复杂度。Llama / Mistral 等大模型量化几乎都走 QAT(直接 PTQ 精度损失太大)。

20.5.8 Observer 类型:量化校准的核心算法

prepare_pt2e 在每个量化点插入的”observer”是 nn.Module,forward 时记录激活值范围、训练完用记录的范围算 scale/zero_pointtorch/ao/quantization/observer.py 提供 4 种主要 Observer:

Observer行号算法
MinMaxObserver440记录所有见过的 min / max,用 [min, max] 算 scale
MovingAverageMinMaxObserver587min/max 用 EMA 平滑,避免被 outlier batch 拖动
PerChannelMinMaxObserver686每个 channel 独立 min/max(per-channel 量化)
HistogramObserver987记录值的直方图,用 KL divergence 选最优 [min, max]

算法差异 + 适用场景

  • MinMaxObserver(最简):scale = (max - min) / (qmax - qmin)。问题:一个 outlier 让整段 range 拉大 → 大部分值挤在窄区间 → 量化精度差
  • MovingAverageMinMax:用滑动平均吸收 outlier 影响。生产 PTQ 推荐
  • PerChannelMinMax:CV 模型必备 —— 不同 channel 的值范围差异大,per-tensor 量化精度损失严重
  • HistogramObserver:最精细 —— 用直方图近似分布,搜索最小化 KL divergence 的 [min, max]。计算贵但精度最高

实战推荐:

  • weights:PerChannelMinMaxObserver(输出 channel 维 per-channel)
  • activations:MovingAverageMinMaxObserver(per-tensor 但平滑)
  • 极致精度:HistogramObserver(calibration 慢 10x 但精度高 1-2%)

PT2E 的 Quantizer.annotate 在每个量化点指定用哪种 Observer,最终 PTQ 精度很大程度由这个选择决定。这套 Observer 体系也是从 v1.x 量化路径继承下来的、PT2E 仍在用的核心组件。

20.5.9 FP8 训练:H100 / B100 时代的新精度

NVIDIA H100 引入 FP8 Tensor Core,吞吐是 BF16 的 2x。FP8 有两种格式:

格式指数位尾数位范围适用
E4M343±448精度高、范围小 → forward / weight 适合
E5M252±57344精度低、范围大 → backward / gradient 适合

PyTorch v2.4+ 内置 FP8 支持,通过 torch.float8_e4m3fn / torch.float8_e5m2 类型暴露。但 FP8 训练不能直接用 autocast 替换 —— 因为 FP8 范围太小,每个 tensor 都需要 per-tensor scale 维护数值范围。

实战 FP8 训练要用 NVIDIA 的 TransformerEngine 库(基于 PyTorch)或 PyTorch 自家的 torchao.float8

import torchao.float8 as float8

# 把 model 的 nn.Linear 转换成 fp8 版本
float8.convert_to_float8_training(model)

# 之后正常训练循环, fp8 处理在 op 内部完成
for batch in loader:
    out = model(batch)
    loss.backward()
    optim.step()

内部机制:每个 fp8 op 维护 input / output 的”amax 历史”,每步根据上一步的 amax 算出最优 scale、用 scale 把 fp32 input cast 到 fp8。这种 delayed scaling 让 FP8 训练既稳定又能用上硬件加速。

实测 Llama-7B FP8 训练比 BF16 快 1.3-1.5x,loss 曲线几乎重合。这是 v2.x 时代降精度的最新前沿。但 FP8 训练对 hyperparameter 敏感(lr / weight decay 都要重调),生产推荐先把 BF16 调通再迁 FP8。

理解 FP8 让你看到精度降级是个”持续向下”的工程方向:fp32 → bf16 → fp8 → 未来 fp4 / int4。每一步硬件支持先到位、PyTorch 跟进、社区调出 best practice。

20.5.10 INT4 LLM 量化:GPTQ / AWQ / NF4 三家算法

LLM 推理量化已经从 int8 卷到 int4 —— 70B 模型 280GB → 35GB,单机 H100 可用。三种主流算法:

GPTQ(One-shot Post-Training):

逐层量化,每层用一批 calibration data 找出最优 4-bit 量化参数。具体:扫描 weight 的每行(每个输出 neuron),用 Hessian 信息决定怎么 round 让最终 output error 最小。算法复杂但 calibration 快(几小时)。

AWQ(Activation-aware):

观察到不同 weight column 的”重要程度”差异大(少数 column 主导大部分 activation)。AWQ 给重要 column 更高 precision、不重要的更低。简化为”找出每层的 1% 重要 column 不量化”。比 GPTQ 简单且精度更高。

NF4 (Normal Float 4-bit)

不是均匀量化,而是按”标准正态分布的分位数”量化。深度网络权重大致服从 N(0, σ²),按分位数量化在概率密度高的区域分辨率高。bitsandbytes / QLoRA 用 NF4。

graph LR
    Fp16[Llama-70B<br/>fp16: 140GB]
    Fp16 -->|GPTQ| GP[GPTQ INT4<br/>~35GB, accuracy -1%]
    Fp16 -->|AWQ| AW[AWQ INT4<br/>~35GB, accuracy -0.5%]
    Fp16 -->|NF4| NF[NF4 + double quant<br/>~17.5GB, accuracy -0.3%]

    style GP fill:#dcfce7
    style AW fill:#dbeafe
    style NF fill:#fef3c7

PyTorch 这层不内置这些算法(在 auto-gptq / bitsandbytes 等独立库),但 v2.4+ 的 torchao 仓库开始集成(torchao.quantization.quant_api.int4_weight_only)。生态的方向:把这些算法标准化、避免每家各自实现。

实战决策

  • 70B+ 模型推理:AWQ INT4(精度 / 速度平衡最好)
  • 受限设备 + LoRA fine-tuning:NF4 + double quant(QLoRA)
  • 自家服务追求极致:自训 GPTQ 量化

理解三家算法让你看到”4-bit 量化”不是黑魔法,而是工程妥协。每家通过不同的 calibration 策略 / 量化方式逼近精度极限。

20.5.11 KV cache 量化:推理服务的关键

LLM 推理的 KV cache 占显存 50%+(长序列时甚至 80%)。70B 模型 4K 序列的 KV cache = 几十 GB。把 KV cache 也量化能省巨大显存。

KV cache 量化的特殊性:

  • 运行时生成:不像 weight 提前量化好,KV 是 forward 时新生成的
  • read-write 频繁:每个 token 都要写新 KV、读历史 KV
  • 精度敏感:attention 输出对 KV 误差敏感

主流方案:

FP8 KV cache(vLLM 用):

generate K, V (bf16) → quantize 到 FP8 (E5M2) → 存进 cache
读 cache → dequantize → attention 计算

实测吞吐降 5-10%(量化 / 反量化开销),但显存省 50%。4K → 8K 序列长度可行

INT8 KV cache(GPTQ / SmoothQuant 类):

更激进,4x 显存节省。但需要更精细的 calibration(要校准 K 与 V 的 scale)。

PyTorch 这层支持 FP8 KV cache 通过 torch.float8_e5m2 类型。vLLM / SGLang 等推理引擎集成了完整 KV cache 量化路径。生产部署 70B+ 模型几乎都要用 —— 否则显存装不下长序列。

理解 KV cache 量化让你看到 LLM 推理的”显存优化”不只是 weight 量化,KV 也是一大块。两个一起量化才能让 70B + 32K 上下文在 8 卡 H100 上跑。

20.5.12 SmoothQuant:处理 outlier 的精度恢复

LLM 量化的真正痛点是 activation 含 outlier —— 少数维度的值远大于其他(差距 10-100x)。直接量化这些维度让 scale 巨大、其他正常维度精度降到几个 bit、模型崩盘。

SmoothQuant 的解:把 outlier “迁移” 到 weight 层。

graph LR
    A[activation X<br/>有 outlier] --> S[除以 scale α]
    W[weight W] --> M[乘以 scale α]
    S --> Mul[X' = X / α]
    M --> WW[W' = W * α]
    Mul --> Out[X' @ W' = X @ W<br/>等价]
    WW --> Out

    style S fill:#fef3c7
    style M fill:#fef3c7

数学上 X @ W = (X/α) @ (αW) 完全等价。但 X’ 没有 outlier(被 α 平滑了)、W’ 多了点动态范围。激活值更”均匀” → 量化精度变好;权重在量化前已经 calibration 过、能容忍微小动态范围扩大。

实战:先在 calibration 阶段计算每层的 α(让 activation 与 weight 量化误差平衡)、把 α 应用到 weight、再量化。整套预处理是”无损”的(α 系数可吸收进 W)。

PyTorch 这层不直接提供 SmoothQuant 实现,但 torchao + 第三方库提供(如 smoothquant pip 包)。实战 LLM 量化几乎都先做 SmoothQuant + 再做 INT8/INT4。这是 INT8 LLM 推理从”精度损失大”变成”几乎无损”的关键工程改造(2022 之后的 paper)。

理解 outlier 处理让你看到量化不是”扔掉 bits 就完了”,而是”扔掉 bits + 重新分配数值范围”的完整流水线。

20.5.13 Master Weights:训练精度的双重存储

混合精度训练的标准实践不只是”算 bf16”,而是 master weights in fp32

fp32 master weight (训练状态, optimizer 看到这个)
    ↓ cast 到 bf16
bf16 weight (forward 用的副本)
    ↓ forward + backward
bf16 grad
    ↓ accumulate 到 fp32 grad
fp32 grad (用 master weight 更新)
    ↓ optimizer.step()
fp32 master weight (新值)
    ↓ cast 到 bf16
bf16 weight (更新后的副本)

为什么需要 master weights?

  • bf16 / fp16 加法数值不稳定:weight = weight - lr * grad,lr 通常 1e-4,weight 1.0,grad 1e-3 → 更新量 1e-7。fp16 表示不出,更新被 round 掉
  • fp32 master 让”小步累积”可行 —— 即便单次更新 1e-7,几千步后能累积成 1e-4 的实质变化

代价:显存翻倍。70B 模型 bf16 weight 140 GB + fp32 master 280 GB = 420 GB。FSDP / ZeRO 通过 sharding master 来解决。

PyTorch 实现:

  • AdamW + AMP:默认在 optimizer 里维护 fp32 momentum / variance + 直接用 fp32 master weight
  • torch.optim.Adam 配合 autocast:optimizer 自动处理 cast
  • FSDP-2 MixedPrecision:精细控制 params / grads / buffers 各自 dtype

理解 master weights 让你看到”AMP 训练”实际是 mixed dtype 流水线:bf16 算、fp32 累、fp32 更新。每个环节用最合适的精度。

20.5.14 Stochastic Rounding:低精度训练的另一招

bf16 算 fp32 加法时如果走标准 round-to-nearest,小幅更新被吃掉(如 1.0 + 1e-7 → 1.0)。Stochastic Rounding:以概率 round 到 floor 或 ceil。

原值: 1.0 + 1e-7 = 1.0000001
fp16 round-to-nearest: 1.0 (差 1e-7 被吃掉)
fp16 stochastic rounding:
  - 99.99% 概率 round 到 1.0 (floor)
  - 0.01% 概率 round 到 1.0001 (ceil, 因为更接近 ceil 0.0001)
期望值 ≈ 1.0000001 (与原值相等)

虽然单次仍可能被 round,但统计期望保留信息。配合大量 batch 训练,最终 weight 收敛到正确位置。

PyTorch v2.4+ 通过 torch.cuda.amp.autocast(stochastic_rounding=True) 或 hardware 级支持(H100/B100 stochastic rounding 是 hardware feature)启用。NVIDIA TransformerEngine 默认开启。

实测:FP8 训练里开 stochastic rounding 让 loss 曲线更接近 BF16 baseline、不开会偏离 1-3%。训练越激进降精度(FP8 / FP4)、stochastic rounding 越关键

20.5.15 Dynamic vs Static Quantization

PyTorch 量化路径分两类:

Static Quantization (PTQ)

calibration 阶段记录 activation 范围 → 推理时用记录的 scale。但实际 activation 可能与 calibration 时不同(用户 input 分布不一定)→ 出现 outlier 时精度差。

Dynamic Quantization

不预先记录 scale,每次 inference 时实时算 input 的 min/max

import torch.quantization as q
model_quant = q.quantize_dynamic(
    model,
    {nn.Linear, nn.LSTM},      # 哪些 module 要量化
    dtype=torch.qint8,
)

每次 forward:

  1. input 来 → 计算 input 的 min/max
  2. 用这个范围算 scale
  3. quantize input
  4. int8 GEMM
  5. dequantize output

代价:每次 inference 多 2 个 reduction(min/max)、几个 cast op,比 static 慢 5-10%。

适用场景

  • Dynamic:input 分布不稳定(如 NLP 不同长度文本、推理服务多租户)
  • Static:input 分布稳定(如 CV 固定图像 size)

LLM 推理通常用 weight static + activation dynamic 混合:weight 已知不变、activation 因输入而异。这是 vLLM / SGLang 的实际做法。

PyTorch 这两条路径分别由 quantize_dynamic / PT2E + convert_pt2e 提供。理解差异让你选对量化策略。

20.5.16 Symmetric vs Asymmetric Quantization

量化的另一个维度:是否需要 zero_point。

Symmetricq = round(x / scale),0 量化为 0。范围 [-127*scale, 127*scale]

Asymmetricq = round(x / scale) + zero_point,可表示任意 [min, max]。范围 [zero_point*scale, (255+zero_point)*scale]

graph LR
    F1[fp32 范围: -3 to +5] --> Sym[Symmetric:<br/>对称到 ±5<br/>scale=5/127<br/>浪费 -5 to -3 区间]
    F1 --> Asym["Asymmetric:<br/>zero_point 让 [-3, 5] 完全填满<br/>精度高一倍"]

    style Sym fill:#fee2e2
    style Asym fill:#dcfce7

取舍

  • Symmetric:硬件友好(int8 mul 只算乘法、不需要 zero_point 减法)。NVIDIA Tensor Core 偏好 symmetric weights
  • Asymmetric:精度高,特别是 ReLU 后的 activation(值都非负,asymmetric 范围利用率高 2x)

实战:

  • weight:symmetric per-channel(硬件加速友好 + 精度足够)
  • activation:symmetric per-tensor(硬件友好)或 asymmetric per-tensor(精度更好)

PT2E 的 Quantizer 在 annotate 时指定每处的 symmetric 选择。X86InductorQuantizer 默认 weight symmetric + activation symmetric。理解这个细节让你 debug 量化精度问题时知道在哪里调。

20.5.17 量化精度调试流程

PTQ 完后发现精度掉 5%+,怎么定位是哪一层量化崩了?标准 flow:

flowchart TD
    Loss[精度严重下降]
    Loss --> Check1{是不是 outlier?}
    Check1 -->|histogram 看激活值分布| Y1[尝试 SmoothQuant + per-channel]
    Check1 -->|否| Check2

    Check2{是不是某层敏感?}
    Check2 -->|Per-layer A/B test:<br/>逐层关闭量化看精度恢复多少| Y2["找出'敏感层' 改 fp16 跑"]
    Check2 -->|否| Check3

    Check3{Calibration data 够吗?}
    Check3 -->|加 batch / 改用 in-domain data| Y3[重新 PTQ]
    Check3 -->|够了仍不行| Y4[切到 QAT 或更高精度 INT8]

    style Y2 fill:#dcfce7
    style Y4 fill:#fef3c7

具体动作:

  1. 看激活分布:对每层 dump activation histogram,看是否有 outlier。outlier 多 → SmoothQuant
  2. A/B test 量化层:逐层把量化关掉(这层用 fp16)测精度。能恢复多少决定这层是否”敏感”
  3. 改 calibration data:用更多 / 更接近真实 input 的 data。NLP 模型用真实推理 prompt、不要用训练数据 calibration
  4. 混合精度:敏感层(通常是 first / last layer + LN)保留 fp16,其他 INT8。这是 LLM INT8 推理的实战做法

torch.ao.quantization.fx.utils.compare_outputs 帮你逐层比较 fp32 vs INT8 输出。这是量化精度 debug 的标准工具。

理解这套 flow 让你在量化崩盘时不抓瞎。“精度损失多少能接受” 也是产品决策(如 search 类应用 1% drop 不可接受、聊天 bot 5% drop 可能 OK)。

20.5.18 group-wise quantization:4-bit 量化的精度密码

LLM 4-bit 量化(GPTQ / AWQ / NF4)几乎都是 group-wise:把 weight 切成 N 个连续元素一组,每组独立 scale。

# weight: shape [out_features, in_features]
# group-wise int4: 每 128 个 in_feature 一个 scale
group_size = 128
weight_int4 = ...   # shape [out, in], dtype=int4
scales = ...        # shape [out, in // group_size], dtype=fp16
zero_points = ...   # shape [out, in // group_size]

为什么不 per-tensor 或 per-channel?

  • per-tensor:精度差到不可用(scale 范围太广)
  • per-channel:精度好但 group_size = in_features(如 4096),仍有 outlier 分布问题
  • per-group (g128):精度接近 fp16,显存仅多 0.5%(scale tensor 比 weight 小 256x)

scale 存储开销的具体数字:

配置weightscales
fp162 byte/param02 byte/param
int4 per-tensor0.5 byte几 byte0.5 byte
int4 per-channel0.5 byte0.0005 byte0.5 byte
int4 g1280.5 byte0.005 byte0.51 byte

g128 几乎不增加存储但大幅提精度。这是 LLM 4-bit 量化的工程精髓。

torchao.quantization.quant_api.int4_weight_only(group_size=128) 是 PyTorch 内置 API,将 nn.Linear 的 weight 改成 int4 g128 + 推理时 dequantize-on-the-fly。INT4 GEMM kernel 内部用一个 lookup table(16 个值)替代 mul,比 fp16 快 1.5-2x。

理解 group-wise 让你看到现代 LLM 量化的实质:用极少存储开销换大幅精度。这是工程取舍调到极致的体现。

20.5.19 amp 接入自定义算子:custom_fwd / custom_bwd

写自定义 torch.autograd.Function 时如何让它 amp-friendly?用 custom_fwd / custom_bwd 装饰器:

from torch.cuda.amp import custom_fwd, custom_bwd

class MyOp(torch.autograd.Function):
    @staticmethod
    @custom_fwd(cast_inputs=torch.float16)
    def forward(ctx, x, weight):
        # x, weight 已被 cast 到 fp16
        ctx.save_for_backward(x, weight)
        out = my_kernel(x, weight)
        return out

    @staticmethod
    @custom_bwd
    def backward(ctx, grad_out):
        # backward 自动用 fp32 (无论 forward 用了什么)
        x, weight = ctx.saved_tensors
        # ... compute grads ...
        return grad_x, grad_weight

cast_inputs=torch.float16 让 forward 输入自动 cast。如果不指定,autocast mode 不会触及 forward 输入(保留用户传入的 dtype)。

custom_bwd 自动给 backward 关闭 autocast,避免反向计算时再次 cast 引入的累积误差。反向应该用 forward 时的 dtype(saved_tensors 是什么 dtype 就用什么)。

实战:写自定义 attention / 自定义 normalization 算子时必须加这两个装饰器,否则 autocast 行为不对(要么没省内存、要么输出 dtype 错乱)。

PyTorch 标准库的算子已经在 autocast_mode.cpp 里注册了 CastPolicy(§20.5.5),用户层 torch.autograd.Function 是另一条路,要靠装饰器主动声明。

20.5.20 量化的演进时间线

PyTorch 量化的关键节点:

版本改进意义
v1.3 (2019)eager mode 量化引入第一个量化路径,手动繁琐
v1.6 (2020)FX graph mode quantization自动化但 trace 限制
v1.10 (2021)quantize_dynamic 稳定推理服务可用
v1.12 (2022)torch.ao 命名 + observer 体系完善量化生态成型
v2.0 (2023)PT2E 实验性基于 export 的新一代
v2.2 (2024)PT2E 稳定FX graph mode 进入维护
v2.4 (2024)torchao 子库引入LLM 量化(int4 / fp8)开始
v2.6 (2025)FP8 训练完整支持H100/B100 时代关键 feature
v2.8 (2025)INT4 weight only + group-wise 内置取代第三方 bitsandbytes
v2.10 (2025)float8 训练 API 稳定大模型训练降精度
v2.11 (2026)量化生态完整(fp32 / bf16 / fp8 / int8 / int4)工程级别成熟

整体趋势:

  • v1.x:把 int8 PTQ / QAT 这套基础打牢
  • v2.0-v2.4:迁到 export-based 的现代量化(PT2E)
  • v2.4+:扩展到 LLM 时代(INT4 + FP8)、把第三方功能收编

理解时间线让你看到量化是个慢工出细活的领域。每个 minor version 都有改进,但每条路径走稳要 1-2 年时间。生产代码尽量跟主流(v2.x 用 PT2E + torchao)、避免用还在迭代的实验性 API。

20.5.21 精度 vs 速度 vs 显存:综合 trade-off 表

把全章话题合起来,给一张实战决策表(H100 + Llama-7B):

配置精度 (vs fp32)推理速度weight 显存KV cache 显存
fp32100%1.0x28 GB16 GB
bf16~99.9%2x14 GB8 GB
fp16 + GradScaler~99.9%2.2x14 GB8 GB
fp8 (E4M3)~99.5%4x7 GB4 GB
INT8 PTQ~99%3x7 GB4 GB
INT8 QAT~99.7%3x7 GB4 GB
INT8 + SmoothQuant~99.6%3x7 GB4 GB
INT4 GPTQ g128~98.5%4x3.5 GB8 GB (KV 仍 bf16)
INT4 AWQ g128~99%4x3.5 GB8 GB
INT4 NF4 + double quant (QLoRA)~99%3x1.75 GB8 GB
AWQ + FP8 KV~98.5%5x3.5 GB4 GB

实战决策

  • 训练:bf16 + autocast + master weights(业界标准);激进试 fp8(H100/B100 + TransformerEngine)
  • 服务推理(精度优先):bf16 / fp16
  • 服务推理(吞吐优先):INT8 PTQ + SmoothQuant
  • 边缘 / 受限显存:INT4 AWQ + FP8 KV
  • fine-tuning 小显存:QLoRA (NF4 + double quant)

理解每个配置的 trade-off 让你 5 分钟为新场景选对量化策略。这是 LLM 推理工程师的核心决策能力。

20.5.22 QLoRA:量化 + LoRA 的工程组合

QLoRA(2023 paper)让 65B 模型能在单张 24GB 4090 上 fine-tune。组合三个技术:

1. NF4 base model

把 base model weight 量化到 NF4(4-bit 正态分布量化)。65B model 130GB → 32GB。

2. LoRA adapter

冻结 base、只训练 LoRA 的 A / B 矩阵(fp16)。LoRA 参数量是 base 的 0.1-1%。

3. Double Quantization

NF4 量化的 scale 自身也量化(NF4 → NF8)。再省 0.4 byte/param × 65B = 26GB。最终 65B model + scales 仅 17.5GB。

graph TB
    Base[65B base model<br/>fp16: 130GB] --> Quant1[NF4 量化<br/>32GB + 8GB scales]
    Quant1 --> Quant2[Double Quantization<br/>32GB + 0.5GB scales = 17.5GB]

    Quant2 --> LoRA[+ LoRA adapter<br/>fp16: ~200MB]
    LoRA --> Train[fine-tune<br/>仅训练 LoRA]
    Train --> Save[保存 LoRA<br/>~50MB]

    style Quant1 fill:#dcfce7
    style Quant2 fill:#fef3c7
    style LoRA fill:#dbeafe

forward 流程:

  1. NF4 weight → on-the-fly dequantize 到 fp16
  2. fp16 weight + LoRA → 算 attention / MLP
  3. backward 只算 LoRA 梯度(base 冻结)
  4. optimizer 只更新 LoRA

代价:每次 forward 要 dequantize(增加 5-15% 开销)。但显存省得离谱 → 让单卡 fine-tune 大模型成为可能。

PyTorch 实现:bitsandbytes 库(最早)+ torchao (v2.4+ 内置)。HuggingFace peft 库把 QLoRA 封装成几行代码:

from peft import LoraConfig, get_peft_model
from transformers import AutoModelForCausalLM

# 加载 base model + 直接量化
model = AutoModelForCausalLM.from_pretrained(
    "Llama-65B",
    quantization_config=BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_use_double_quant=True,
    ),
)

# 加 LoRA
lora_config = LoraConfig(r=16, lora_alpha=32, target_modules=["q_proj", "v_proj"])
model = get_peft_model(model, lora_config)

# 训练
trainer.train()

理解 QLoRA 让你看到现代大模型工程是多技术栈合成的产物:量化 + PEFT + ckpt 增量保存 + 推理时 dequantize。每个独立技术不重要、组合在一起改变可能性。这是 LLM 时代的工程哲学。

20.5.23 INT8 GEMM 的 Tensor Core 实现

为什么 INT8 GEMM 在 Tensor Core 上 2x 于 FP16?看 NVIDIA Tensor Core 设计:

Tensor Core 配置单 cycle 完成A100 ops/cycle/SMH100 ops/cycle/SM
FP32 GEMM8x4x43264
TF32 GEMM8x4x864128
FP16/BF16 GEMM16x8x16128256
INT8 GEMM16x8x32256512
FP8 GEMM (H100+)16x8x32-1024
INT4 GEMM (H100+)16x8x64-2048

每往下一档(fp16 → fp8 → int4),单 cycle 处理元素数翻倍 → 吞吐翻倍。这是硬件层面的”降精度收益”。

但实际生产收益往往不到理论值:

  • memory bandwidth bound:INT8 weight 比 FP16 小一半,理论应快 2x 但 memory 仍是 bottleneck → 实际只快 1.5x
  • dequantize 开销:INT4 weight 要 on-the-fly dequantize 到 FP16/INT8 算 → 抵消部分收益
  • kernel launch overhead:INT4 kernel 复杂、launch overhead 高 → 小 batch 时优势缩小

理解硬件限制让你在做量化决策时不会过度乐观 ——“INT4 比 FP16 快 4x”是上限,实际 2-3x 已经是优秀。配合 batching(让 memory bound 转 compute bound)才能逼近上限。

20.5.24 量化模型部署的端到端 flow

实战量化部署的 7 步流程:

flowchart TD
    Step1[1. 训练 fp32/bf16 模型]
    Step1 --> Step2[2. 选量化策略<br/>INT8 PTQ / INT4 AWQ / FP8]
    Step2 --> Step3[3. 准备 calibration data<br/>~512 个 in-domain samples]
    Step3 --> Step4[4. 跑 calibration]
    Step4 --> Step5[5. 验证精度<br/>vs fp32 测 metrics]
    Step5 --> Q1{精度 OK?}
    Q1 -->|否| Step6[6a. 调 SmoothQuant /<br/>per-channel / 改算法]
    Q1 -->|是| Step7[7. 集成到推理引擎<br/>vLLM / SGLang / TensorRT]
    Step6 --> Step5
    Step7 --> Done[部署上线]

    style Step5 fill:#fef3c7
    style Q1 fill:#dbeafe

每步的工程细节:

  • calibration data:用真实 prompt(不是训练数据)。NLP 任务 256-512 prompt 够;CV 任务 100-500 image
  • 精度验证:定 metrics(如 perplexity / accuracy / BLEU),与 fp32 baseline 对比。关键指标 drop > 1% 立刻调
  • 集成推理引擎:vLLM v0.5+ 接受 AWQ checkpoint;SGLang 支持 GPTQ;TensorRT-LLM 支持 INT8 / FP8
  • A/B test:上线前先 5% 流量灰度、看用户反馈是否有质量下降。即便 perplexity 几乎一致,用户偶尔能感觉到”答案变笨”

整套流程从训练完到上线大约 1-3 天(取决于精度调试轮次)。生产团队通常有专门的”量化工程师”做这条 pipeline。

20.5.25 PT2E + Inductor:量化 kernel 怎么 codegen

PT2E 把 model 转成”含 quantize/dequantize 节点的 fx graph”,但真正的 INT8 kernel 是 Inductor 生成的。看具体怎么协作:

# convert_pt2e 后 model 的 fx graph (简化):
def forward(x):
    x_q = quantize(x, scale=0.1, zp=128)        # fp32 → int8
    w_q = quantize(weight, scale=0.05, zp=0)    # static, 已量化好
    out_q = qlinear(x_q, w_q, bias)              # int8 GEMM, 输出 int32 累加
    out = dequantize(out_q, scale=0.005, zp=0)  # int32 → fp32
    return out

torch.compile(quantized_model) 让 Inductor 看到 quantize → qlinear → dequantize 这种模式。Inductor 的 fusion pattern matcher(§14.5)认得这种 “Q-Linear-DQ” 三元组、合并成一个 INT8 GEMM kernel:

# Inductor 生成的 Triton kernel (简化)
@triton.jit
def fused_qlinear(x_ptr, w_ptr, scale_x, scale_w, ...):
    x_int8 = tl.load(x_ptr, mask, dtype=tl.int8)
    w_int8 = tl.load(w_ptr, mask, dtype=tl.int8)
    out_int32 = tl.dot(x_int8, w_int8)            # INT8 Tensor Core
    out_fp32 = out_int32 * (scale_x * scale_w)    # 直接 dequantize
    tl.store(out_ptr, out_fp32)

整个 GEMM + dequantize 一个 kernel 完成,不写中间 int32 buffer 到 HBM。这是 PT2E + Inductor 的核心价值 —— 让量化模型享受 fusion 红利。

对比 v1.x FX graph mode:每个 quantize / qlinear / dequantize 都是独立 ATen 算子、各自一个 kernel、都过 HBM。PT2E + Inductor 让这套合一 → 比 v1.x 快 1.5-2x。

理解这层让你看到 PyTorch v2.x 量化的现代化不只是”换 IR”,是让量化与编译栈深度集成。这是 PyTorch 整体 v2 化战略的延伸。

20.5.26 AMP × FSDP:分布式训练的精度协作

第 18 章 §18.6 提过 FSDP 的 MixedPrecision 配置。本章看怎么与 autocast 协作。

两条路径

1. autocast + FSDP(不推荐,混乱)

fsdp_model = FSDP(model)
with autocast(dtype=torch.bfloat16):
    out = fsdp_model(x)

问题:FSDP 内部 unshard / reshard 操作的 dtype 由 FSDP MixedPrecision 决定,外层 autocast 只影响 forward op 的 cast。两套 dtype 配置容易冲突。

2. FSDP MixedPrecision(推荐)

mp_policy = MixedPrecision(
    param_dtype=torch.bfloat16,        # weight 用 bf16
    reduce_dtype=torch.float32,         # AllReduce 梯度用 fp32(防 outlier)
    buffer_dtype=torch.bfloat16,        # buffer (BN running stats) 用 bf16
)
fsdp_model = FSDP(model, mixed_precision=mp_policy)

# 不用 autocast
out = fsdp_model(x)                     # forward 自动 bf16

FSDP 内部把 forward 操作都 cast 到 param_dtype、grad reduce 用 reduce_dtype、buffer 操作用 buffer_dtype。比 autocast 控制粒度更细。

实战决策树:

  • 单卡训练:autocast 简单够用
  • DDP:autocast,配合 default reduce dtype
  • FSDP:用 FSDP.MixedPrecision,不要再叠 autocast
  • TP / PP:用对应框架(Megatron / PiPPy)的 mp 配置

这种”分布式训练框架接管 mp 配置”是 v2.x 时代的工程标准。autocast 本身仍重要(dispatcher mode 的实现机制是教材),但生产代码里被框架包了一层。理解两条路径让你不会”抄两份代码导致 dtype 错乱”。

20.5.27 一段总结:精度的多维度

把全章话题合起来,精度策略实际是多维度配置:

graph TB
    subgraph Train[训练阶段]
        T1[Forward dtype: bf16/fp16]
        T2[Backward dtype: bf16/fp16]
        T3[Master weight: fp32]
        T4[Optimizer state: fp32 or 8-bit]
        T5[Grad reduce dtype: fp32]
        T6[Buffer dtype: 同 forward]
    end

    subgraph Infer[推理阶段]
        I1[Weight: int8 / int4]
        I2[Activation: int8 / int8 dynamic]
        I3[KV cache: fp8 / int8]
        I4[Quantization scheme: per-tensor / per-group]
        I5[Calibration: PTQ / QAT]
    end

    style T1 fill:#dcfce7
    style I1 fill:#dbeafe
    style I3 fill:#fef3c7

每个维度独立可调。单看任何一个都简单,但组合空间巨大:bf16 train + int4 AWQ infer + fp8 KV、bf16 train + int8 QAT infer、fp8 train + int8 PTQ infer……

现实生产里通用配置:

  • 训练:bf16 forward + fp32 master + fp32 reduce
  • 推理:INT4 AWQ weight + bf16 activation + FP8 KV cache(vLLM 标准)

但 cutting edge 在不断推进:FP8 训练(H100+)、INT4 训练(实验)、FP4 KV cache(GH200)。理解这些维度让你能跟踪前沿、做出有依据的工程决策。

20.5.28 模型显存估算:训练 vs 推理

经常被问的问题:“这台机器能 train / infer 多大模型?” 用本章的精度知识算。

训练显存(Adam optimizer + 标准 mixed precision):

Per param 显存:
- bf16 weight: 2 bytes
- bf16 gradient: 2 bytes
- fp32 master weight: 4 bytes
- fp32 Adam exp_avg: 4 bytes
- fp32 Adam exp_avg_sq: 4 bytes
合计: 16 bytes/param

70B 模型: 70e9 × 16 = 1120 GB

加上 activation memory(约 weight 的 0.5-2x,取决于 batch size + seq length),70B 训练总占用 ~1500-2500 GB → 必须 FSDP/ZeRO-3 分到 16+ 张 H100(每张 80GB)。

FP8 训练(H100+):

bf16 weight + fp8 forward / backward → 整体降到 ~10 bytes/param,70B = 700 GB。让 8 卡 H100 能 fit。

推理显存(不带 KV cache)

Per param:
- fp16: 2 bytes
- int8: 1 byte
- int4 g128: 0.51 byte

70B 模型推理:
- fp16: 140 GB → 双卡 H100
- int8: 70 GB → 单卡 H100
- int4 AWQ: 35 GB → 单卡 A100/L40 即可

推理 + KV cache

Llama-70B, seq=4096, batch=1:
- KV cache fp16: 16 GB
- KV cache fp8: 8 GB

总:
- fp16 weight + fp16 KV: 156 GB
- int4 + fp8 KV: 35 + 8 = 43 GB

→ int4 + fp8 KV 让 70B 模型 + 4K 序列单卡 A100 (80GB) 跑得动

把这套估算公式记下,让你 5 秒回答”X 模型在 Y 配置上能不能跑、要 reshape 到什么精度”。具体公式:训练显存 ≈ 16 byte/param × 模型参数数;推理显存 ≈ (precision_byte) × 参数数 + KV cache。代入参数数、batch、seq_len 即可估算单卡 / 多卡是否够用。

20.5.29 量化精度问题排查 cheat sheet

实战排查 Cheat sheet:

症状可能原因解决
INT8 推理输出 NaN某层 scale 算成 0calibration data 质量差,加更多样本
INT8 推理输出全 0所有 activation 量化到 0scale 太大,改 per-channel 或 SmoothQuant
INT8 perplexity 飙升 50%某层 outlier 主导用 SmoothQuant 或保留该层 fp16
INT4 推理跑得比 INT8 慢dequantize 开销 + small batch加 batch 让 GEMM 主导、或换 INT8
FP8 训练 loss 发散scale 历史不稳定用 TransformerEngine 的 delayed scaling
AMP 训练慢于 fp32大量 cast 开销检查模型是否大量小算子(kernel launch 主导)
GradScaler scale 一直减半fp16 下 grad 频繁 inf改用 bf16 + 不用 GradScaler
QAT 收敛慢fake_quantize 引入太多噪声先训几 epoch fp32、再开 fake_quantize
PT2E export 失败某个 op 不支持 exporttorch._dynamo.allow_in_graph 或自定义 decomposition
量化后 KV cache 仍占大显存KV 没量化单独配置 KV cache 量化(vLLM 的 kv_cache_dtype

每条都是真实生产场景。把这张表存到 wiki,新人量化项目能省至少 1 周排查时间。

20.5.30 量化的”哲学”:精度是分配的

把全章看完最后留下的核心 insight:

精度不是单一标量,是个分配问题。整个网络总精度预算固定(如 INT8 总位数),怎么分配到不同位置最优?

  • GEMM-heavy 算子linear / conv / mm):bf16/fp16 安全
  • 数值敏感算子softmax / layer_norm / cross_entropy):必须 fp32
  • 大权重 layer(如 transformer 的 FFN):可激进量化(INT4 OK)
  • 小权重 + 关键 layer(first / last layer + LN):保留 fp16
  • 激活值(dynamic):per-tensor 或 per-token
  • 权重(static):per-channel 或 per-group

设计精度策略 = 把精度预算”分配到最需要的地方”。这种思想适用于任何”资源受限的优化问题”:内存预算、算力预算、注意力预算……

autocast 用白名单显式分配(哪些 op 哪个精度);PT2E 用 Quantizer + annotate 灵活分配(哪些位置 INT8 哪些 fp16);FP8 训练用 per-tensor scale 动态分配。每套机制都在解决同一个 meta 问题:给定数据流,把精度预算分配到对的位置

理解这条让你看 PyTorch 量化生态不是 “一堆零散的 API”,而是”精度分配工具的工具箱”。每个工具服务一类分配场景。这是看完全章应该带走的最终认知。

20.5.31 H100 + Llama-70B 训练的实战精度配置

把全章理论落到具体生产配置(基于公开 paper 与社区最佳实践,本章话题的最终汇集):

import torch
import torch.distributed as dist
from torch.distributed.fsdp import FSDP, MixedPrecision, ShardingStrategy
import torchao.float8 as float8

# 1. 模型 + FSDP 配置
model = Llama70B(config)

mp_policy = MixedPrecision(
    param_dtype=torch.bfloat16,
    reduce_dtype=torch.float32,        # 千卡训练防 outlier 累积
    buffer_dtype=torch.bfloat16,
)

fsdp_model = FSDP(
    model,
    sharding_strategy=ShardingStrategy.FULL_SHARD,
    mixed_precision=mp_policy,
    device_id=torch.cuda.current_device(),
)

# 2. (可选) 转 FP8 forward
float8.convert_to_float8_training(
    fsdp_model,
    config=float8.Float8LinearConfig(
        scaling_strategy=float8.ScalingStrategy.DELAYED,
    ),
)

# 3. Optimizer 用 fp32 master + 8-bit Adam (省显存)
import bitsandbytes as bnb
optimizer = bnb.optim.AdamW8bit(
    fsdp_model.parameters(),
    lr=1e-4,
    weight_decay=0.01,
)

# 4. 训练循环(不需要 autocast,FSDP 接管)
for step, batch in enumerate(loader):
    optimizer.zero_grad()
    out = fsdp_model(batch)
    loss = compute_loss(out, target)
    loss.backward()
    optimizer.step()

这套配置的精度分配:

  • forward / backward:bf16(FSDP MixedPrecision)→ FP8(torchao.float8 转换)
  • gradient reduce:fp32(防 outlier 累积)
  • master weight:fp32(FSDP 内部维护)
  • optimizer state:8-bit Adam(bitsandbytes 量化的 Adam moments)

显存占用(每 rank,假设 32 卡):

  • bf16 weight (sharded):140/32 = 4.4 GB
  • fp32 master (sharded):280/32 = 8.8 GB
  • 8-bit Adam state (sharded):(560 → 140) /32 = 4.4 GB
  • bf16 grad buffer:4.4 GB
  • activation:~10 GB
  • 每 rank 总计 ~32 GB,刚好 fit H100 80GB(留余量给 batch size)

这是 70B 大模型训练的”精度优化压榨” —— 每个维度都用最合适的精度,让单机 8卡 + 4机 32 卡能 train 70B。理解每个数字背后的精度选择,才能在自家场景做调整。

20.5.32 vLLM 推理服务的精度配置示例

部署侧的实战配置(vLLM 0.6+,70B 模型 H100 单卡推理):

from vllm import LLM, SamplingParams

# 配置 1: 极致吞吐 (INT4 weight + FP8 KV)
llm = LLM(
    model="path/to/llama-70b-awq",       # 已量化好的 INT4 AWQ checkpoint
    quantization="awq",
    kv_cache_dtype="fp8_e5m2",            # KV cache 用 FP8
    max_model_len=8192,
    tensor_parallel_size=1,
)

# 配置 2: 平衡精度 (INT8 + FP8 KV)
llm = LLM(
    model="path/to/llama-70b",            # 原始 fp16 checkpoint
    quantization="smoothquant",           # 在线 SmoothQuant + INT8
    kv_cache_dtype="fp8_e5m2",
)

# 配置 3: 最高精度 (BF16 + BF16 KV)
llm = LLM(
    model="path/to/llama-70b",
    dtype="bfloat16",
    kv_cache_dtype="auto",                 # 与 weight 同 dtype = bf16
    tensor_parallel_size=2,                # 单卡装不下,TP=2
)

每种配置的实测:

配置weightKV单卡显存吞吐 (tokens/s)精度 (vs bf16)
AWQ INT4 + FP8 KV35 GB8 GB~50 GB450099%
SmoothQuant INT8 + FP8 KV70 GB8 GB~80 GB350099.5%
BF16 (TP=2)70/2 GB8/2 GB~50 GB2200100%

通常生产推荐 AWQ INT4 + FP8 KV —— 精度损失 1% 但吞吐 2x,性价比最高。除非应用对精度极度敏感(如代码生成、math reasoning),否则不该用 BF16。

理解这套对照让你看到本章理论与实际部署的最后一公里。精度策略最终要落在生产 metrics 上:吞吐多少、精度多少、成本多少。这是工程而非学术。

20.5.33 GradScaler 的”_growth_tracker”状态机细节

§20.2.2 提了 GradScaler 动态调 scale,具体内部状态机在 torch/amp/grad_scaler.py(v2.11 实测 714 行;torch/cuda/amp/grad_scaler.py 在 v2.5+ 是 deprecated 转发壳):

class GradScaler:
    def __init__(
        self,
        init_scale=2.**16,        # 起始 scale
        growth_factor=2.0,         # 没问题时翻倍
        backoff_factor=0.5,        # 出问题时减半
        growth_interval=2000,      # 连续 2000 步无 inf 才翻倍 scale
    ):
        self._scale = init_scale
        self._growth_tracker = 0   # 连续无 inf 计数器

scaler.update() 的逻辑:

def update(self, new_scale=None):
    if any_inf_per_device():           # 这步有 inf/nan
        self._scale *= self._backoff_factor    # scale 减半
        self._growth_tracker = 0                # 重置计数
    else:
        self._growth_tracker += 1
        if self._growth_tracker >= self._growth_interval:
            self._scale *= self._growth_factor  # 翻倍 scale
            self._growth_tracker = 0

为什么 growth_interval = 2000?因为:

  • 太小(如 100):scale 震荡(一会儿翻倍一会儿减半),训练初期不稳定
  • 太大(如 10000):scale 调整慢,错过最优值
  • 2000:经验值,让 scale 稳定追踪”当前最大可用 scale”

实战:监控 scaler.get_scale() 在训练前 1k 步快速从 65536 攀升到 ~1e8(调到 fp16 精度极限)→ 之后稳定波动几次。如果一直减半下去(最终 < 1)说明 fp16 训练不稳定,需切 bf16。

理解这个状态机让 fp16 训练的”为什么 scale 在变”不再神秘 —— 它在自适应找最大可用 scale。

20.6 几条工程经验

1. 训练用 bf16 + autocast,不用 fp16:bf16 更稳、不需要 GradScaler 维护。Llama 类大模型几乎都 bf16

2. 但旧硬件(V100 / T4)只支持 fp16:A100 / H100 都同时支持 fp16 + bf16。生产部署确认硬件能力

3. AMP 的”哪些 op 走低精度”是社区调出的白名单:自定义算子要 torch.cuda.amp.custom_fwd(cast_inputs=...) 显式声明

4. GradScaler.step() 内部检查 inf/nan:直接看 scaler.get_scale() 监控,scale 一直被减半说明 fp16 下数值不稳定,要么改 bf16 要么减小 lr

5. PT2E 是新代码首选:v2.6+ 之后 FX graph mode quantization 不再被推荐

6. activation_checkpoint + AMP 兼容:两者正交,可叠加

7. FSDP + AMP:用 FSDP 的 MixedPrecision 配置(第 18 章 §18.6),不是单独 autocast。FSDP 的 mp 控制更细(params/grads/buffers 各自独立 dtype)

8. 量化感知训练 (QAT):边训练边模拟量化,让模型适应量化噪声。prepare_qat_pt2e API。比直接 PTQ 精度高但训练时间多

20.7 跨书关联

  • 第 5 章 §5.7 dispatcher mode:autocast 的实现机制
  • 第 7 章 §7.5.3 saved_tensors_hooks:activation_checkpoint 与 AMP 配合的底座
  • 第 18 章 §18.6 FSDP MixedPrecision:分布式训练里 mp 的精细控制
  • 《vLLM 内核探秘》第 13 章 量化引擎:vLLM 的推理量化(AWQ / GPTQ / FP8)与本章互补 —— vLLM 关注推理时执行,本章关注训练 / 模型转换

20.8 设计启示

精度策略的核心思想:

第一不同操作有不同 dtype 适配性:GEMM 在低精度下数值稳定,softmax / norm / reduce 不行。混合精度比”全模型一种 dtype”灵活得多

第二训练与推理的精度要求完全不同:训练需要稳定的梯度信号,推理只需准确的输出。所以训练用 bf16/fp16,推理可以激进到 int8 / int4

第三autocast 用 dispatcher mode 实现:让”自动 cast”成为非侵入式特性,用户代码不用改

第四GradScaler 的状态机:动态 loss scaling 是”自适应数值稳定”的经典工具

第五精度是分配问题,不是单一标量:每个算子、每层、每个 tensor 都可以独立选 dtype。混合精度的本质是”在数据流的不同位置用不同精度”,把固定预算分配到最需要的地方

第六编译栈与精度栈共用基础设施:dispatcher mode 让 autocast 落地、export 让 PT2E 落地、Inductor 让两者的产物高效执行。这种”基础设施复用”让 PyTorch 的多个高级特性能持续向前演进而不互相阻塞

20.9 跨章呼应:精度与编译栈的协作

把第 12-15 章(编译栈)+ 本章(精度)合起来看:

  • autocast 是 dispatcher mode(§5.7、§20.2.1):编译时 Dynamo 能 trace 进 autocast 决策、把”哪些 op cast 到 bf16”固化到 fx graph
  • PT2E 用 torch.export(§12.8.28、§20.4):量化路径与 export 路径共用同一套 IR
  • Inductor 编译量化 graph(§20.5.25):Q-Linear-DQ 三元组被 fusion pattern 识别为 INT8 GEMM kernel
  • FP8 训练 + torch.compile:torchao.float8 的 fp8 module 也能被 Dynamo trace、由 Inductor 生成 fp8 kernel

这种”精度策略与编译栈深度耦合”是 v2.x 的工程红利。v1.x 时代量化与 torchscript 半割裂、autocast 与 trace 难协作。v2.x 让两条线在 dispatcher mode + export + Inductor 三层全部统一 —— 用户不再需要关心”我用了 amp 还能 compile 吗”,因为答案永远是”能”。

下一章拆 Profiler —— 当训练慢、显存爆、卡 idle,这些症状的诊断都靠 profiler。本章说”精度策略多种”,下一章给”诊断每种策略实际效果”的工具。两章配合让你在精度优化时有据可依:profile 看吞吐 / 显存 / 实际 dtype、再回到本章选合适配置、再 profile 验证。这种”测量-优化-再测量”的工程闭环是大模型时代的核心方法论。

评论 0