第14章 QAT 与 act_quant:量化感知训练全链路

“Train as you serve. If inference is FP8, train with FP8 in the loop.” —— V3 训练栈的核心理念

V4 的 QAT 不是后训练量化的”补救”,而是训练全程都在与量化误差共舞。


14.1 引子:PTQ vs QAT 的根本差别

把模型从 BF16 量化到 FP8 / FP4 通常有两种路径:

PTQ (Post-Training Quantization, 后训练量化)

  • 模型用 FP32 / BF16 训完
  • 训完之后离线”量化校准”——根据少量样本估算每层 weight 的最佳 scale
  • 推理时按校准的 scale 量化运行
  • 优点:训练流程不变
  • 缺点:量化引入的精度损失没有”训练补偿”——某些极敏感层会塌陷

QAT (Quantization-Aware Training, 量化感知训练)

  • 训练时在 forward 路径上插入”假量化”——把激活/权重量化再反量化,模拟推理时的精度损失
  • 反向传播时让模型”感知”量化误差,自适应地学到对量化鲁棒的参数
  • 优点:模型主动学会”对抗量化误差”,量化后精度损失极小
  • 缺点:训练成本上升(多了 quant/dequant 操作)

V4 选 QAT 的理由:

  • 1.6T 参数 + FP4 expert 的极端低精度——PTQ 几乎肯定塌陷
  • DeepSeek 已经在 V3 上把 QAT 工业化——技术路线成熟
  • QAT 的训练成本上升被 FP8 训练栈的整体加速抵消
flowchart LR
  subgraph PTQ["PTQ 路径"]
    PT1[BF16 训练] --> PT2[训练完成]
    PT2 --> PT3[离线校准 scale]
    PT3 --> PT4[FP8/FP4 推理]
  end
  subgraph QAT["QAT 路径 (V4)"]
    QT1[BF16 / FP8 训练循环]
    QT1 --> QT2[forward: act_quant 注入]
    QT2 --> QT3[backward: STE 透传梯度]
    QT3 --> QT1
    QT1 --> QT4[训练完成]
    QT4 --> QT5[直接 FP8/FP4 推理]
  end
  PT4 -.精度可能塌陷.-> Bad[质量下降]
  QT5 --> Good[精度无明显损失]

14.2 act_quant 的 V4 源码

V4 在 inference/model.py 里没有 act_quant 的具体实现——它从 kernel 模块导入:

from kernel import act_quant, fp4_act_quant, fp8_gemm, fp4_gemm, sparse_attn, hc_split_sinkhorn

act_quant 在 DeepGEMM 仓库的 csrc/deep_gemm/act_quant.cu 里实现。从调用点看接口:

# 在 Attention.forward
act_quant(kv[..., :-rd], 64, scale_fmt, scale_dtype, True)

# 在 Indexer.forward
fp4_act_quant(q, fp4_block_size, True)

act_quant 的语义:

act_quant(x: Tensor, block_size: int, scale_fmt, scale_dtype, in_place: bool)
  → 把 x 量化到 FP8 e4m3,按 block_size 分块求 scale
  → 返回 (quantized_x, scale)
  → 如果 in_place=True,直接覆盖 x(可能是个 quant + dequant 的 fake quant 结果)

fp4_act_quant 同理但量化到 FP4 e2m1,block_size 通常是 32(FP4 的标准 block 大小)。


14.3 fake quantization:训练时的”假量化”

QAT 的核心是 fake quantization——训练时做 quant + dequant 一对操作,模拟推理时的精度损失,但保持梯度流动。

具体地:

def fake_quant(x, block_size, dtype):
    """量化 + 反量化,等于训练时模拟推理量化"""
    # 1. 计算 per-block scale
    blocks = x.view(-1, block_size)
    amax = blocks.abs().max(dim=1, keepdim=True).values
    scale = compute_scale(amax, dtype)  # ue8m0 scale

    # 2. 量化
    quantized = round_to_grid(blocks / scale, dtype)  # FP4 / FP8 grid

    # 3. 反量化
    dequantized = quantized * scale

    return dequantized.view(x.shape)

forward 时这等同于”把 x 替换成 dequant(quant(x))“——数值上等于推理时的精度。

backward 时有个问题:round_to_grid 不可微(rounding 的导数处处为 0 或 inf)。QAT 通常用 Straight-Through Estimator (STE)——把 round 的梯度直接 pass through:

class FakeQuant(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, block_size, dtype):
        return fake_quant(x, block_size, dtype)

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output, None, None  # 直接透传梯度

这种”前向量化、后向透传”的不一致让模型感知量化的数值损失,但学习目标仍然是原始 BF16 的损失——两者一起让训练收敛到”对量化鲁棒”的参数。


14.3·补 fake quantization 的 forward / backward 路径图

flowchart TB
  subgraph Forward["forward 路径(量化误差注入)"]
    F1["x: BF16"] --> F2["compute amax per block"]
    F2 --> F3["scale = 2^ceil(log2(amax/fp_max))<br/>(ue8m0)"]
    F3 --> F4["x_scaled = x / scale"]
    F4 --> F5["quantized = round_to_grid(x_scaled, FP4/FP8)"]
    F5 --> F6["dequantized = quantized × scale"]
    F6 --> F7["x_fake_quanted: BF16"]
  end
  subgraph Backward["backward 路径(STE 透传)"]
    B1["grad_output (BF16)"] --> B2["STE: 假装 round 可微"]
    B2 --> B3["grad_input = grad_output<br/>(直接透传)"]
  end
  
  Forward -.前向数值已被 round.-> Backward
  Backward -.后向梯度按未 round 算.-> ParamUpdate["weight 更新走<br/>'对齐量化网格的方向'"]
  
  classDef quant fill:#7c2d12,stroke:#fb923c,color:#ffedd5
  classDef ste fill:#312e81,stroke:#a78bfa,color:#ede9fe
  class F2,F3,F4,F5 quant
  class B2,B3 ste

这种”前向有损 + 后向无损” 的不对称是 STE 的关键——让模型既感知量化误差又能正常学习。


14.4 V4 中的 act_quant 调用点

把 V4 源码里所有 act_quant / fp4_act_quant 的调用点列出来:

位置调用量化目标block_size
Attention.forwardact_quant(kv[..., :-rd], 64, ...)KV 非 rope 部分64
Indexer.forwardfp4_act_quant(q, fp4_block_size, True)Indexer query32
Compressor.forward (rotate)fp4_act_quant(kv, fp4_block_size, True)Indexer 的 KV32
Compressor.forward (no rotate)act_quant(kv[..., :-rd], 64, ...)主 attention 的压缩 KV64
linear 函数内部act_quant(x, block_size, ...)linear 输入128

读这张表能看到 V4 的精度策略:

  • 关键路径(KV、Q、压缩 KV)走 FP8 + block_size=64:精度敏感
  • score net 路径(Indexer / Indexer Compressor)走 FP4 + block_size=32:可承受精度损失
  • linear 输入走 FP8 + block_size=128:与 weight block 对齐

每个调用点都对应一段精度损失。模型在 QAT 训练中学会”对这些精度损失鲁棒”——这就是 V4 在 FP4 expert 下仍能保持表达力的原因。


14.5 in-place 修饰符的语义

V4 的 act_quant 调用大多带 in_place=True

act_quant(kv[..., :-rd], 64, scale_fmt, scale_dtype, True)
fp4_act_quant(q, fp4_block_size, True)

in_place=True 的语义:直接修改输入张量,不返回新张量。

这在工程上的好处:

  • 节省显存——不需要为量化后的张量分配新空间
  • 让”在某段代码里 x 是 BF16,下一行就是 fake quanted BF16” 的语义直白

但 in_place 操作对 autograd 友好需要小心——PyTorch 的 autograd 默认禁止 in_place 修改有梯度的张量。V4 的实现里,fp4_act_quant 等函数应该有特殊的 autograd 注册(用 Function.apply + STE)让 in_place 与梯度兼容。


14.6 QAT 调试的常见问题

实现自己的 QAT 训练时容易踩的坑:

坑一:量化粒度与 GEMM tile 不对齐

如果 act_quant 用 block_size=64 但 GEMM 用 tile=128——每个 tile 内有两组不同 scale 的元素,TensorCore 只能用一组 scale,会出错。V4 让所有 act_quant 的 block_size 与对应 GEMM 的 tile 对齐(KV 是 64,linear 是 128)——这是必须的。

坑二:scale 的数据类型不匹配

V4 的 ue8m0 scale 在 PyTorch 里通常用 torch.float8_e8m0fnu dtype(如果 PyTorch 版本支持)或 torch.float32 占位。如果你的实现里 scale 是 BF16,per-block scale 的精度不足,QAT 训练会有大量噪声。

坑三:STE 的反向梯度过大

直接透传梯度(STE)在某些极端情况下会让梯度爆炸——量化前后差异 10x 时,梯度仍按原值传递。V4 在 backward 里通常会做 gradient clipping 防止爆炸。

坑四:训练初期的不稳定

QAT 在训练开始就引入量化误差,初期 loss 曲线比 PTQ 训练曲线更陡。建议有一个 warm-up 阶段(前几 K 步只用 BF16,不开启 QAT),然后逐步打开 QAT。V4 的训练曲线(V3 论文中描述类似机制)也用了 warm-up。


14.7 与 PyTorch 的 quantization 工具对比

把 V4 的 QAT 与 PyTorch 官方 quantization 工具对比:

工具支持精度per-blockgroupedDeepSeek V4 兼容
torch.ao.quantizationINT8per-tensor / per-channel
torch.float8_e4m3fnFP8per-tensor部分(需要外加 scale)
torchao FP8FP8per-tensor / per-row部分
DeepGEMM act_quantFP8 + FP4per-block ue8m0

V4 的 QAT 工具链与 PyTorch 主流量化工具有显著差异——V4 走的是”DeepGEMM 自带 + 在 model.py 里直接调用”的特化路线,不复用 torch.ao.quantization。

这种特化让 V4 的 QAT 性能极好,但也意味着:如果你想把 V4 的 QAT 思路移植到其他框架(如 transformers),需要自己实现一套 act_quant


14.8 动手实验:fake_quant 模拟

import torch

def fake_quant_fp8(x, block_size=128):
    """模拟 V4 的 act_quant:FP8 e4m3 + ue8m0 per-block scale"""
    shape = x.shape
    flat = x.flatten()
    blocks = flat.view(-1, block_size)
    
    # ue8m0 scale: 2^ceil(log2(amax / fp8_max))
    fp8_max = 448.0  # FP8 e4m3 的最大值
    amax = blocks.abs().max(dim=1, keepdim=True).values
    scale_exp = torch.ceil(torch.log2((amax / fp8_max).clamp(min=1e-30)))
    scale = torch.pow(2.0, scale_exp)

    # 量化:把 blocks / scale 截断到 FP8 网格
    scaled = blocks / scale
    # 简化:用 round 模拟
    quantized = torch.round(scaled * 128) / 128  # 简化的 FP8 网格
    quantized = quantized.clamp(-fp8_max, fp8_max)

    # 反量化
    dequantized = quantized * scale
    return dequantized.view(shape)


# 对比 BF16 与 fake_quant 的精度损失
torch.manual_seed(0)
x = torch.randn(2048, 2048) * 5
x_fq = fake_quant_fp8(x, block_size=128)
err = (x - x_fq).abs().mean().item()
rel_err = err / x.abs().mean().item()
print(f"Mean abs error: {err:.5f}")
print(f"Relative error: {rel_err*100:.3f}%")

# 对比矩阵乘法的精度
y = torch.randn(2048, 1024) * 5
y_fq = fake_quant_fp8(y, block_size=128)

z_bf16 = x @ y
z_qat = x_fq @ y_fq

mm_err = (z_bf16 - z_qat).abs().mean().item()
mm_rel_err = mm_err / z_bf16.abs().mean().item()
print(f"\nGEMM error after FP8 fake quant:")
print(f"Mean abs error: {mm_err:.5f}")
print(f"Relative error: {mm_rel_err*100:.3f}%")

跑这段代码:典型 FP8 fake quant 的相对误差约 0.5-1.0%,传到 GEMM 后误差约 1-2%。这是 V4 训练时模型实际感知的”量化噪声”。


14.8·补 QAT 与精度退化的攻防

V4 的 QAT 不只是”训练时模拟量化”那么简单——它在训练全程参与一场与精度退化的”攻防战”。把这场攻防的几个关键回合摆出来:

攻防回合一:FP4 expert 的精度地板

FP4 e2m1 只有 16 个离散值——任何 expert weight 在 FP4 下的”最低精度地板” 已经被硬件锁死。QAT 能做的是让模型主动避免落入 FP4 的”精度盲区”——比如让 weight 不要出现需要”在 0.5 与 1.0 之间”精确表达的中间值,而是把 weight 学得”明显偏向某一侧”。

具体做法是:训练时每次 forward 都把 weight 反量化(dequant),用 dequant 后的近似 weight 算 forward。模型感受到”我的 weight 实际上只有这么几个离散值”,就学会把决策边界放到这些离散点之间,避开盲区。

攻防回合二:activation 的动态范围控制

激活值在不同 layer 的数值范围差异巨大——某些 layer 的 activation 可能在 [-1, 1],某些可能在 [-100, 100]。FP8 e4m3 的最大值约 448,但配合 ue8m0 scale 可以覆盖任意范围——前提是 scale 算得正确。

QAT 的工作之一是让 activation 的范围保持稳定——不要让某些 layer 的 activation 突然跨数量级。具体表现是 SwiGLU 的 swiglu_limit=10、HC 的 Sinkhorn 归一化、attn_sink 的可学约束——这些机制一起把 activation 的范围控制在 FP8 友好的区间。

攻防回合三:scale 的”温度调节”

ue8m0 scale 是 2 的整数幂——离散度比”任意 FP32 scale” 大。QAT 训练中,每个 block 的 scale 选择是离散决策——scale = 2^ceil(log2(amax / fp8_max))。这种离散决策会让训练曲线出现”阶梯感”——某个 batch scale 跳一级,loss 突然抖动。

V4 的应对是block_size 的设计:128×128 的 block 内 amax 已经被多个 element 的统计稳定化——不会因为单个 outlier 让整 block 的 scale 跳级。这是为什么 block_size=128 而非 32 的工程理由之一。

攻防回合四:训练后期的”放大效应”

QAT 训练的后期(cooldown 阶段),lr 衰减到很小——这时候每次更新的 weight 变化只有量化网格的 1/100 量级。这种小变化在反量化后可能完全消失——梯度更新”被量化网格吸收”。

V4 的应对是 Muon 的 spectral norm 约束——它让 weight 的更新方向更”显著”,即便幅度小也能逃出量化网格的吸收。这与传统 AdamW + 量化的”小 lr 后期失效” 形成对比。

这四个攻防回合让 V4 的 QAT 在 1.6T 规模下保持稳定——任何一个环节松懈,模型就会塌陷到”低精度训出来的低质量”。


14.8·补·补 QAT 与”训练-推理一致性”的工程承诺

V4 走 QAT 还有一个更深远的工程动机:训练-推理的 bit-level 一致性

什么是 bit-level 一致性?意思是:模型在训练时见到的精度、推理时跑出来的精度,精确到 bit 都相等。这听起来是个理所当然的需求,但在传统 PTQ 模型里是不成立的——PTQ 模型训练时是 BF16,推理时是 FP8,两者的数值误差从训练完成那一刻起就开始”偏移”。

V4 的 QAT 工程承诺:训练时 forward 路径上每个 op 的精度 = 推理时该 op 的精度。这意味着:

  • 你在 V4 训练 loss 上看到的数字,与你在 vLLM 部署 V4 后跑同样输入得到的数字,差异完全可控
  • 你做 fine-tune 时如果改了某层 weight,可以精确预测推理时的输出变化
  • A/B test 不同训练配置时,差异不会被”训练-推理精度漂移”污染

这种 bit-level 一致性是工程化大模型训练栈的”金标准”——大多数公司做不到,因为它需要训练框架与推理框架共享底层 op。V4 通过把 DeepGEMM 同时放在训练和推理路径,达到了这个标准。

如果你的项目想从 V4 借鉴 QAT 思路,最关键的不是 act_quant 函数本身——而是确保训练和推理的 op 实现完全一致。这是 V4 最难复制的工程能力之一。


14.8·延展 QAT 与 LLM 评估的”精度漂移”检测

V4 的 QAT 让训练-推理精度对齐,但生产中仍需持续监控”是否真的对齐”。这部分监控属于评估工程的范畴——与《LLM 评估工程》卷有直接关联。

精度漂移的来源

  • 软件层:vLLM / DeepGEMM 升级时算子实现可能有微小差异
  • 硬件层:不同 GPU 批次的 FP4 / FP8 行为可能略有差异
  • 数据层:用户输入分布与训练分布不同,激活值范围偏移

漂移检测方法

  • golden test:选 100-1000 个 fixed prompt + reference output,定期跑一遍对比 token-by-token 差异
  • distribution test:监控生产输出的统计分布(token 频率、长度、log-prob),与训练时 reference 对比
  • shadow test:把一小部分流量同时跑 V4 production 和 V4 reference(FP32)实例,对比差异

当检测到漂移时的应对

  • 小漂移(<0.1% token-level diff):通常无害,不必处理
  • 中漂移(0.1%-1%):可能是软件版本不兼容,回滚或修复
  • 大漂移(>1%):必须排查根本原因——可能是部署配置错(如精度模式被误设)

这套精度漂移检测工程在 V3 时代已经成熟,V4 沿用。具体的 evaluation pipeline 实现细节(指标定义、自动化跑批、告警阈值)会在 evals 卷的”训练-推理一致性测试”章节展开。


14.9 延伸阅读

  • 量化感知训练综述(arXiv:2103.13630):QAT 经典理论
  • Straight-Through Estimator(arXiv:1308.3432):STE 的源头
  • DeepSeek-V3 报告中关于 FP8 训练的章节
  • 本书第 12 章:FP4 / FP8 / ue8m0 格式细节
  • 本书第 13 章:DeepGEMM 内部如何处理 fake quant 后的 GEMM
  • 本书第 17 章:Muon 优化器如何与 QAT 配合

14.9·补 QAT 训练栈的 5 个工程教训

V4 的 QAT 把”训练时模拟推理量化” 推到了 1.6T 规模。从源码可以反推出 V4 团队踩过的几个工程教训——任何想复刻 QAT 的项目都值得借鉴:

教训 1:QAT 必须从训练第一步就启用

某些团队尝试”先 BF16 训练,后期切 QAT”——结果是 weight 已经走到 BF16 的某个局部最优,强行切到 FP8/FP4 后陷入塌陷。V4 选择从训练第一步就 QAT——让 weight 一开始就在量化网格上演化,避免局部最优。

教训 2:scale 计算必须可微

ue8m0 scale 是离散的(2 的整数幂)——直接 ceil 计算是不可微的。V4 大概率用了”ceil 的 STE”——前向用 ceil、反向假装可微。如果你的实现里 scale 计算”完全 round 死”,反向梯度会被截断——某些参数训不动。

教训 3:block 边界的 scale 选择很关键

每 128×128 块独立 scale。如果两个相邻块的 scale 差异极大(如 2^4 vs 2^-4),block 边界的 weight 会因为 scale 跳变而出现”discontinuity” ——影响 GEMM 累加精度。V4 通过”训练时让相邻块的 amax 接近”来缓解——具体机制不明,但实测分布显示相邻块 scale 差异通常 ≤2^2。

教训 4:fake quant 的 in-place 操作必须配合 autograd hook

V4 的 fp4_act_quant 是 in-place 操作。PyTorch 的 autograd 默认禁止有梯度张量的 in-place 修改——除非你用自定义 Function + STE。这是 V4 的 kernel 实现里专门处理的——如果你照搬 V4 源码而没配套 autograd hook,反向会报 “modified in-place” 错误。

教训 5:精度漂移监控比”训练 loss 监控”更重要

QAT 的 loss 可能看起来正常下降,但精度在悄悄退化(前向输出与 BF16 reference 偏差增大)。V4 的训练大概率有一套独立的”精度对比” 监控——每隔 N step 跑同一组输入在 BF16 reference 上对比 QAT 输出的 KL 散度。这个指标反映”QAT 是否真的在学量化鲁棒性”——比 loss 更直接。

这 5 个教训是 V4 / V3 时代 QAT 工业化的”血泪经验”——任何后续做 1T+ 规模 QAT 的团队都要重新踩这些坑(除非他们直接借鉴 V4 的经验)。


14.9·补·补 QAT 工程师速记

关键概念

  • QAT = 训练时模拟推理的量化误差
  • act_quant:把激活动态量化到 FP8 / FP4
  • STE:Straight-Through Estimator,反向梯度透传
  • in-place fake quant:节省显存的量化-反量化对

V4 中 act_quant 调用点

  • KV cache 写入:FP8 + block 64
  • linear 输入:FP8 + block 128
  • Indexer query:FP4 + block 32
  • Compressor (rotate=True):FP4 + block 32

调试 QAT 的 5 个症状

  1. loss 早期 NaN —— scale 计算溢出
  2. 训练曲线”阶梯感” —— scale 跳级,正常
  3. 推理输出与训练 reference 偏差大 —— QAT 没真正生效,回查 act_quant 是否启用
  4. fine-tune 后量化误差变大 —— fine-tune 时关掉了 QAT
  5. 某些层精度异常 —— 该层的 block_size 与 GEMM tile 不对齐

与 PyTorch 官方 quantization 的差异

  • V4 用 DeepGEMM 自带 act_quant,不用 torch.ao.quantization
  • per-block ue8m0 scale,不是 per-tensor
  • in-place 操作(autograd 友好需要自定义 Function)

生产监控指标

  • 推理输出与 BF16 reference 的 KL 散度(应稳定 < 0.01)
  • 每层 act_quant 的 amax 分布(应稳定,无突跳)
  • 部署后的精度漂移(每周跑 golden test)

14.9·延展 QAT 与”反向传播误差累积” 的边界

V4 在 1.6T + 32T tokens 的训练规模下,QAT 的反向传播误差累积是个真实问题。把这个边界讲清楚。

误差累积来源 1:STE 的不精确

Straight-Through Estimator 把 round 的反向梯度透传——前向数值已经被 round,反向梯度按”未 round” 算。这个不一致让某些极端方向的梯度估计有偏差。

在小模型上偏差可以忽略——大模型 + 长训练后偏差累积成可见的精度损失。

误差累积来源 2:FP4 / FP8 GEMM 的累加误差

虽然 GEMM 内部用 FP32 累加,但累加几千次 token 的输出后,FP32 自身的舍入误差也会累积。

误差累积来源 3:跨 layer 误差放大

第 1 层的小误差经过 60 层会被放大——理论上约 1.0001^60 ≈ 1.006,即 0.6% 放大。这是 dense 模型的情况;sparse + HC + 多种精度的复杂结构里,误差放大可能是 2-5%。

V4 的 4 个对抗机制

  1. HC 的 Sinkhorn 双随机性:让 4 路 hidden 的能量守恒——抑制误差累积
  2. per-layer 精度复位:每层 RMSNorm 把 hidden 重新归一——切断误差链
  3. attn_sink 的兜底:稀疏 attention 选错时不会让数值崩溃
  4. swiglu_limit 的 clip:防止 SwiGLU 输出爆炸

这 4 个机制让 V4 的误差累积在可接受范围(最终模型推理 ≈ BF16 reference 的 99%+)。

理解这个边界的意义

如果你想从 V4 借鉴 QAT 设计到自己的模型,必须同时引入这 4 个对抗机制——否则 QAT 在大模型上会失败。光有 act_quant 不够,必须配套架构 + 数值兜底。


14.10 本章小结

  • V4 用 QAT 而非 PTQ——训练时就让模型”看到”量化误差
  • act_quant / fp4_act_quant 在 forward 路径上做”假量化”(量化 + 反量化),backward 用 STE 透传梯度
  • V4 源码里 5 个 act_quant 调用点:KV、Q、Compressor 双路径、linear 输入——每个点的 block_size 与对应 GEMM tile 对齐
  • in_place=True 节省显存——配合自定义 autograd 函数兼容梯度
  • QAT 调试 4 个常见坑:粒度不对齐、scale 类型错、STE 梯度爆炸、初期不稳定
  • V4 的 QAT 工具链与 PyTorch 官方 quantization 工具差异显著——是 DeepGEMM 自带的特化方案

第 15 章我们离开精度链路,进入 V4 的分布式工程:Tensor / Expert 并行的实现细节。

评论 0