第14章 QAT 与 act_quant:量化感知训练全链路
“Train as you serve. If inference is FP8, train with FP8 in the loop.” —— V3 训练栈的核心理念
V4 的 QAT 不是后训练量化的”补救”,而是训练全程都在与量化误差共舞。
14.1 引子:PTQ vs QAT 的根本差别
把模型从 BF16 量化到 FP8 / FP4 通常有两种路径:
PTQ (Post-Training Quantization, 后训练量化):
- 模型用 FP32 / BF16 训完
- 训完之后离线”量化校准”——根据少量样本估算每层 weight 的最佳 scale
- 推理时按校准的 scale 量化运行
- 优点:训练流程不变
- 缺点:量化引入的精度损失没有”训练补偿”——某些极敏感层会塌陷
QAT (Quantization-Aware Training, 量化感知训练):
- 训练时在 forward 路径上插入”假量化”——把激活/权重量化再反量化,模拟推理时的精度损失
- 反向传播时让模型”感知”量化误差,自适应地学到对量化鲁棒的参数
- 优点:模型主动学会”对抗量化误差”,量化后精度损失极小
- 缺点:训练成本上升(多了 quant/dequant 操作)
V4 选 QAT 的理由:
- 1.6T 参数 + FP4 expert 的极端低精度——PTQ 几乎肯定塌陷
- DeepSeek 已经在 V3 上把 QAT 工业化——技术路线成熟
- QAT 的训练成本上升被 FP8 训练栈的整体加速抵消
flowchart LR
subgraph PTQ["PTQ 路径"]
PT1[BF16 训练] --> PT2[训练完成]
PT2 --> PT3[离线校准 scale]
PT3 --> PT4[FP8/FP4 推理]
end
subgraph QAT["QAT 路径 (V4)"]
QT1[BF16 / FP8 训练循环]
QT1 --> QT2[forward: act_quant 注入]
QT2 --> QT3[backward: STE 透传梯度]
QT3 --> QT1
QT1 --> QT4[训练完成]
QT4 --> QT5[直接 FP8/FP4 推理]
end
PT4 -.精度可能塌陷.-> Bad[质量下降]
QT5 --> Good[精度无明显损失]
14.2 act_quant 的 V4 源码
V4 在 inference/model.py 里没有 act_quant 的具体实现——它从 kernel 模块导入:
from kernel import act_quant, fp4_act_quant, fp8_gemm, fp4_gemm, sparse_attn, hc_split_sinkhorn
act_quant 在 DeepGEMM 仓库的 csrc/deep_gemm/act_quant.cu 里实现。从调用点看接口:
# 在 Attention.forward
act_quant(kv[..., :-rd], 64, scale_fmt, scale_dtype, True)
# 在 Indexer.forward
fp4_act_quant(q, fp4_block_size, True)
act_quant 的语义:
act_quant(x: Tensor, block_size: int, scale_fmt, scale_dtype, in_place: bool)
→ 把 x 量化到 FP8 e4m3,按 block_size 分块求 scale
→ 返回 (quantized_x, scale)
→ 如果 in_place=True,直接覆盖 x(可能是个 quant + dequant 的 fake quant 结果)
fp4_act_quant 同理但量化到 FP4 e2m1,block_size 通常是 32(FP4 的标准 block 大小)。
14.3 fake quantization:训练时的”假量化”
QAT 的核心是 fake quantization——训练时做 quant + dequant 一对操作,模拟推理时的精度损失,但保持梯度流动。
具体地:
def fake_quant(x, block_size, dtype):
"""量化 + 反量化,等于训练时模拟推理量化"""
# 1. 计算 per-block scale
blocks = x.view(-1, block_size)
amax = blocks.abs().max(dim=1, keepdim=True).values
scale = compute_scale(amax, dtype) # ue8m0 scale
# 2. 量化
quantized = round_to_grid(blocks / scale, dtype) # FP4 / FP8 grid
# 3. 反量化
dequantized = quantized * scale
return dequantized.view(x.shape)
forward 时这等同于”把 x 替换成 dequant(quant(x))“——数值上等于推理时的精度。
backward 时有个问题:round_to_grid 不可微(rounding 的导数处处为 0 或 inf)。QAT 通常用 Straight-Through Estimator (STE)——把 round 的梯度直接 pass through:
class FakeQuant(torch.autograd.Function):
@staticmethod
def forward(ctx, x, block_size, dtype):
return fake_quant(x, block_size, dtype)
@staticmethod
def backward(ctx, grad_output):
return grad_output, None, None # 直接透传梯度
这种”前向量化、后向透传”的不一致让模型感知量化的数值损失,但学习目标仍然是原始 BF16 的损失——两者一起让训练收敛到”对量化鲁棒”的参数。
14.3·补 fake quantization 的 forward / backward 路径图
flowchart TB
subgraph Forward["forward 路径(量化误差注入)"]
F1["x: BF16"] --> F2["compute amax per block"]
F2 --> F3["scale = 2^ceil(log2(amax/fp_max))<br/>(ue8m0)"]
F3 --> F4["x_scaled = x / scale"]
F4 --> F5["quantized = round_to_grid(x_scaled, FP4/FP8)"]
F5 --> F6["dequantized = quantized × scale"]
F6 --> F7["x_fake_quanted: BF16"]
end
subgraph Backward["backward 路径(STE 透传)"]
B1["grad_output (BF16)"] --> B2["STE: 假装 round 可微"]
B2 --> B3["grad_input = grad_output<br/>(直接透传)"]
end
Forward -.前向数值已被 round.-> Backward
Backward -.后向梯度按未 round 算.-> ParamUpdate["weight 更新走<br/>'对齐量化网格的方向'"]
classDef quant fill:#7c2d12,stroke:#fb923c,color:#ffedd5
classDef ste fill:#312e81,stroke:#a78bfa,color:#ede9fe
class F2,F3,F4,F5 quant
class B2,B3 ste
这种”前向有损 + 后向无损” 的不对称是 STE 的关键——让模型既感知量化误差又能正常学习。
14.4 V4 中的 act_quant 调用点
把 V4 源码里所有 act_quant / fp4_act_quant 的调用点列出来:
| 位置 | 调用 | 量化目标 | block_size |
|---|---|---|---|
Attention.forward | act_quant(kv[..., :-rd], 64, ...) | KV 非 rope 部分 | 64 |
Indexer.forward | fp4_act_quant(q, fp4_block_size, True) | Indexer query | 32 |
Compressor.forward (rotate) | fp4_act_quant(kv, fp4_block_size, True) | Indexer 的 KV | 32 |
Compressor.forward (no rotate) | act_quant(kv[..., :-rd], 64, ...) | 主 attention 的压缩 KV | 64 |
linear 函数内部 | act_quant(x, block_size, ...) | linear 输入 | 128 |
读这张表能看到 V4 的精度策略:
- 关键路径(KV、Q、压缩 KV)走 FP8 + block_size=64:精度敏感
- score net 路径(Indexer / Indexer Compressor)走 FP4 + block_size=32:可承受精度损失
- linear 输入走 FP8 + block_size=128:与 weight block 对齐
每个调用点都对应一段精度损失。模型在 QAT 训练中学会”对这些精度损失鲁棒”——这就是 V4 在 FP4 expert 下仍能保持表达力的原因。
14.5 in-place 修饰符的语义
V4 的 act_quant 调用大多带 in_place=True:
act_quant(kv[..., :-rd], 64, scale_fmt, scale_dtype, True)
fp4_act_quant(q, fp4_block_size, True)
in_place=True 的语义:直接修改输入张量,不返回新张量。
这在工程上的好处:
- 节省显存——不需要为量化后的张量分配新空间
- 让”在某段代码里 x 是 BF16,下一行就是 fake quanted BF16” 的语义直白
但 in_place 操作对 autograd 友好需要小心——PyTorch 的 autograd 默认禁止 in_place 修改有梯度的张量。V4 的实现里,fp4_act_quant 等函数应该有特殊的 autograd 注册(用 Function.apply + STE)让 in_place 与梯度兼容。
14.6 QAT 调试的常见问题
实现自己的 QAT 训练时容易踩的坑:
坑一:量化粒度与 GEMM tile 不对齐
如果 act_quant 用 block_size=64 但 GEMM 用 tile=128——每个 tile 内有两组不同 scale 的元素,TensorCore 只能用一组 scale,会出错。V4 让所有 act_quant 的 block_size 与对应 GEMM 的 tile 对齐(KV 是 64,linear 是 128)——这是必须的。
坑二:scale 的数据类型不匹配
V4 的 ue8m0 scale 在 PyTorch 里通常用 torch.float8_e8m0fnu dtype(如果 PyTorch 版本支持)或 torch.float32 占位。如果你的实现里 scale 是 BF16,per-block scale 的精度不足,QAT 训练会有大量噪声。
坑三:STE 的反向梯度过大
直接透传梯度(STE)在某些极端情况下会让梯度爆炸——量化前后差异 10x 时,梯度仍按原值传递。V4 在 backward 里通常会做 gradient clipping 防止爆炸。
坑四:训练初期的不稳定
QAT 在训练开始就引入量化误差,初期 loss 曲线比 PTQ 训练曲线更陡。建议有一个 warm-up 阶段(前几 K 步只用 BF16,不开启 QAT),然后逐步打开 QAT。V4 的训练曲线(V3 论文中描述类似机制)也用了 warm-up。
14.7 与 PyTorch 的 quantization 工具对比
把 V4 的 QAT 与 PyTorch 官方 quantization 工具对比:
| 工具 | 支持精度 | per-block | grouped | DeepSeek V4 兼容 |
|---|---|---|---|---|
| torch.ao.quantization | INT8 | per-tensor / per-channel | ❌ | ❌ |
| torch.float8_e4m3fn | FP8 | per-tensor | ❌ | 部分(需要外加 scale) |
| torchao FP8 | FP8 | per-tensor / per-row | ❌ | 部分 |
| DeepGEMM act_quant | FP8 + FP4 | per-block ue8m0 | ✅ | ✅ |
V4 的 QAT 工具链与 PyTorch 主流量化工具有显著差异——V4 走的是”DeepGEMM 自带 + 在 model.py 里直接调用”的特化路线,不复用 torch.ao.quantization。
这种特化让 V4 的 QAT 性能极好,但也意味着:如果你想把 V4 的 QAT 思路移植到其他框架(如 transformers),需要自己实现一套 act_quant。
14.8 动手实验:fake_quant 模拟
import torch
def fake_quant_fp8(x, block_size=128):
"""模拟 V4 的 act_quant:FP8 e4m3 + ue8m0 per-block scale"""
shape = x.shape
flat = x.flatten()
blocks = flat.view(-1, block_size)
# ue8m0 scale: 2^ceil(log2(amax / fp8_max))
fp8_max = 448.0 # FP8 e4m3 的最大值
amax = blocks.abs().max(dim=1, keepdim=True).values
scale_exp = torch.ceil(torch.log2((amax / fp8_max).clamp(min=1e-30)))
scale = torch.pow(2.0, scale_exp)
# 量化:把 blocks / scale 截断到 FP8 网格
scaled = blocks / scale
# 简化:用 round 模拟
quantized = torch.round(scaled * 128) / 128 # 简化的 FP8 网格
quantized = quantized.clamp(-fp8_max, fp8_max)
# 反量化
dequantized = quantized * scale
return dequantized.view(shape)
# 对比 BF16 与 fake_quant 的精度损失
torch.manual_seed(0)
x = torch.randn(2048, 2048) * 5
x_fq = fake_quant_fp8(x, block_size=128)
err = (x - x_fq).abs().mean().item()
rel_err = err / x.abs().mean().item()
print(f"Mean abs error: {err:.5f}")
print(f"Relative error: {rel_err*100:.3f}%")
# 对比矩阵乘法的精度
y = torch.randn(2048, 1024) * 5
y_fq = fake_quant_fp8(y, block_size=128)
z_bf16 = x @ y
z_qat = x_fq @ y_fq
mm_err = (z_bf16 - z_qat).abs().mean().item()
mm_rel_err = mm_err / z_bf16.abs().mean().item()
print(f"\nGEMM error after FP8 fake quant:")
print(f"Mean abs error: {mm_err:.5f}")
print(f"Relative error: {mm_rel_err*100:.3f}%")
跑这段代码:典型 FP8 fake quant 的相对误差约 0.5-1.0%,传到 GEMM 后误差约 1-2%。这是 V4 训练时模型实际感知的”量化噪声”。
14.8·补 QAT 与精度退化的攻防
V4 的 QAT 不只是”训练时模拟量化”那么简单——它在训练全程参与一场与精度退化的”攻防战”。把这场攻防的几个关键回合摆出来:
攻防回合一:FP4 expert 的精度地板
FP4 e2m1 只有 16 个离散值——任何 expert weight 在 FP4 下的”最低精度地板” 已经被硬件锁死。QAT 能做的是让模型主动避免落入 FP4 的”精度盲区”——比如让 weight 不要出现需要”在 0.5 与 1.0 之间”精确表达的中间值,而是把 weight 学得”明显偏向某一侧”。
具体做法是:训练时每次 forward 都把 weight 反量化(dequant),用 dequant 后的近似 weight 算 forward。模型感受到”我的 weight 实际上只有这么几个离散值”,就学会把决策边界放到这些离散点之间,避开盲区。
攻防回合二:activation 的动态范围控制
激活值在不同 layer 的数值范围差异巨大——某些 layer 的 activation 可能在 [-1, 1],某些可能在 [-100, 100]。FP8 e4m3 的最大值约 448,但配合 ue8m0 scale 可以覆盖任意范围——前提是 scale 算得正确。
QAT 的工作之一是让 activation 的范围保持稳定——不要让某些 layer 的 activation 突然跨数量级。具体表现是 SwiGLU 的 swiglu_limit=10、HC 的 Sinkhorn 归一化、attn_sink 的可学约束——这些机制一起把 activation 的范围控制在 FP8 友好的区间。
攻防回合三:scale 的”温度调节”
ue8m0 scale 是 2 的整数幂——离散度比”任意 FP32 scale” 大。QAT 训练中,每个 block 的 scale 选择是离散决策——scale = 2^ceil(log2(amax / fp8_max))。这种离散决策会让训练曲线出现”阶梯感”——某个 batch scale 跳一级,loss 突然抖动。
V4 的应对是block_size 的设计:128×128 的 block 内 amax 已经被多个 element 的统计稳定化——不会因为单个 outlier 让整 block 的 scale 跳级。这是为什么 block_size=128 而非 32 的工程理由之一。
攻防回合四:训练后期的”放大效应”
QAT 训练的后期(cooldown 阶段),lr 衰减到很小——这时候每次更新的 weight 变化只有量化网格的 1/100 量级。这种小变化在反量化后可能完全消失——梯度更新”被量化网格吸收”。
V4 的应对是 Muon 的 spectral norm 约束——它让 weight 的更新方向更”显著”,即便幅度小也能逃出量化网格的吸收。这与传统 AdamW + 量化的”小 lr 后期失效” 形成对比。
这四个攻防回合让 V4 的 QAT 在 1.6T 规模下保持稳定——任何一个环节松懈,模型就会塌陷到”低精度训出来的低质量”。
14.8·补·补 QAT 与”训练-推理一致性”的工程承诺
V4 走 QAT 还有一个更深远的工程动机:训练-推理的 bit-level 一致性。
什么是 bit-level 一致性?意思是:模型在训练时见到的精度、推理时跑出来的精度,精确到 bit 都相等。这听起来是个理所当然的需求,但在传统 PTQ 模型里是不成立的——PTQ 模型训练时是 BF16,推理时是 FP8,两者的数值误差从训练完成那一刻起就开始”偏移”。
V4 的 QAT 工程承诺:训练时 forward 路径上每个 op 的精度 = 推理时该 op 的精度。这意味着:
- 你在 V4 训练 loss 上看到的数字,与你在 vLLM 部署 V4 后跑同样输入得到的数字,差异完全可控
- 你做 fine-tune 时如果改了某层 weight,可以精确预测推理时的输出变化
- A/B test 不同训练配置时,差异不会被”训练-推理精度漂移”污染
这种 bit-level 一致性是工程化大模型训练栈的”金标准”——大多数公司做不到,因为它需要训练框架与推理框架共享底层 op。V4 通过把 DeepGEMM 同时放在训练和推理路径,达到了这个标准。
如果你的项目想从 V4 借鉴 QAT 思路,最关键的不是 act_quant 函数本身——而是确保训练和推理的 op 实现完全一致。这是 V4 最难复制的工程能力之一。
14.8·延展 QAT 与 LLM 评估的”精度漂移”检测
V4 的 QAT 让训练-推理精度对齐,但生产中仍需持续监控”是否真的对齐”。这部分监控属于评估工程的范畴——与《LLM 评估工程》卷有直接关联。
精度漂移的来源:
- 软件层:vLLM / DeepGEMM 升级时算子实现可能有微小差异
- 硬件层:不同 GPU 批次的 FP4 / FP8 行为可能略有差异
- 数据层:用户输入分布与训练分布不同,激活值范围偏移
漂移检测方法:
- golden test:选 100-1000 个 fixed prompt + reference output,定期跑一遍对比 token-by-token 差异
- distribution test:监控生产输出的统计分布(token 频率、长度、log-prob),与训练时 reference 对比
- shadow test:把一小部分流量同时跑 V4 production 和 V4 reference(FP32)实例,对比差异
当检测到漂移时的应对:
- 小漂移(<0.1% token-level diff):通常无害,不必处理
- 中漂移(0.1%-1%):可能是软件版本不兼容,回滚或修复
- 大漂移(>1%):必须排查根本原因——可能是部署配置错(如精度模式被误设)
这套精度漂移检测工程在 V3 时代已经成熟,V4 沿用。具体的 evaluation pipeline 实现细节(指标定义、自动化跑批、告警阈值)会在 evals 卷的”训练-推理一致性测试”章节展开。
14.9 延伸阅读
- 量化感知训练综述(arXiv:2103.13630):QAT 经典理论
- Straight-Through Estimator(arXiv:1308.3432):STE 的源头
- DeepSeek-V3 报告中关于 FP8 训练的章节
- 本书第 12 章:FP4 / FP8 / ue8m0 格式细节
- 本书第 13 章:DeepGEMM 内部如何处理 fake quant 后的 GEMM
- 本书第 17 章:Muon 优化器如何与 QAT 配合
14.9·补 QAT 训练栈的 5 个工程教训
V4 的 QAT 把”训练时模拟推理量化” 推到了 1.6T 规模。从源码可以反推出 V4 团队踩过的几个工程教训——任何想复刻 QAT 的项目都值得借鉴:
教训 1:QAT 必须从训练第一步就启用
某些团队尝试”先 BF16 训练,后期切 QAT”——结果是 weight 已经走到 BF16 的某个局部最优,强行切到 FP8/FP4 后陷入塌陷。V4 选择从训练第一步就 QAT——让 weight 一开始就在量化网格上演化,避免局部最优。
教训 2:scale 计算必须可微
ue8m0 scale 是离散的(2 的整数幂)——直接 ceil 计算是不可微的。V4 大概率用了”ceil 的 STE”——前向用 ceil、反向假装可微。如果你的实现里 scale 计算”完全 round 死”,反向梯度会被截断——某些参数训不动。
教训 3:block 边界的 scale 选择很关键
每 128×128 块独立 scale。如果两个相邻块的 scale 差异极大(如 2^4 vs 2^-4),block 边界的 weight 会因为 scale 跳变而出现”discontinuity” ——影响 GEMM 累加精度。V4 通过”训练时让相邻块的 amax 接近”来缓解——具体机制不明,但实测分布显示相邻块 scale 差异通常 ≤2^2。
教训 4:fake quant 的 in-place 操作必须配合 autograd hook
V4 的 fp4_act_quant 是 in-place 操作。PyTorch 的 autograd 默认禁止有梯度张量的 in-place 修改——除非你用自定义 Function + STE。这是 V4 的 kernel 实现里专门处理的——如果你照搬 V4 源码而没配套 autograd hook,反向会报 “modified in-place” 错误。
教训 5:精度漂移监控比”训练 loss 监控”更重要
QAT 的 loss 可能看起来正常下降,但精度在悄悄退化(前向输出与 BF16 reference 偏差增大)。V4 的训练大概率有一套独立的”精度对比” 监控——每隔 N step 跑同一组输入在 BF16 reference 上对比 QAT 输出的 KL 散度。这个指标反映”QAT 是否真的在学量化鲁棒性”——比 loss 更直接。
这 5 个教训是 V4 / V3 时代 QAT 工业化的”血泪经验”——任何后续做 1T+ 规模 QAT 的团队都要重新踩这些坑(除非他们直接借鉴 V4 的经验)。
14.9·补·补 QAT 工程师速记
关键概念:
- QAT = 训练时模拟推理的量化误差
- act_quant:把激活动态量化到 FP8 / FP4
- STE:Straight-Through Estimator,反向梯度透传
- in-place fake quant:节省显存的量化-反量化对
V4 中 act_quant 调用点:
- KV cache 写入:FP8 + block 64
- linear 输入:FP8 + block 128
- Indexer query:FP4 + block 32
- Compressor (rotate=True):FP4 + block 32
调试 QAT 的 5 个症状:
- loss 早期 NaN —— scale 计算溢出
- 训练曲线”阶梯感” —— scale 跳级,正常
- 推理输出与训练 reference 偏差大 —— QAT 没真正生效,回查 act_quant 是否启用
- fine-tune 后量化误差变大 —— fine-tune 时关掉了 QAT
- 某些层精度异常 —— 该层的 block_size 与 GEMM tile 不对齐
与 PyTorch 官方 quantization 的差异:
- V4 用 DeepGEMM 自带 act_quant,不用 torch.ao.quantization
- per-block ue8m0 scale,不是 per-tensor
- in-place 操作(autograd 友好需要自定义 Function)
生产监控指标:
- 推理输出与 BF16 reference 的 KL 散度(应稳定 < 0.01)
- 每层 act_quant 的 amax 分布(应稳定,无突跳)
- 部署后的精度漂移(每周跑 golden test)
14.9·延展 QAT 与”反向传播误差累积” 的边界
V4 在 1.6T + 32T tokens 的训练规模下,QAT 的反向传播误差累积是个真实问题。把这个边界讲清楚。
误差累积来源 1:STE 的不精确
Straight-Through Estimator 把 round 的反向梯度透传——前向数值已经被 round,反向梯度按”未 round” 算。这个不一致让某些极端方向的梯度估计有偏差。
在小模型上偏差可以忽略——大模型 + 长训练后偏差累积成可见的精度损失。
误差累积来源 2:FP4 / FP8 GEMM 的累加误差
虽然 GEMM 内部用 FP32 累加,但累加几千次 token 的输出后,FP32 自身的舍入误差也会累积。
误差累积来源 3:跨 layer 误差放大
第 1 层的小误差经过 60 层会被放大——理论上约 1.0001^60 ≈ 1.006,即 0.6% 放大。这是 dense 模型的情况;sparse + HC + 多种精度的复杂结构里,误差放大可能是 2-5%。
V4 的 4 个对抗机制:
- HC 的 Sinkhorn 双随机性:让 4 路 hidden 的能量守恒——抑制误差累积
- per-layer 精度复位:每层 RMSNorm 把 hidden 重新归一——切断误差链
- attn_sink 的兜底:稀疏 attention 选错时不会让数值崩溃
- swiglu_limit 的 clip:防止 SwiGLU 输出爆炸
这 4 个机制让 V4 的误差累积在可接受范围(最终模型推理 ≈ BF16 reference 的 99%+)。
理解这个边界的意义:
如果你想从 V4 借鉴 QAT 设计到自己的模型,必须同时引入这 4 个对抗机制——否则 QAT 在大模型上会失败。光有 act_quant 不够,必须配套架构 + 数值兜底。
14.10 本章小结
- V4 用 QAT 而非 PTQ——训练时就让模型”看到”量化误差
- act_quant / fp4_act_quant 在 forward 路径上做”假量化”(量化 + 反量化),backward 用 STE 透传梯度
- V4 源码里 5 个 act_quant 调用点:KV、Q、Compressor 双路径、linear 输入——每个点的 block_size 与对应 GEMM tile 对齐
- in_place=True 节省显存——配合自定义 autograd 函数兼容梯度
- QAT 调试 4 个常见坑:粒度不对齐、scale 类型错、STE 梯度爆炸、初期不稳定
- V4 的 QAT 工具链与 PyTorch 官方 quantization 工具差异显著——是 DeepGEMM 自带的特化方案
第 15 章我们离开精度链路,进入 V4 的分布式工程:Tensor / Expert 并行的实现细节。
评论 0
还没有评论,来说两句吧。
评论加载失败,刷新重试。