第15章 Tensor / Expert 并行:ColumnParallel 与 RowParallel

“A 1.6T model lives across machines. Every forward is a coordinated dance.” —— V3 工程师内部分享

V4 的并行策略不是事后补丁,而是写在每个 Linear 类里的”出生属性”。


15.1 引子:1.6T 模型的部署算式

V4 Pro 总参 1.6T。即便用 FP4 + FP8 混合精度(平均约 0.7 字节 / 参数),权重大小 ~1.1 TB。

单卡 H100 80GB 显然装不下。即使是 H200 141GB 也不行。V4 必须用分布式部署——典型方案:

  • 8 卡 NVLink 节点(H100 / H200):模型权重切到 8 卡,每卡约 140 GB
  • 16 卡多节点(2 × 8 卡 IB 互联):权重切到 16 卡,每卡约 70 GB
  • 32 卡 / 64 卡:通常给极大 batch 或低延迟需求

无论哪种部署,模型必须在张量层面切分——不能简单复制。V4 用两个核心机制:

  1. Tensor Parallel (TP):把每层 Linear 的权重沿 row / column 切到不同卡
  2. Expert Parallel (EP):把 384 个 routed expert 分配到不同卡,每卡持有一部分

V4 源码 (inference/model.py) 通过 world_sizerank 这对全局变量,把切分逻辑直接写在每个并行类里。本章拆这 4 个类。


15.2 全局并行状态

V4 在文件顶部声明:

world_size = 1
rank = 0
block_size = 128

这三个变量在 Transformer.__init__ 被覆盖:

def __init__(self, args: ModelArgs):
    global world_size, rank, default_dtype, scale_fmt, scale_dtype
    world_size = dist.get_world_size() if dist.is_initialized() else 1
    rank = dist.get_rank() if dist.is_initialized() else 0
    ...

V4 用 PyTorch 的 torch.distributed 做通信——world_size = 总 GPU 数、rank = 当前 GPU 编号。这两个值在每个进程里是一致的——通过 torchrunmpirun 启动时由 launcher 设置。

V4 让这两个变量是全局可变,而不是每个 module 持有自己的引用——这种”全局可变变量”的设计违反传统软件工程纪律,但在 V4 这种”模块嵌套深 + 配置一次贯穿全程”的场景下能大幅简化代码。


15.3 ColumnParallelLinear:输出维度切分

class ColumnParallelLinear(Linear):
    """Shards output dim across TP ranks. No all-reduce needed on output."""
    def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
        assert out_features % world_size == 0
        self.part_out_features = out_features // world_size
        super().__init__(in_features, self.part_out_features, bias, dtype)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return linear(x, self.weight, self.bias)

ColumnParallelLinear 的语义:[in, out] 矩阵沿 out 维度切成 world_size。每个 rank 持有 [in, out / world_size] 的局部矩阵。

forward 时:每个 rank 独立计算局部输出 [B, S, out / world_size]——不需要通信

输出维度被切了,怎么办?要么:

  • 后续操作能直接处理切分后的输出(如紧接 RowParallelLinear,正好把切分维度作为输入维度)
  • 或者外部显式 all_gather 重组

V4 的典型用法:q、k、v、wq_b 等都是 ColumnParallelLinear——输出的 head 维度被天然切分到不同 rank,后续 attention 计算完后通过 RowParallelLinear 自然合并。


15.4 RowParallelLinear:输入维度切分

class RowParallelLinear(Linear):
    """Shards input dim across TP ranks. All-reduce on output to sum partial results."""
    def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
        assert in_features % world_size == 0
        self.part_in_features = in_features // world_size
        super().__init__(self.part_in_features, out_features, bias, dtype)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        y = linear(x, self.weight, None)
        if world_size > 1:
            y = y.float()
            dist.all_reduce(y)
        if self.bias is not None:
            y += self.bias
        return y.type_as(x)

RowParallelLinear 的语义:[in, out] 矩阵沿 in 维度切。每个 rank 持有 [in / world_size, out] 的局部矩阵。

forward 时:每个 rank 用局部 in 输入算局部输出,输出形状 [B, S, out](完整 out 维度),但是数值上只是”部分和”。用 all_reduce 把所有 rank 的部分和加起来,得到完整输出

注意 V4 的实现:先转 float32 再 all_reduce、加 bias、最后转回原 dtype。这是为了:

  • all_reduce 的精度敏感——不同 rank 的部分和量级可能不同,FP32 累加更稳
  • bias 在 reduce 之后加,避免被 reduce 重复 N 倍

ColumnParallel + RowParallel 的经典组合

ColumnParallelLinear(in=D, out=H × heads)     # 输出按 head 维度切
   ↓ (输出: [B, S, head/world × D_h], 每个 rank 各持一部分 head)
ColumnParallel attention 内部计算

RowParallelLinear(in=H × heads, out=D)        # 输入按 head 切,输出 reduce
   ↓ all_reduce 后得到完整 [B, S, D]

这套组合在 V4 的 Attention 类里直接可见——wq_b 是 ColumnParallel,wo_a 是 ColumnParallel,wo_b 是 RowParallel。


15.4·补 ColumnParallel + RowParallel 经典组合的张量流

V4 attention 内部的并行流转最典型地体现了这套组合:

flowchart LR
  X["x: [B,S,7168]<br/>每 rank 完整副本"] --> ColLin["ColumnParallel<br/>wq_b: 切 out_features"]
  ColLin --> Q["q: [B,S, n_heads/world × head_dim]<br/>每 rank 不同 head 切片"]
  Q --> Compute["attention 计算<br/>(rank 内独立)"]
  Compute --> O["o: [B,S, n_heads/world × head_dim]<br/>仍是 head 切片"]
  O --> RowLin["RowParallel<br/>wo_b: 切 in_features"]
  RowLin --> Partial["partial_output: [B,S,7168]<br/>每 rank 仅算了部分和"]
  Partial --> AllReduce[("all_reduce<br/>跨 rank 累加")]
  AllReduce --> Final["完整 x: [B,S,7168]<br/>每 rank 都有相同结果"]
  
  classDef parallel fill:#312e81,stroke:#a78bfa,color:#ede9fe
  classDef sync fill:#7c2d12,stroke:#fb923c,color:#ffedd5
  class ColLin,RowLin parallel
  class AllReduce sync

整段 attention 只有最后一次 all_reduce 通信——通信量 O(B×S×D),远小于 attention 内部的 O(B×S×n_heads×head_dim) 计算量。这就是”TP 在 attention 上几乎免费”的工程账。


15.5 ParallelEmbedding:vocab 维度切分

class ParallelEmbedding(nn.Module):
    def __init__(self, vocab_size: int, dim: int):
        super().__init__()
        ...
        assert vocab_size % world_size == 0
        self.part_vocab_size = (vocab_size // world_size)
        self.vocab_start_idx = rank * self.part_vocab_size
        self.vocab_end_idx = self.vocab_start_idx + self.part_vocab_size
        self.weight = nn.Parameter(torch.empty(self.part_vocab_size, self.dim))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if world_size > 1:
            mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx)
            x = x - self.vocab_start_idx
            x[mask] = 0
        y = F.embedding(x, self.weight)
        if world_size > 1:
            y[mask] = 0
            dist.all_reduce(y)
        return y

ParallelEmbedding 的语义:把 vocab 切到不同 rank——rank 0 持有 token id [0, V/8),rank 1 持有 [V/8, 2V/8) …

forward 的关键技巧:

  • 每个 rank 算”自己负责的 token id”的 embedding——不属于本 rank 的 token id 被 mask 成 0
  • dist.all_reduce 把所有 rank 的部分 embedding 累加——因为每个 token id 只有一个 rank 真正算了,all_reduce 实际等于”取那个 rank 的输出”

这种”mask + all_reduce”的写法看起来浪费——为什么不直接用 dist.all_to_all 或者点对点通信?答案是 all_reduce 的硬件支持最成熟、延迟最低——比 all_to_all 快得多。在 V4 这种 vocab=129280、单 token embedding 只有 7168 维的场景下,all_reduce 几乎是零开销。


15.6 ParallelHead:lm_head 的 vocab gather

class ParallelHead(nn.Module):
    def __init__(self, vocab_size: int, dim: int, ...):
        super().__init__()
        ...
        self.part_vocab_size = (vocab_size // world_size)
        self.weight = nn.Parameter(torch.empty(self.part_vocab_size, self.dim, dtype=torch.float32))

    def get_logits(self, x):
        return F.linear(x[:, -1].float(), self.weight)

    def forward(self, x, hc_fn, hc_scale, hc_base, norm):
        x = self.hc_head(x, hc_fn, hc_scale, hc_base)
        logits = self.get_logits(norm(x))
        if world_size > 1:
            all_logits = [torch.empty_like(logits) for _ in range(world_size)]
            dist.all_gather(all_logits, logits)
            logits = torch.cat(all_logits, dim=-1)
        return logits

ParallelHead 与 ParallelEmbedding 对偶:vocab 切到不同 rank、每个 rank 算”自己负责的部分 logits”、最后 all_gather 把所有 rank 的 logits 拼起来。

注意几个细节:

细节 1:只算最后一个 token 的 logits

return F.linear(x[:, -1].float(), self.weight)

x[:, -1] 取每条序列的最后一个 token——LLM 推理通常只关心下一个 token 的 logits。这避免了为序列中间 token 算 logits 的浪费。

细节 2:weight 是 float32

V4 的 lm_head 权重保 FP32(不像 attention / FFN 走 FP4 / FP8)。这是因为 logits 直接决定 next-token 概率,精度损失会被反复采样放大——必须保 FP32。

细节 3:all_gather 而非 all_reduce

vocab 在不同 rank 上是互不相交的——每个 rank 算的 logits 是 vocab 不同切片的真实结果,不是部分和。all_gather 把这些”真实切片”拼起来;如果用 all_reduce,会把不同切片错误相加。


15.7 Expert Parallel:384 expert 怎么切到 8 卡

V4 的 MoE 类里:

class MoE(nn.Module):
    def __init__(self, layer_id: int, args: ModelArgs):
        super().__init__()
        ...
        assert args.n_routed_experts % world_size == 0
        self.n_routed_experts = args.n_routed_experts
        self.n_local_experts = args.n_routed_experts // world_size
        self.experts_start_idx = rank * self.n_local_experts
        self.experts_end_idx = self.experts_start_idx + self.n_local_experts
        ...
        self.experts = nn.ModuleList([
            Expert(...) if self.experts_start_idx <= i < self.experts_end_idx else None
            for i in range(self.n_routed_experts)
        ])
        self.shared_experts = Expert(...)   # 每 rank 都有完整 shared expert

8 卡 TP 部署 V4 Pro:384 / 8 = 48 个 expert / rank。每 rank 的 nn.ModuleList 里:

  • 自己持有的 48 个位置是真实 Expert 实例
  • 其他 336 个位置是 None

forward 时:

for i in range(self.experts_start_idx, self.experts_end_idx):
    if counts[i] == 0:
        continue
    expert = self.experts[i]
    idx, top = torch.where(indices == i)
    y[idx] += expert(x[idx], weights[idx, top, None])

if world_size > 1:
    dist.all_reduce(y)

y += self.shared_experts(x)

每 rank 只跑自己持有的 48 个 expert,输出贡献到 y。最后 all_reduce 把所有 rank 的部分输出累加。

通信成本:每层 MoE 一次 all_reduce on [B*S, D]。对于 V4 Pro 在 1M context + batch=8 下:

  • y 大小:8 × 1048576 × 7168 × 2 bytes (BF16) = 120 GB / step
  • 显然这不可能——all_reduce 不可能在每层 layer 上传 120 GB

实际上 V4 在生产部署时不会让 1M 序列经过所有 61 层 MoE——会用 prefill / decode 分阶段,且 MoE 的 all_reduce 在 decode 时只走 1 个 token(D / 8 卡 ≈ 数 KB)。

第 16 章会深入这部分通信优化——DeepEP 给 MoE 提供了比 NCCL all_reduce 更快的”专用通信库”。


15.8 一段 Attention 的并行流转

把 V4 的 Attention 在 8 卡 TP 下的并行流转走一遍:

flowchart TB
  X["x: [B, S, 7168]<br/>每 rank 都有完整副本"] --> ColLin1["ColumnParallelLinear (wq_b)<br/>每 rank 持 [1536, 16 heads × 512]"]
  ColLin1 --> Q["q: [B, S, 16 local heads, 512]<br/>每 rank 不同 head"]
  X --> Lin["Linear (wkv)<br/>每 rank 完整副本"]
  Lin --> KV["kv: [B, S, 512]<br/>每 rank 完整副本 (MQA)"]
  Q --> SparseAttn["sparse_attn<br/>每 rank 只算自己的 head"]
  KV --> SparseAttn
  SparseAttn --> O["o: [B, S, 16 local heads, 512]"]
  O --> ColLin2["wo_a (ColumnParallel)<br/>每 rank 持 [n_heads × 512 / 16, 16 × 1024]"]
  ColLin2 --> OLora["o_lora: [B, S, 16 / 8 groups, 1024]"]
  OLora --> RowLin["wo_b (RowParallel)<br/>每 rank 持 [16 × 1024 / 8, 7168]"]
  RowLin --> Reduce["all_reduce"]
  Reduce --> XOut["x_out: [B, S, 7168]<br/>每 rank 都有完整结果"]

整个 attention 的通信只有最后一次 all_reduce on [B, S, 7168]——其他都是 rank 内独立计算。这是 TP 在 attention 上的高效之处:通信量 O(B × S × D),远小于参数量 O(D²)。


15.9 与 Megatron-LM 的对比

V4 的并行类与 Megatron-LM 的 ColumnParallelLinear / RowParallelLinear 有同源思路(Megatron-LM 是这套方案的工业化先驱):

维度Megatron-LMV4 inference/model.py
ColumnParallel完整支持完整支持
RowParallel完整支持完整支持
EmbeddingVocabParallelEmbeddingParallelEmbedding (同名不同实现)
LM HeadVocabParallelOutputLayerParallelHead (含 hc_head 处理)
Sequence Parallel支持不支持(V4 用稀疏 attention 不需要)
Pipeline Parallel支持不支持(推理代码无 PP)
代码量数千行~50 行

V4 的并行实现非常简洁——只覆盖推理需要的部分。训练时的 Pipeline Parallel、Sequence Parallel 等更复杂机制由训练框架(不公开的内部代码)处理,与 inference/model.py 解耦。

这种”训练 / 推理代码分开”的设计让公开的推理代码极其简洁——读者不需要被训练栈的复杂性淹没。


15.10 通信开销估算

把 V4 Pro 在 8 卡 TP / 16 卡 TP+EP 下的通信开销估算一下(每 token decode):

通信操作张量大小频率总通信量 / token
Attention all_reduce[1, 7168] BF16 = 14 KB每层 1 次 × 61 层854 KB
MoE expert all_reduce[1, 7168] BF16 = 14 KB每层 1 次 × 61 层854 KB
LM Head all_gather[vocab/8] FP32 = 16 KB × 81 次130 KB
Embedding all_reduce[1, 7168] BF16 = 14 KB1 次14 KB
总计--~1.85 MB / token

NVLink 带宽约 600 GB/s(H100 之间),所以单 token 通信耗时约 3 μs——对 50 ms / token 的 decode 几乎可以忽略。

但在 prefill 阶段,序列长度 S 进入公式——通信量乘以 S。1M context 的 prefill 一次通信量约 2 GB——这时通信成本变得显著。

第 16 章 DeepEP 主要针对 prefill 的 all-to-all 优化。


15.11 动手实验:跑通最小 TP 推理

# 启动 2 卡 TP(在单机上模拟)
# torchrun --nproc-per-node=2 --master-port=29500 minimal_tp.py

import os
import torch
import torch.distributed as dist
import torch.nn as nn

dist.init_process_group(backend='nccl')
world_size = dist.get_world_size()
rank = dist.get_rank()
device = torch.device(f'cuda:{rank}')

class ColParallel(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        assert out_dim % world_size == 0
        self.weight = nn.Parameter(torch.randn(out_dim // world_size, in_dim, device=device))
    def forward(self, x):
        return x @ self.weight.T

class RowParallel(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        assert in_dim % world_size == 0
        self.weight = nn.Parameter(torch.randn(out_dim, in_dim // world_size, device=device))
    def forward(self, x):
        y = x @ self.weight.T
        dist.all_reduce(y)
        return y

# 2 卡上测试
m1 = ColParallel(128, 256)
m2 = RowParallel(256, 128)
x = torch.randn(4, 128, device=device)
y = m2(m1(x))
print(f"Rank {rank}: y shape = {y.shape}, mean = {y.mean().item():.4f}")
dist.destroy_process_group()

跑这段代码,会在 2 卡上看到一致的 y mean——证明 ColumnParallel + RowParallel 的通信正确组合得到了与单卡等价的结果。


15.11·补 V4 并行策略的”代码极简” 哲学

把 V4 的并行实现与 Megatron-LM / DeepSpeed 等成熟训练框架对比,最显著的差异是代码量——V4 的 4 个并行类合计约 50 行核心代码,Megatron-LM 同等功能要数千行。这种极简来自几个设计选择:

选择 1:把训练复杂度外置

Megatron-LM 是训练框架,必须处理优化器状态分片、梯度累积、Pipeline Parallel、混合精度训练等复杂场景。V4 的 inference/model.py 是推理代码——所有训练复杂度被外置到不公开的训练栈,只保留推理需要的最小并行抽象。

选择 2:放弃通用 API

Megatron-LM 的 ColumnParallelLinear 支持丰富配置——gather_output / async_tensor_model_parallel_allreduce / sequence_parallel 等十几个开关。V4 的版本只有一种行为——按 out_features 切分、forward 不通信。这种”放弃通用性”让代码极简、性能最优。

选择 3:全局变量代替依赖注入

world_size / rank 是文件级全局变量——不是每个 module 持有自己的 ProcessGroup 引用。这违反了”显式依赖” 的传统软件工程纪律,但避免了在 200 处方法签名里都加 group 参数。

选择 4:MQA-style 让 KV 不切

V4 的 num_key_value_heads=1 让 KV 在 TP 多卡间不切——每 rank 持有完整 KV cache。这绕过了”KV 切分时的复杂索引数学”——代码大幅简化。代价是显存冗余 N 倍,但 V4 的 KV cache 已经被 Compressor 压到很小(每序列 ~8 GB),8 卡冗余只占总显存的小份额。

这些选择让 V4 的并行代码”看一眼就懂”——这种可读性对开源项目极重要。读者不需要花一周读代码才能理解 V4 怎么工作,几小时就够了。


15.11·补·补 部署 V4 时的并行策略选择

把 V4 部署到生产时,TP / EP / DP 的配比是关键工程决策。给一个决策树:

决策点 1:序列长度

  • < 32K:可以用更激进的 TP(16+)让单序列延迟最低
  • 32K - 256K:8 卡 TP 是甜区——平衡延迟与吞吐
  • 256K - 1M:必须 8 卡 TP + 充足 KV 显存——可能需要把 KV cache 切到主机内存

决策点 2:并发量(batch size)

  • 低并发(batch < 4):TP 比 EP 重要——让单个 sequence 跑得最快
  • 中并发(batch 4-16):TP=8 + DP(多副本)——每个副本服务一部分用户
  • 高并发(batch > 16):TP=8 + EP=16/32(跨节点)—— 单一大模型实例服务大量并发

决策点 3:硬件拓扑

  • 单节点 8 卡:TP=8、不需要 EP(expert 全部在 NVLink 内)
  • 双节点 16 卡:TP=8 + EP=2(每节点 192 expert,跨节点 EP=2)
  • 多节点 32+ 卡:TP=8 + EP 跨节点 + DP 多副本

决策点 4:延迟 vs 吞吐 trade-off

  • 延迟优先:增加 TP(更多并行算 single token)
  • 吞吐优先:增加 batch + 保持 TP 适中

实际部署中,最常见的配置是 单节点 8 卡 TP=8——这是 V4 在 H100 上的最佳”性价比”配置。多节点部署仅在”高并发 + 延迟可接受”的场景下才有 ROI。


15.11·延展 V4 的并行抽象与 vLLM PagedAttention 的协同

V4 的并行类(ColumnParallel / RowParallel / ParallelEmbedding / ParallelHead)与 vLLM 的 PagedAttention 在工程层有微妙的协同关系。

协同点 1:KV cache 不切

V4 的 num_key_value_heads=1 让 KV 在所有 TP rank 上是完整副本。这与 vLLM PagedAttention 的”每 rank 持有全部 block 的本地副本”语义一致——每 rank 都能直接访问完整 KV,不需要跨 rank 通信。

协同点 2:Q 切分到 head

V4 的 wq_b 是 ColumnParallel——Q head 被切分到不同 rank。vLLM 的 PagedAttention 接收”本 rank 的 Q heads”,与”全副本的 KV” 做 attention 计算,输出 ColumnParallel 形式的 attention 输出。

协同点 3:O 投影做 reduce

V4 的 wo_b 是 RowParallel——attention 输出经过它后做 all_reduce 得到完整 hidden。vLLM 的 attention backend 在这一步可能直接调用 RowParallelLinear 的 forward——无缝衔接。

协同点 4:MoE 的 expert parallel

V4 的 384 expert 切到不同 rank。vLLM 适配时需要让 ModelRunner 知道每 rank 持有哪些 expert——这与 V4 的 experts_start_idx / experts_end_idx 一致。

这些协同点意味着:vLLM 的 V4 适配 PR 的”模型并行”部分非常薄——主要是把 V4 的并行类注册到 vLLM 的 ParallelState 系统。绝大多数代码可以从 V4 的 inference/model.py 直接复制——这是 V4 设计的工程红利。

具体的协同细节会在《vLLM 推理内核深度解析》第 14 章 “Tensor 并行” 后续更新中展开。


15.11·拓展 V4 并行实现的”边界条件”清单

V4 并行实现里有几个重要的边界条件——必须满足才能正确工作。把它们整理成一份”启动前检查清单”:

边界 1:out_features % world_size == 0

ColumnParallelLinear 要求输出维度能被 world_size 整除。V4 的 out_features 都是 128 倍数(如 128 head × 512 head_dim = 65536),可以整除 8 / 16 / 32 等常见 world_size。但如果你 fine-tune 时改了 head 数(比如减到 96),可能违反这个约束——必须挑能整除的 world_size。

边界 2:in_features % world_size == 0

RowParallelLinear 同理。V4 的 in_features 也都是 128 倍数,对常见 world_size 都满足。

边界 3:vocab_size % world_size == 0

ParallelEmbedding / ParallelHead 要求 vocab 能被 world_size 整除。V4 vocab=129280,可以整除 8(=16160 per rank)但不能整除 16(129280 / 16 = 8080)——意味着 16 卡 TP 部署需要扩展 vocab 或换其他切分方式。

边界 4:n_routed_experts % world_size == 0

MoE 要求 expert 数能被 world_size 整除。V4 的 384 expert 整除 8(48 per rank)、16(24 per rank)、32(12 per rank)——常见配置都满足。

边界 5:world_size == 1 时不通信

V4 的并行类有 if world_size > 1: 守卫——单卡模式下完全跳过通信。这让你可以单卡跑 V4 (small variant) 做调试 + 多卡跑 V4 Pro 做生产,同一份代码兼容两种模式。

边界 6:dist 必须先 init

world_size = dist.get_world_size() if dist.is_initialized() else 1——V4 的 Transformer.init 检查 dist 是否初始化。如果你忘了 dist.init_process_group(...),V4 会默默退化到 single-rank 模式,可能导致部署不正确。

把这 6 个边界做成 deployment 的 pre-check,可以在 V4 部署的第一天避免大部分配置错误。


15.12 延伸阅读

  • Megatron-LM 论文(arXiv:1909.08053):TP 的工业化先驱
  • DeepSpeed Ulysses(arXiv:2309.14509):sequence parallelism
  • Tensor Parallelism in Distributed Training(NVIDIA 文档):TP 的硬件视角
  • 本书第 16 章:DeepEP——MoE 的专用 all-to-all 通信库
  • 本书《vLLM 推理内核深度解析》第 14 章:vLLM 中的 TP 实现

15.12·补 V4 与 ZeRO 优化器分片的协同

V4 的并行类是”Tensor Parallel + Expert Parallel” 的组合。生产训练通常还会叠加 ZeRO 优化器分片——把优化器状态切到不同 rank。把这套叠加的工程细节梳理一下。

ZeRO-1:仅切分优化器状态(如 Muon 的 momentum buffer)。每 rank 持有所有 weight 但只持有部分优化器状态。

ZeRO-2:切分优化器状态 + 梯度。反向传播时梯度先在本 rank 计算,再 reduce_scatter 到对应 rank。

ZeRO-3:切分优化器状态 + 梯度 + weight。每 rank 只持有部分 weight—— forward 时需要 all_gather 把 weight 拼起来。

V4 与 ZeRO 的叠加规则:

  • TP 切分维度(如 head 维度)与 ZeRO 切分维度(如 optimizer state 的 layer 维度)必须正交——避免冲突
  • TP rank 之间共享 ZeRO rank ID—— 比如 8 卡 TP + 4 节点,每节点 8 卡是同一个 ZeRO rank
  • Expert Parallel 的 expert 不参与 ZeRO 切分——expert weight 已经分布到不同 rank,再切就乱了

V4 的训练大概率用 ZeRO-1(最保守)或 ZeRO-2——ZeRO-3 在 1.6T 模型上 weight all_gather 的通信开销过大,不划算。

这部分配置大多数公司不会自己实现——直接用 DeepSpeed 或 FSDP 的现成支持。但理解原理让你能 debug 配置错误。


15.12·补·补 V4 与 Sequence Parallel / Pipeline Parallel 的关系

V4 的并行类只覆盖 Tensor Parallel + Expert Parallel。但生产训练通常还有 Sequence Parallel 和 Pipeline Parallel——把它们与 V4 的关系说清楚。

Sequence Parallel (SP)

把 sequence 维度切到不同 rank。每 rank 处理 sequence 的一段——主要在 activation memory 上节省(不需要每 rank 存完整 sequence)。

V4 的并行类没有原生 SP 支持——inference/model.py 不切 sequence。但 V4 训练时大概率有 SP(不公开训练栈)——只是 inference 用不到。

为什么 inference 不用 SP

inference 的瓶颈是 KV cache + GEMM,不是 activation。SP 节省 activation 的好处在 inference 上没用——而 SP 的通信开销反而拖慢推理。所以 V4 inference 路径不带 SP。

Pipeline Parallel (PP)

把不同 layer 切到不同 rank——前半模型在 rank 0、后半在 rank 1。每 token 顺序经过 rank。

V4 的 inference/model.py 没有 PP——所有 layer 都在同一 rank(每 rank 持有全部 layer 的 TP 部分)。这是 inference 路径的设计选择。

为什么 inference 不用 PP

PP 引入”流水线泡沫”(pipeline bubble)——某些 rank 在等其他 rank。短上下文 + 大 batch 下泡沫小,PP 可行;但 V4 的目标是长上下文 + 适中 batch,泡沫会成为延迟瓶颈。

训练 vs 推理的并行差异

维度训练推理(V4 inference)
TP
EP
SP❌(不需要)
PP✅(大模型必须)❌(延迟敏感)
ZeRO❌(推理无优化器状态)
FSDP✅(部分)

理解这种差异让你正确选并行配置——训练时该用 PP 就用,推理时不要照搬训练配置。


15.13 本章小结

  • V4 用 4 个并行类:ColumnParallel / RowParallel / ParallelEmbedding / ParallelHead
  • ColumnParallel 切 out 维度,无 reduce;RowParallel 切 in 维度,需 all_reduce
  • ParallelEmbedding / ParallelHead 切 vocab 维度,分别用 mask+all_reduce / all_gather
  • Expert Parallel 用稀疏 ModuleList + None 占位 + per-rank 循环的”代码简洁”实现
  • Attention 的 ColumnParallel + RowParallel 经典组合让通信量降到 O(B × S × D)
  • 与 Megatron-LM 同源,但 V4 推理代码极简(~50 行)——训练复杂度解耦到不公开训练栈

第 16 章:DeepEP——V4 给 MoE all-to-all 量身定制的通信库。

评论 0