第12章 FP4 e2m1 + FP8 e4m3:混合精度的几何
“Precision is not a single dial. It’s a mosaic of dials, each turned just right.” —— NVIDIA 一位资深 mixed-precision 工程师
V4 的精度配方有 5 个维度:weight 精度、activation 精度、scale 精度、block 大小、计算累加精度。每个维度都被独立调到最优。
12.1 引子:从 BF16 到 FP4 的精度阶梯
把 V4 涉及的浮点格式列在一张表上:
| 格式 | 总位数 | 符号 | 指数 | 尾数 | 动态范围 | V4 中的位置 |
|---|---|---|---|---|---|---|
| BF16 | 16 | 1 | 8 | 7 | ~1e±38 | 激活默认 / shared expert |
| FP8 e4m3 | 8 | 1 | 4 | 3 | ~1e±2 (实际) | linear 权重 / GEMM 输入 |
| FP8 e5m2 | 8 | 1 | 5 | 2 | ~1e±15 | (V3 用过,V4 主要用 e4m3) |
| FP4 e2m1 | 4 | 1 | 2 | 1 | [-6, 6] 离散 | routed expert 权重 |
| ue8m0 | 8 | 0 | 8 | 0 | 2^-127 ~ 2^128 | scale 张量 |
| FP32 | 32 | 1 | 8 | 23 | ~1e±38 | 累加器 / norm 计算 |
V4 的精度策略是”精度阶梯化”——不同算子用不同精度,让”精度成本”最优分配。具体地:
- 存储用极低精度(FP4 expert / FP8 linear)—— 把 1.6T 参数压到几百 GB
- 计算用 FP32 累加—— 保证乘加结果的精度
- 激活值默认 BF16—— 在每个算子之间保持稳定
- scale 用 ue8m0—— 一种”纯指数”格式,记录每块的缩放因子
flowchart LR
subgraph Storage["存储 (省内存)"]
FP4w["expert weight: FP4 e2m1"]
FP8w["linear weight: FP8 e4m3"]
UE8M0["scale: ue8m0"]
BF16Act["activation: BF16"]
end
subgraph Compute["计算 (保精度)"]
Quant["act_quant: BF16 → FP8/FP4"]
GEMM["FP8/FP4 GEMM<br/>累加器: FP32"]
Dequant["输出: BF16"]
end
subgraph Numeric["数值稳定"]
Norm["RMSNorm: FP32"]
Soft["softmax: FP32"]
Sink["attn_sink: FP32"]
end
Storage --> Compute --> Numeric
12.2 FP4 e2m1 的位级表示
FP4 e2m1 总共 4 位:
[ 1 sign | 2 exponent | 1 mantissa ]
可以表示的所有值:
| 二进制 | 数值 | 二进制 | 数值 |
|---|---|---|---|
| 0000 | +0 | 1000 | -0 |
| 0001 | +0.5 | 1001 | -0.5 |
| 0010 | +1.0 | 1010 | -1.0 |
| 0011 | +1.5 | 1011 | -1.5 |
| 0100 | +2.0 | 1100 | -2.0 |
| 0101 | +3.0 | 1101 | -3.0 |
| 0110 | +4.0 | 1110 | -4.0 |
| 0111 | +6.0 | 1111 | -6.0 |
只有 16 个离散值。最大正值 6.0,最小非零正值 0.5——动态范围 [0.5, 6.0]。
FP4 的工程意义:
- 单参数仅占 0.5 字节——把 1.6T 参数压到 800 GB(实际 V4 用 fp4_e2m1fn_x2 打包,2 个 fp4 = 1 个字节)
- 单参数动态范围有限(0.5 ~ 6)——必须配合 per-block scale 才能表达大数值
- 离散度大(相邻值差异 0.5)——直接量化精度损失约 12-25%
V4 用 FP4 存 routed expert 权重——因为 384 个 expert 的”稀疏激活”摊销了这部分精度损失。
12.3 FP8 e4m3 的位级表示
FP8 e4m3 总共 8 位:
[ 1 sign | 4 exponent | 3 mantissa ]
理论动态范围 [2^-6, 2^8] ≈ [0.016, 256],最大值约 448(NVIDIA 标准定义)。
V4 把 FP8 e4m3 用于 linear 权重和 GEMM 输入。具体配置(来自 config.json):
"quantization_config": {
"fmt": "e4m3",
"scale_fmt": "ue8m0",
"weight_block_size": [128, 128],
"activation_scheme": "dynamic"
}
weight_block_size = [128, 128]:每 128×128 的权重块共用一个 ue8m0 scale。
activation_scheme = dynamic:激活值的 scale 在每次 forward 时动态计算(不是预定的)。
为什么 V4 选 e4m3 而不是 e5m2?
- e4m3 的尾数多 1 位(3 vs 2),精度更高
- e5m2 的指数多 1 位,动态范围更大但精度低
- Linear 权重的数值分布相对集中(不像梯度那样跨多个数量级),精度比动态范围更重要
- 配合 ue8m0 per-block scale,e4m3 的有限动态范围被 scale 补偿
V3 在反向梯度上用 e5m2(梯度跨数量级大),V4 在 inference 时只用 e4m3(前向激活相对集中)。
12.4 ue8m0:纯指数 scale
ue8m0 是一种特殊的 8 位浮点:
[ 8 exponent | 0 mantissa ]
无符号、无尾数、纯指数——只能表示 2 的幂:2^-127 到 2^128。
ue8m0 的设计哲学:scale 的作用是”缩放数值的量级”,不需要精度——只需要”指数级”的覆盖范围。把 8 位全用来存指数,恰好覆盖 IEEE 754 float32 的指数范围。
V4 的工程意义:
- 每个 weight block 一个 ue8m0 scale(1 字节)
- 每 128×128 块只多 1 字节——overhead 0.0006%
- ue8m0 与 FP32 之间的转换是简单的位移,硬件实现极快
- ue8m0 × FP4 / FP8 的”反量化”是位级 add(指数相加)
flowchart LR
subgraph 存储["磁盘 / GPU 显存"]
FP4Weight["FP4 e2m1: 4 bit"]
UE8M0Scale["ue8m0 scale: 8 bit (per 128 elements)"]
end
subgraph 加载["加载到 SMEM"]
Combined["反量化: weight × 2^scale"]
end
subgraph 计算["TensorCore"]
BF16Activation["BF16 activation"]
GEMM["fp4_gemm 或 dequant + bf16 GEMM"]
end
FP4Weight --> Combined
UE8M0Scale --> Combined
Combined --> GEMM
BF16Activation --> GEMM
12.5 block-wise scale 的几何
weight_block_size = [128, 128] 的几何含义:把权重矩阵分成 128×128 的块,每块共享一个 scale。
考虑 V4 的一个 expert 矩阵 w1 / w3:[hidden=7168, inter_dim=3072]。划分成块:
- 行方向:7168 / 128 = 56 块
- 列方向:3072 / 128 = 24 块
- 总块数:56 × 24 = 1344 块
- 每块 1 个 ue8m0 scale = 1344 字节 = ~1.3 KB
权重总大小:7168 × 3072 × 0.5 字节 (FP4) = 11 MB Scale 总大小:1.3 KB Scale overhead = 0.012%
这种 per-block scale 的关键好处:不同块可以有不同的 scale 范围——某块全是大数值(scale=2^4),相邻块全是小数值(scale=2^-4)——两块独立缩放,不会互相干扰。
如果用 per-tensor scale(整个矩阵一个 scale),FP4 的有限范围 [0.5, 6] 必须覆盖矩阵中所有数值——只要矩阵里有一个异常大值,所有其他值就会被压扁到 0。per-block 完美解决这个问题。
为什么是 128 而不是 32 或 256?
- 32:scale overhead 上升到 0.05%——可承受但偏高
- 256:scale 数量减半,但每块跨度大,可能导致块内异常值
- 128:与 GPU 的 TensorCore tile size 对齐(H100 / B200 的 WGMMA tile 通常是 64 或 128),硬件友好
12.6 V4 内部各模块的精度分配
把 V4 的精度分配画成表:
| 模块 / 张量 | 存储精度 | 计算精度 | 累加精度 | 备注 |
|---|---|---|---|---|
| Embedding 权重 | BF16 | - | - | 词表小,不量化 |
| Q/K/V Linear 权重 | FP8 e4m3 | FP8 GEMM | FP32 | weight_block 128×128 |
| Routed Expert 权重 | FP4 e2m1 | FP4 GEMM | FP32 | scale block 32 |
| Shared Expert 权重 | BF16 | BF16 GEMM | FP32 | 不量化(精度敏感) |
| Compressor wkv/wgate | FP32 | - | - | 数量小,保 FP32 |
| HC hc_attn_fn 等 | FP32 | - | - | Sinkhorn 精度敏感 |
| Indexer wq_b | FP4 e2m1 | FP4 GEMM | FP32 | 与 expert 同精度 |
| Indexer weights_proj | BF16 | BF16 GEMM | FP32 | 显式指定 BF16 |
| LM Head | BF16 | BF16 GEMM | FP32 | 输出精度敏感 |
| Activation (默认) | BF16 | - | - | 进 GEMM 时 act_quant |
| KV Cache | BF16 + FP8 | - | - | rope 部分 BF16, nope FP8 |
| RMSNorm 内部 | - | FP32 | - | 必须 FP32 防溢出 |
| Softmax 内部 | - | FP32 | - | 防 exp 溢出 |
| RoPE freqs_cis | complex64 | - | - | 1M 接近 float32 极限 |
读这张表能看到 V4 的精度策略:“高频访问 + 数量大 → 用低精度,低频访问 + 关键路径 → 保高精度”。
384 个 routed expert 是数量最大的存储——压到 FP4 节省最多。LM head 是关键路径但参数量相对小(vocab × dim),保 BF16。HC / Compressor / Indexer 的部分参数因为对精度敏感(Sinkhorn 等),保 FP32。
12.7 反量化路径:从 FP4 到 BF16
V4 的 linear 函数处理 FP4 / FP8 / BF16 的统一调度:
def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
assert bias is None
if weight.dtype == torch.float4_e2m1fn_x2:
x, s = act_quant(x, block_size, scale_fmt, scale_dtype)
return fp4_gemm(x, s, weight, weight.scale, scale_dtype)
elif weight.dtype == torch.float8_e4m3fn:
x, s = act_quant(x, block_size, scale_fmt, scale_dtype)
return fp8_gemm(x, s, weight, weight.scale, scale_dtype)
else:
return F.linear(x, weight)
读这段代码的关键:
weight.dtype 决定路径:
- FP4:先把 x 量化到 FP8(不是 FP4!),然后 fp4_gemm 处理”FP8 x FP4 → BF16”
- FP8:把 x 量化到 FP8,然后 fp8_gemm 处理”FP8 x FP8 → BF16”
- 其他:标准 BF16 GEMM
为什么 FP4 weight 配 FP8 activation? 因为:
- FP8 激活已经够低精度了,再压到 FP4 会让”激活的精度损失”超过”权重的精度损失”,得不偿失
- FP4 GEMM 在硬件上其实是把 FP4 “升采样”成 FP8 后做 FP8 GEMM 的——所以 activation 也是 FP8
weight.scale 是个隐藏属性:
# 在 Linear.__init__ 里
self.weight.scale = self.scale = nn.Parameter(...)
V4 把 scale 直接挂到 weight 张量上(weight.scale = ...),让 linear 函数能拿到它。这是个 PyTorch 的非主流用法——通常 attribute 被挂到 Module,但 V4 直接挂到 Parameter,简化了调用。
12.8 与其他低精度方案的对比
把 V4 的精度方案与几个主流方案对比:
| 方案 | Weight | Activation | Scale | 适用模型 |
|---|---|---|---|---|
| GPTQ (INT4) | INT4 | FP16 | per-row | Llama / OPT 推理 |
| AWQ (INT4) | INT4 | FP16 | per-channel | Llama / Qwen |
| FP8 (NVIDIA) | FP8 e4m3 | FP8 e4m3 | per-tensor | H100 / B200 通用 |
| MXFP6 / MXFP4 | FP6 / FP4 | FP6 / FP4 | per-32 | 即将到来 |
| DeepSeek V3 | FP8 e4m3 | FP8 e4m3 | ue8m0 per-128 | V3 训练 + 推理 |
| DeepSeek V4 | FP4 expert + FP8 linear | FP8 | ue8m0 per-128 | V4 训练 + 推理 |
V4 的精度方案在两个维度上独特:
- 混合 FP4 + FP8:唯一在 1.6T 规模工业化的混合方案
- ue8m0 scale:与 micro-scaling format(OCP 标准)兼容,硬件友好
未来 NVIDIA 的 B200 + 后续架构会原生支持 MXFP4 / MXFP6——V4 的方案大致与 OCP 标准对齐,未来硬件的”原生支持”能让 V4 的 throughput 进一步提升。
12.9 动手实验:用 PyTorch 模拟 FP4 量化
import torch
def fake_fp4_quant(x, block_size=32):
"""模拟 FP4 e2m1 量化 + 反量化,看精度损失"""
# FP4 e2m1 的 16 个值
fp4_values = torch.tensor([0, 0.5, 1, 1.5, 2, 3, 4, 6, -0, -0.5, -1, -1.5, -2, -3, -4, -6])
shape = x.shape
x_flat = x.flatten()
blocks = x_flat.split(block_size)
out_blocks = []
for b in blocks:
# 找到 block 的最大绝对值,作为 scale 基准
amax = b.abs().max()
# ue8m0 scale = 2^ceil(log2(amax / 6))
if amax > 0:
scale_exp = torch.ceil(torch.log2(amax / 6.0))
scale = torch.pow(2.0, scale_exp)
else:
scale = torch.tensor(1.0)
# 量化:找最近的 FP4 值
b_scaled = b / scale # 现在 |b_scaled| <= 6
# 找每个元素最近的 fp4 值
diffs = (b_scaled.unsqueeze(-1) - fp4_values.unsqueeze(0)).abs()
nearest = fp4_values[diffs.argmin(dim=-1)]
# 反量化
b_reconstructed = nearest * scale
out_blocks.append(b_reconstructed)
return torch.cat(out_blocks).view(shape)
# 测试
x = torch.randn(128, 128) * 3
x_fp4 = fake_fp4_quant(x, block_size=32)
err = (x - x_fp4).abs().mean().item()
relative_err = err / x.abs().mean().item()
print(f"Mean abs error: {err:.4f}")
print(f"Relative error: {relative_err*100:.2f}%")
跑这个实验:典型的 FP4 量化误差约 8-12% 相对误差。对应 V4 文档里说”FP4 的精度损失被 384 expert 的稀疏激活摊销”——12% 误差被 1/64(每 token 见 6/384 个 expert)稀释,最终对模型输出的影响小很多。
12.9·补 V4 精度策略的”纵深防御”哲学
V4 的精度配方看起来复杂——5 种格式、不同 block 大小、不同模块用不同精度。这种复杂背后是一种”纵深防御” 的工程哲学:不让任何单个精度策略承担全部抗误差责任。
把这种纵深防御的层次拆开看:
第一层防御:FP4 / FP8 的 per-block scale
per-block scale 让每个 128×128 块独立缩放——某块的 outlier 不会拖累其他块。这是最基础的防御层,覆盖 99% 的精度损失。
第二层防御:QAT 让模型自适应量化误差
QAT 让模型在训练中”看到” FP4/FP8 的精度损失,主动学习”避开量化盲区”。这层防御覆盖了”数据分布特殊导致 per-block scale 不够”的情况。
第三层防御:FP32 累加 + BF16 中间
GEMM 内部累加用 FP32——避免低精度反复累加导致的舍入误差爆炸。结果输出 BF16——保留下游计算所需的精度。这层防御覆盖了”长链 GEMM 累积误差” 的问题。
第四层防御:HC / Sinkhorn / RMSNorm 的 FP32 计算
HC 的混合矩阵、Sinkhorn 的归一化、RMSNorm 都强制用 FP32——这些”算法关键点”不允许任何精度损失。这层防御兜住了”低精度让算法本身崩溃”的极端情况。
第五层防御:attn_sink / swiglu_limit / 数值 clamp
模型架构层面的”硬约束”——attn_sink 防止 attention 数值崩溃、swiglu_limit 防止 FFN 爆炸。这层是”最后一道墙”——即便前 4 层都失效,模型也不会塌陷。
5 层防御互相独立、各司其职。任何单层失效,其他层兜底——这就是为什么 V4 在 FP4 expert 这种极端低精度下仍能保持 1.6T 模型的表达力。
12.9·补·补 V4 精度方案与硬件演进的对齐
V4 的精度选择不是凭空设计的——它与 NVIDIA 硬件的演进路线深度对齐。把对齐关系摆出来:
对齐 1:FP8 e4m3 与 H100 / B200 的 TensorCore 原生支持
H100 是第一个原生支持 FP8 TensorCore 的 GPU。NVIDIA 的 FP8 标准是 e4m3 + e5m2 双格式——V4 选 e4m3(用于推理 forward)正好在 H100 / B200 的”快路径”上。如果 V4 选了非标准 FP8 格式,会失去硬件加速。
对齐 2:FP4 e2m1 与 B200 的原生支持
B200 是 NVIDIA 第一个原生支持 FP4 TensorCore 的 GPU(2025 年发布)。V4 在 2026 年发布——刚好赶上 B200 的成熟期。这种”模型 + 硬件” 的发布时机对齐让 V4 在 B200 上能跑出最优 FP4 性能。
对齐 3:ue8m0 scale 与 OCP Microscaling 标准
OCP(Open Compute Project)2024 年发布了 Microscaling Format 标准——定义了 MXFP4 / MXFP6 等”低精度数据 + 共享 scale” 的格式标准。V4 的 FP4 + ue8m0 与这个标准几乎兼容——意味着未来支持 OCP 标准的所有硬件(不止 NVIDIA)都能跑 V4。
对齐 4:block_size=128 与 H100 / B200 的 SMEM / TensorCore tile
H100 的 WGMMA 指令支持 128 列的 tile size。B200 进一步扩展到 128/256 灵活 tile。V4 的 weight_block_size=128 与硬件 tile 完美对齐——每个 GEMM tile 用 1 对 scale,硬件层面无浪费。
对齐 5:BF16 activation 与所有现代 GPU
BF16 是 A100 / H100 / B200 / 甚至 AMD MI300 的标准激活精度。V4 的激活默认 BF16 让模型在所有现代硬件上都”能跑”——只是 FP4 / FP8 路径在不同硬件上速度不同。
这 5 个对齐让 V4 在硬件演进的当下时刻处于”最优位置”——现在的硬件直接受益、未来的硬件因为标准对齐自然受益。这是工程设计的”长期主义”——不为短期方便牺牲与硬件标准的对齐。
12.9·延展 PyTorch 底层与 V4 精度的接口
V4 的 FP4 / FP8 / ue8m0 在 PyTorch 端依赖几个相对新的 dtype 与算子。把这些依赖梳理出来:
PyTorch dtype 依赖:
torch.float8_e4m3fn:PyTorch 2.0+ 引入。V4 用于 linear weight。torch.float8_e5m2:V3 用过、V4 主要不用。torch.float4_e2m1fn_x2:PyTorch 2.4+ 引入。V4 用于 routed expert weight。注意_x2后缀——表示”两个 FP4 打包到 1 字节”。torch.float8_e8m0fnu:V4 的 ue8m0 scale 类型。PyTorch 较新版本支持。
PyTorch 算子依赖:
torch.view_as_complex/torch.polar:用于 RoPE 的复指数计算torch.distributed.all_reduce等通信原语:与 V4 的并行类配合nn.functional.silu/nn.functional.softplus:SwiGLU 与 sqrtsoftplus 的基础
版本要求:V4 的 inference/model.py 顶部 transformers_version = 4.57.1 ——意味着官方测试用 transformers 4.57.1。PyTorch 版本要求约 2.4+(为 FP4 dtype)。如果你 PyTorch 版本太老,FP4 dtype 会报错。
未来兼容:PyTorch 在 OCP MX 标准成熟后会引入 torch.float6_e3m2 / torch.float6_e2m3 等新格式——V4 是为 OCP MX 标准设计的,未来 PyTorch 演进对 V4 完全兼容。
如果你想深入理解这些 dtype 在 PyTorch 内部如何实现(特别是 _x2 打包格式、ue8m0 的位级表示),可以参考《PyTorch 内核源码剖析》中关于”低精度数据类型” 的章节——那一卷会从 PyTorch 的 c10::ScalarType 一路讲到 ATen 的 FP4/FP8 dispatch。
12.10 延伸阅读
- OCP Microscaling Formats(ocp-microscaling-formats):MXFP4 / MXFP6 标准
- FP8 Format(NVIDIA whitepaper):H100 的 FP8 格式
- GPTQ 论文(arXiv:2210.17323):INT4 后训练量化
- AWQ 论文(arXiv:2306.00978):activation-aware 量化
- 本书第 13 章:DeepGEMM 内部如何做 FP4/FP8 GEMM
- 本书第 14 章:QAT 训练时的 act_quant 路径
12.10·补 V4 精度方案的”迁移到其他模型”指南
V4 的精度方案(FP4 expert + FP8 linear + ue8m0 + 128 block)是为它的特定结构设计的。如果你想把这套方案迁移到其他模型,几个关键点:
迁移点 1:weight 是否符合 block 量化假设
V4 的所有 weight 维度都是 128 倍数——天然 block 友好。如果你的模型 weight 维度不是 128 倍数(如 hidden=4096 但 inter_dim=11008),block_size 必须用最大公约数(在这个例子是 32)——精度损失增加。
迁移点 2:是否有对应规模的 expert
FP4 适合 expert——因为 384 expert 的稀疏激活摊销精度损失。如果你的模型只有 8 个 expert(如 Mixtral),FP4 损失会被显著放大,反而不如 FP8。
迁移点 3:训练成本
V4 的 QAT 是从预训练第一天就启用——大模型 from-scratch 训练才能用 QAT 把精度损失训”消化”。如果你做的是 fine-tune(小数据),不能从 BF16 切到 FP4 ——必须用 PTQ,但 PTQ 在 FP4 上几乎肯定塌陷。
迁移点 4:硬件支持
H100 / B200 是 V4 精度方案的最优硬件。如果部署到 A100(无 FP8 / 无 FP4 原生),需要软件模拟——性能损失可能 50% 以上。
迁移点 5:DeepGEMM 依赖
V4 精度的高效实现依赖 DeepGEMM 库——cuBLAS 不支持。如果你的项目不能引入 DeepGEMM 作为依赖,迁移时需要自己写或者放弃这套精度方案。
总的迁移建议:V4 精度方案最适合”V4-like 模型 + H100/B200 + DeepGEMM 可用” 三个条件同时满足的场景。否则建议用更通用的 FP8 e4m3 + per-tensor scale 方案,损失部分性能换取通用性。
12.10·补·补 一段精度配方的”代价对账”
V4 的精度配方看似”白送性能”——FP4 expert + FP8 linear 让显存与 FLOPs 都减小,似乎没有代价。但工程上没有免费午餐。把代价算清楚:
代价 1:训练复杂度
QAT 全程参与 + Muon 优化器 + 多种 dtype 并存——训练栈的工程复杂度比 BF16 训练高 3-5x。需要专门的 ML infra 团队维护。
代价 2:硬件锁定
V4 的精度配方在 H100 / B200 上最优——A100 / 消费级 GPU 上要么慢、要么跑不了。这种硬件锁定让”用旧硬件部署 V4” 困难。
代价 3:调试难度
低精度引入的数值问题不容易排查——某个 expert 输出异常可能是 FP4 反量化错、可能是 ue8m0 scale 错、可能是 act_quant 块大小错。每种错的症状相似但根因不同。
代价 4:fine-tune 限制
如前文(§12.10·补)所述,fine-tune 不能简单地从 BF16 切换到 FP4——必须保持训练时的精度配方。这限制了”用 BF16 fine-tune 然后量化部署” 这种常见 workflow。
代价 5:版本耦合
V4 的精度依赖 DeepGEMM / FlashMLA 的特定版本——这两个库升级时必须保证 V4 仍能跑。这种版本耦合让”用最新工具” 与”保持模型稳定” 之间有 trade-off。
理解这些代价让你做”V4 vs 其他模型” 的选择更清醒——不是”V4 的精度配方更先进就该用 V4”,而是”V4 适合愿意付这些代价的场景”。
如果你的项目对硬件 / 工具链 / 调试都要求灵活性,可能 BF16 dense 模型(即便能力略弱)反而是更好选择。
12.10·延展 V4 精度配方在 inference 与 training 之间的”对称性”
V4 的精度配方有一个常被忽略的特性:inference 与 training 的精度路径完全对称。这种对称性是工程纪律的体现。
对称点 1:相同的 dtype 集合
- weight:FP4 expert + FP8 linear(训练 / 推理一致)
- activation:BF16 默认(训练 / 推理一致)
- scale:ue8m0(训练 / 推理一致)
- 累加器:FP32(训练 / 推理一致)
对称点 2:相同的 act_quant 调用点
训练时 forward 路径上的每个 act_quant 调用,推理时也有完全相同的调用——位置、block_size、dtype 都一致。
对称点 3:相同的 GEMM 实现
DeepGEMM 同时服务训练和推理——训练时 GEMM 用 fp8_gemm,推理时也用。两者底层是同一个 CUDA kernel。
对称点 4:相同的 RMSNorm / Sinkhorn
数值稳定相关的 op(RMSNorm、Sinkhorn)在训练 / 推理都用 FP32 计算,避免任何精度差异。
为什么这种对称性重要:
- 零分布漂移:训练时见到的精度 = 推理时见到的精度,模型不会因部署而行为变化
- 可复现的 bug:训练时遇到的数值问题在推理时一致出现,便于调试
- fine-tune 的可预测性:fine-tune 用同样的精度配方,结果对部署可预测
对称性的代价:
如果你想在不同设备上部署不同精度(如 H100 用 FP8、A100 用 BF16),这种对称性会被破坏——A100 部署的模型行为与 H100 训练版本会有差异。V4 选择”保持对称”——意味着它锁定在 H100/B200 平台上,不下沉到老硬件。
这种”为了对称放弃跨硬件灵活性” 的工程权衡是 V4 的特点——它把质量保证置于跨硬件兼容性之上。
12.11 本章小结
- V4 用 5 维精度策略:weight (FP4/FP8) + activation (BF16/FP8) + scale (ue8m0) + block (128×128) + 累加 (FP32)
- FP4 e2m1 仅 16 个离散值,配合 per-block scale 才能表达大动态范围
- FP8 e4m3 比 e5m2 精度高、动态范围小——配合 scale 后整体范围能覆盖
- ue8m0 是纯指数 scale,硬件实现极快(位级 add 指数)
- block 128×128 与 H100/B200 的 TensorCore tile size 对齐
- 精度阶梯化:高频大量数据用低精度(routed expert FP4),关键路径保高精度(HC / RMSNorm / softmax 用 FP32)
第 13 章我们进入 DeepGEMM——V4 这套精度配方背后的 CUDA kernel 实现。
评论 0
还没有评论,来说两句吧。
评论加载失败,刷新重试。