第18章 FSDP:ZeRO 风格的参数分片

“FSDP changes the question from ‘how do I parallelize this model’ to ‘how much GPU memory can I trade for communication’.”

—— PyTorch FSDP design doc

本章要点

  • DDP 在 70B 模型时崩溃:每张卡需要 280GB(params) + 280GB(grad) + 560GB(Adam state)= 1120GB,单 H100 才 80GB
  • FSDP 灵感来自 DeepSpeed ZeRO:把 params / grads / optimizer state 都切到 N 张卡,每张卡只持有 1/N
  • 核心机制:forward 时 AllGather 把分片重建成完整 param、用完后立即 reshard;backward 时同理 + 最后 ReduceScatter 同步梯度
  • 5 种 ShardingStrategyNO_SHARD(=DDP)、SHARD_GRAD_OP(ZeRO-2)、FULL_SHARD(ZeRO-3)、HYBRID_SHARD_HYBRID_SHARD_ZERO2
  • FSDP-1 vs FSDP-2:v2.4+ 推出新接口 torch.distributed.fsdp.fully_shard,模块化 + 与 torch.compile 兼容性好得多
  • prefetch 是性能命脉:在 layer N forward 时 prefetch layer N+1 的 unshard,隐藏通信延迟

18.1 为什么 DDP 不够用

第 17 章 DDP 假设”每张卡放得下完整模型”。70B 模型:

  • params:70B × 4B(fp32) = 280 GB
  • grads:280 GB
  • Adam optimizer state (m + v):560 GB
  • 加上 activations 几十 GB
  • 合计 1120+ GB

H100 80GB / GB200 192GB —— 单卡都装不下。这是 DDP 在大模型时代的天花板。

FSDP(Fully Sharded Data Parallel)的解法:把这堆显存按 rank 数 N 均匀切分到 N 张卡。每张卡只持有 1/N 份数据,需要时通过通信临时凑齐完整张量。

8 卡 FSDP:
  每张卡 params: 35 GB
  每张卡 grads:  35 GB
  每张卡 optimizer: 70 GB
  合计 ~140 GB / 卡  ← 80GB 仍然装不下
16 卡:  ~70 GB / 卡  ← 接近但还要 activation
32 卡:  ~35 GB / 卡  ← 舒服了

70B 模型至少 32 张卡才能跑稳。FSDP 的”分片粒度”是这套机制的核心。

18.2 5 种 ShardingStrategy

graph TB
    DDP["DDP / NO_SHARD<br/>每卡完整 params + grads + optimizer"]
    Z2["ZeRO-2 / SHARD_GRAD_OP<br/>params 完整,但 grads + optimizer 切片"]
    Z3["ZeRO-3 / FULL_SHARD<br/>params + grads + optimizer 全切片<br/>显存最省"]
    HZ["HYBRID_SHARD<br/>同节点内 ZeRO-3,跨节点 DDP<br/>跨节点带宽不足时的折中"]

    DDP -.显存大.-> Z2
    Z2 -.通信增多.-> Z3
    Z3 -.通信占比可能太高.-> HZ

    style Z3 fill:#dcfce7,stroke:#22c55e,stroke-width:2px
    style HZ fill:#fef3c7,stroke:#f59e0b

通信代价:

Strategyparams 通信grads 通信optimizer 通信显存 / 卡
NO_SHARD (DDP)0AllReduce0完整
SHARD_GRAD_OP0ReduceScatter01/N grads + opt
FULL_SHARD2× AllGather (fw + bw)ReduceScatter0(params 已分片)全部 1/N
HYBRID_SHARD节点内 AllGather节点内 RS + 跨节点 AR-1/N (节点内)

FULL_SHARD 显存最省、通信最多。生产代码里 70B+ 通常用 FULL_SHARD 或 HYBRID_SHARD。

18.3 unshard / reshard 的精确时机

FSDP 把参数按 N 切片后,每张卡平时只持有 1/N。但算子需要完整 param 才能算(Linear 需要完整 weight 矩阵做 GEMM)。所以 forward 前要 AllGather 把分片凑成完整 param,算完立即 reshard(释放完整副本,回到 1/N):

sequenceDiagram
    autonumber
    participant Layer as Layer N forward
    participant Comm as AllGather
    participant GPU as GPU 显存

    Layer->>Comm: 发起 AllGather (异步)
    Note over GPU: 此时 GPU 上 layer N 的 param 是分片
    Comm->>GPU: AllGather 完成 → 完整 param 在显存
    Layer->>GPU: 跑 layer N 的 GEMM
    GPU->>GPU: reshard: 立即释放完整 param
    Note over GPU: 回到 1/N 状态, 等下次 forward

backward 同理:算每层 grad 前再 AllGather 一次 param,算完 reshard,最后用 ReduceScatter 把这层的 grad 切片同步:

Layer N backward:
  1. AllGather param (因为反向也要用 weight 算 grad_input)
  2. 计算 grad_input + grad_param
  3. ReduceScatter grad_param: 8 卡各自只保留自己的 1/N 切片
  4. 释放完整 param

ReduceScatter 是 AllReduce 的”半成品” —— 每个 rank 只拿到 reduce 结果的 1/N 切片,不再 AllGather。这刚好对应 FSDP 的需求:每个 rank 只更新自己持有的 1/N 参数。

18.4 prefetch:隐藏通信延迟

朴素实现里:

fw layer 1 = AllGather + GEMM + reshard (串行)
fw layer 2 = AllGather + GEMM + reshard
...

每层都等 AllGather 完成才开始算 —— 通信延迟全暴露。prefetch 优化:

fw layer 1: AllGather_1 → GEMM_1 (期间 prefetch AllGather_2)
fw layer 2: GEMM_2 (用已经 prefetch 完的 param) → prefetch AllGather_3
...

每层 GEMM 进行时,下一层的 AllGather 在另一个 stream 上跑。理想情况下通信完全 overlap 到计算里,FSDP 性能接近 DDP(如果带宽足够)。

实现上 FSDP 用 dedicated CUDA stream(_unshard_stream_pre_unshard_stream_runtime_utils.py:263-269)跑 collective,与训练主 stream 并发。

prefetch 深度由 forward_prefetch=True/False 与 backward 的 backward_prefetch=BACKWARD_PRE/POST 控制。BACKWARD_PRE 在前一层 backward 开始前 prefetch;BACKWARD_POST 在前一层 backward 完成后 prefetch(更安全但 overlap 少)。

18.5 FSDP-1 vs FSDP-2

PyTorch v2.4+ 推出了新一代 FSDPtorch.distributed.fsdp.fully_shard(也叫 FSDP-2,源码在 torch/distributed/fsdp/_fully_shard/)。它解决了 FSDP-1 的几个痛点:

1. 模块化:FSDP-1 把整个模型 wrap 成一个 FullyShardedDataParallel(model),所有逻辑放一个大类(2167 行)。FSDP-2 用 fully_shard(submodule) API 给每个 submodule 单独 wrap,更精细控制:

# FSDP-2
from torch.distributed.fsdp import fully_shard

for layer in model.layers:
    fully_shard(layer)            # 每个 transformer 层单独 shard
fully_shard(model)                # root module

2. 与 torch.compile 兼容:FSDP-1 有不少 nn.Module.__setattr__ 黑魔法,让 Dynamo trace 时 graph break 严重。FSDP-2 重新设计了参数管理,能完整被 Inductor 编译。

3. DTensor 后端:FSDP-2 内部用 DTensor(distributed tensor)抽象,让张量分片成为 first-class concept。

生产代码里 v2.4+ 推荐用 FSDP-2。FSDP-1 仍然支持但被标记为”老 API”,新功能(如 fully_shardmesh 接口)只在 FSDP-2 上加。

18.6 mixed precision:FSDP 的另一大优化

FSDP 提供专门的 MixedPrecision 配置:

from torch.distributed.fsdp import MixedPrecision

mp_policy = MixedPrecision(
    param_dtype=torch.bfloat16,         # AllGather 时用 bf16 通信
    reduce_dtype=torch.float32,         # ReduceScatter 时用 fp32 (避免梯度精度损失)
    buffer_dtype=torch.bfloat16,
)

model = FSDP(model, mixed_precision=mp_policy)

为什么这个比朴素 model.bfloat16() 好?因为 FSDP 控制每种张量的 dtype 独立:

  • params 在 bf16(省一半 AllGather 带宽)
  • gradients 在 fp32 reduce(保证数值稳定)
  • master weights 仍是 fp32(让 Adam 更新精确)

这套”分别控制”是大模型训练 fp32+bf16 混合精度的标配。第 20 章量化与混合精度会展开。

18.6.5 DeviceMesh:拓扑的 first-class 表达

FSDP-2 之前,分布式训练靠 ProcessGroup 表达”哪些 rank 参与同一通信”。混合并行(DP + TP + PP)时要手工管理多个 group,容易出错。

torch.distributed.device_mesh(1553 行)引入 DeviceMesh 抽象:把 N 个 rank 排成 K 维网格,每维有自己的 ProcessGroup。

from torch.distributed.device_mesh import init_device_mesh

# 32 卡 = 4 节点 × 8 卡
# 节点内 8 卡 ZeRO-3, 跨节点 4 个 replica DDP
mesh_2d = init_device_mesh(
    device_type="cuda",
    mesh_shape=(4, 8),
    mesh_dim_names=("replica", "shard"),
)

# 取出某一维的 ProcessGroup
shard_pg = mesh_2d["shard"].get_group()    # 节点内 8 卡 group
replica_pg = mesh_2d["replica"].get_group() # 跨节点 4 group

每个 rank 在 mesh 里有 K 个坐标(如 (replica=2, shard=5))。在某一维上的 collective(如 shard.all_gather)只在那一维的 group 内做,不涉及其他维。

DeviceMesh 让 HSDP / 3D parallel(DP + TP + PP)的实现从”手动管 N 个 group”变成”声明一个 mesh 然后取维度”。这是 FSDP-2 / DTensor 的基础。

18.6.6 DTensor:分片张量的 first-class 类型

FSDP-2 内部不直接持有 plain Tensor,而是 DTensor(distributed tensor)。一个 DTensor 由三部分组成:

from torch.distributed._tensor import DTensor, Shard, Replicate

dtensor = DTensor.from_local(
    local_tensor,         # 本 rank 持有的切片
    device_mesh,          # 它在哪个 mesh 上
    placements=[Shard(0)] # 在 mesh 维度上怎么分布: Shard / Replicate / Partial
)

三种 Placement:

  • Shard(dim):按某 dim 切到 mesh 上,每 rank 持有 1/N
  • Replicate():每 rank 都有完整副本
  • Partial:每 rank 持有”部分和”,下次访问前需 reduce

DTensor 算子(dtensor + dtensor)会自动选最优通信策略:如果两边都是 Shard(0),相加无需通信;如果一个 Shard(0) 一个 Replicate,自动 AllGather + 加。这套自动 dispatch 让用户不用手写 collective。

FSDP-2 把每个 nn.Parameter 包成 DTensor(local_shard, mesh, [Shard(0)]),forward 时调用 to_replicate() 触发 AllGather 凑成完整 param、用完转回 ShardFSDP-2 的 unshard / reshard 逻辑就是 DTensor 的 placement 转换,比 FSDP-1 的手写更通用、与 torch.compile 兼容好。

18.6.7 HSDP 的具体拓扑

HYBRID_SHARD 的关键是用 2D mesh:节点内 shard、跨节点 replicate:

mesh = init_device_mesh("cuda", (2, 8), mesh_dim_names=("inter", "intra"))

# 配 FSDP-2: 节点内 (intra) shard, 跨节点 (inter) replicate
fully_shard(model, mesh=mesh)

通信模式:

  • forward 的 AllGather 只在 intra 维(节点内),走 NVLink(带宽 ~600 GB/s)
  • backward 的 ReduceScatter 也在 intra 维
  • backward 完成后额外做一次 AllReduce 在 inter 维(跨节点)同步梯度

这条策略让”高带宽节点内通信”承担分片代价、“低带宽跨节点”只做一次梯度同步。32 卡 H100 集群上 HSDP 比 FULL_SHARD 通常快 20-40%,前提是节点内带宽远大于跨节点。

18.6.8 _unshard 的 stream 编排

打开 _runtime_utils.py:277_unshard

def _unshard(state, handle, unshard_stream, pre_unshard_stream):
    with state._device_handle.stream(pre_unshard_stream):
        # 1. 在 pre-unshard stream 上准备 buffer (alloc 完整 param 大小的 tensor)
        pad_for_unshard(handle)

    with state._device_handle.stream(unshard_stream):
        # 2. 在 unshard stream 上发起 AllGather
        # event 让 unshard stream 等 pre-unshard 完成
        unshard_stream.wait_stream(pre_unshard_stream)
        all_gather_into_tensor(...)

两个 stream 的分工:

  • pre_unshard_stream:跑 alloc / pad 等”准备工作”
  • unshard_stream:跑 AllGather 本身

为什么要两个 stream?因为如果 alloc 与 AllGather 都在主 stream 跑,会阻塞下一层 forward。两个独立 stream 让它们与主 stream 完全 overlap。

wait_stream 是 CUDA 的 stream-to-stream 同步原语:让 unshard stream 等 pre-unshard stream 上之前发的 op 完成。这种”多 stream 协同”是 FSDP 性能的关键 —— 配合 prefetch(§18.4),让通信完全隐藏到计算里。

18.6.9 activation_checkpoint × FSDP 的协作

第 7 章 §7.5.3 讲过 activation_checkpoint:前向不保存中间激活,反向时重新 forward 一遍取回。这套机制配 FSDP 用时多一层复杂度。

考虑 FSDP+checkpoint 的反向流程:

  1. 正常 forward(FSDP layer N):
    • AllGather 凑完整 param → forward 计算(不存 activation)→ reshard 释放完整 param
  2. 反向走到 layer N
    • 需要重新 forward 取 activation → 但 param 已经被 reshard
    • 必须 再次 AllGather 凑完整 param → 重 forward → 拿到 activation → backward
    • 完成后又 reshard
sequenceDiagram
    autonumber
    participant FW as Forward layer N
    participant Mem as 显存
    participant BW as Backward layer N

    FW->>Mem: AllGather param (1)
    FW->>FW: 跑 forward, 不存 activation
    FW->>Mem: reshard param (释放完整副本)

    Note over BW: 反向到 layer N

    BW->>Mem: AllGather param (2) ← 重 forward 还要再 unshard 一次!
    BW->>BW: 重 forward 拿到 activation
    BW->>BW: backward 算梯度
    BW->>Mem: reshard param
    BW->>Mem: ReduceScatter grad

每个 checkpointed layer 在反向触发 2 次 AllGather(一次为 backward 本身、一次为重 forward)。这是 FSDP+checkpoint 比纯 FSDP 通信量上升的根本。

torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py:112CheckpointWrapper 是 PyTorch 给 FSDP 的官方 checkpoint 接口。它的关键设计:

  • apply_activation_checkpointing(model, ...) (:239) 递归给每个匹配 check_fn 的 module 套上 CheckpointWrapper
  • use_reentrant=False 是新版默认(v2.0+):用 saved_tensors_hooks(第 7 章 §7.5.3)实现,与 FSDP / torch.compile 兼容性好
  • 老的 use_reentrant=Truetorch.autograd.Function(第 7 章 §7.8),与 FSDP 反向 hook 互动有 corner case,新代码避开

实战决策:开 FSDP 后是否再加 activation_checkpoint?看显存。70B 训练即使用 FULL_SHARD,activation 显存(forward 中间激活)仍然几十 GB。activation_checkpoint 能再省 70%+ activation 显存,代价是反向多 2x AllGather。如果跨节点带宽足够,这笔账划算 —— 大模型训练几乎都开。

apply_activation_checkpointingcheck_fn 通常设成”是 transformer block 就 wrap”:

from functools import partial
apply_activation_checkpointing(
    model,
    check_fn=lambda m: isinstance(m, TransformerBlock),
)

这种”按 block 粒度 checkpoint”是 LLM 训练的标准做法。block 内部不 checkpoint —— 否则 attention / mlp 各自重 forward 通信量爆炸。

18.6.10 FlatParameter (FSDP-1):把多个 param 平铺成一个大张量

FSDP-1 内部不直接管理用户的 nn.Parameter,而是把同一 wrap unit 内的多个 param 平铺合并成一个 FlatParameter

# 用户 module 有 3 个 param
linear.weight  shape [768, 768]    fp32
linear.bias    shape [768]         fp32
norm.weight    shape [768]         fp32

# FSDP-1 合并 (flatten + concat)
flat_param = FlatParameter(torch.cat([
    linear.weight.flatten(),     # 589824 元素
    linear.bias.flatten(),       # 768
    norm.weight.flatten(),       # 768
]))                              # 共 591360 元素

_flat_param.py:202FlatParameternn.Parameter 子类,元数据里记着每个原始 param 的 offset / shape / dtype。访问 linear.weight 时通过 view 从 FlatParameter 取出对应区段。

为什么要 flatten?

  • AllGather 能一次拉所有 param(vs 分别拉每个 param 的 N 次 collective)
  • shard 时只切一次(按 1/N 切 FlatParameter,不是切几十个 param)
  • memory 连续让 GPU memory bandwidth 利用率高

代价是 view 重建复杂、与 torch.compile 兼容性差。FSDP-2 抛弃 FlatParameter,每个 param 用 DTensor 直接 shard —— 与 compile 兼容性大幅提升,但 collective 数量变多(靠 group AllGather 等优化补偿)。

理解 FlatParameter 让你看到 FSDP-1 的 ckpt 文件里那些”奇怪 key”(如 _fsdp_wrapped_module._flat_param_0)时不困惑 —— 那是 flat 后的合并张量。

18.6.11 auto_wrap_policy:决定 sharding 粒度

fully_shard(model, auto_wrap_policy=...) 让 FSDP 自动决定哪些子 module 各自成为一个 wrap unit。粒度选择:

  • 太粗(整个 model 一个 unit):unshard 一次拉全部参数,显存峰值与不分片相同 —— FSDP 退化成 DDP
  • 太细(每个 Linear 一个 unit):每层 unshard 一次 collective,通信开销爆炸
  • 合适粒度:每个 transformer block 一个 unit,平衡 collective 数量与显存峰值

wrap.py:178ModuleWrapPolicy 是常用 policy:按 module 类型选择 wrap unit。

from torch.distributed.fsdp.wrap import ModuleWrapPolicy

policy = ModuleWrapPolicy({TransformerBlock})    # 每个 TransformerBlock 一个 unit
fully_shard(model, auto_wrap_policy=policy)

其他 policy:

  • size_based_auto_wrap_policy:按参数量阈值(如 >100M 一个 unit)
  • transformer_auto_wrap_policy:transformer 友好的版本
  • 自定义 callable policy

LLM 训练几乎都是 ModuleWrapPolicy({TransformerBlock}) —— 与 activation_checkpoint 的 wrap 粒度对齐,让两者协作最优。

18.6.12 lazy init 与 meta device:训练前不分配显存

70B 模型直接 model = LlamaModel(config) 会立即在 GPU 上分配 280GB params —— 单卡装不下、初始化崩。

torch.device("meta") 是个”假 device”,张量只有 shape / dtype 元信息、没有实际数据。FSDP-2 配 meta device 流程:

with torch.device("meta"):
    model = LlamaModel(config)    # 不分配显存

fully_shard(model, mesh=mesh, ...)   # FSDP 用 meta module 构造 sharded param

model.to_empty(device='cuda')         # 把 meta param 替换成真实显存 (1/N 大小)
init_model_weights(model)             # 用户实现的 init 函数 (按 rank 自己 init 那 1/N)

整个流程全程没有”完整 280GB”在任何 rank 显存里出现。FSDP-2 的 lazy init 让 70B+ 模型能在 80GB 单卡上启动训练 —— 之前是不可能的。

to_empty() 是个特殊 API:把 meta param 替换成同 shape 的真实张量但不初始化。后续用户调 init_model_weights 在每个 rank 上各自 init 自己持有的那 1/N 数据。这种”分片 init”避免了”先 init 完整 model 再切”的中间显存峰值。

18.6.13 CPU Offload:参数 / optimizer state 卸到 CPU

显存极紧张时(如单卡训练超大模型),FSDP 提供 cpu_offload

from torch.distributed.fsdp import CPUOffload

fully_shard(model, cpu_offload=CPUOffload(offload_params=True))

机制:

  • params 平时存在 CPU RAM(host memory)
  • forward 时 CPU → GPU 拷贝 + AllGather + GPU 跑 forward + reshard 回 GPU + GPU → CPU 拷贝释放 GPU 副本
  • backward 同理

代价是每 forward 多两次 H2D / D2H 拷贝(PCIe 受限,几十毫秒)。整个训练吞吐通常降 50-70%,但能让”放不下的模型放下”。

更激进的 optim_state_offload(FSDP-2):optimizer state(exp_avg / exp_avg_sq)也卸到 CPU。step 时 CPU → GPU 取 + 计算 + 写回。再省一份显存(约参数量 × 2 字节,70B 大约 560GB)。

实战:能用多卡分摊就用多卡(HSDP),cpu_offload 是”硬件不够时的最后兜底”。生产 70B 训练几乎从不用 cpu_offload —— 性能损失太大、不如多租几张卡。

18.6.14 state_dict_type:完整 vs 分片

FSDP 提供两种 state_dict 视图:

  • StateDictType.FULL_STATE_DICTapi.py:293):rank 0 收集所有 rank 的分片,组装成完整 state_dict(与未 shard 的 model 视图等价)
  • StateDictType.SHARDED_STATE_DICTapi.py:340):每 rank 输出自己持有的 1/N 切片
from torch.distributed.fsdp import StateDictType

# 完整 state_dict (rank 0 持有完整, 其他 rank 空)
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT):
    sd = model.state_dict()
    if rank == 0:
        torch.save(sd, "ckpt.pt")

完整 state_dict 简单但 rank 0 要装下完整模型(70B = 280GB),实战不可行。

# 分片 state_dict (每 rank 写自己那 1/N)
with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
    sd = model.state_dict()
    dcp.save(sd, checkpoint_id="ckpt_dir")

分片 state_dict 与第 19 章 §19.6.7 的 DCP 配合,每 rank 并行写自己那份 —— 这是 70B+ 训练 ckpt 的标配。SHARDED_STATE_DICT 写出的不是 ddp model 的视图,而是”DTensor 字典”,每个 entry 是 DTensor(local_shard, mesh, placement)。DCP 知道怎么序列化 DTensor + 加载时按当前 mesh 重建。

18.6.15 summon_full_params:临时凑齐完整参数

某些操作(如 model surgery、debug print 完整 weight)需要完整参数。FSDP.summon_full_params context manager 让 FSDP 临时 unshard 整个 group:

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

with FSDP.summon_full_params(model, recurse=True):
    # 此时所有 param 都是完整的 (每 rank 都有完整副本)
    print(model.linear.weight.shape)    # [hidden, hidden] 完整尺寸

实现机制:进入 with 块时触发对所有 wrap unit 的 AllGather → 完整副本临时驻留显存 → 退出时立即 reshard。显存峰值临时翻 N 倍,所以 70B 模型 8 卡 unshard 后单卡瞬间需要 280GB —— 不可行。生产 summon_full_params 用于小 unit / 测试。

writeback=True 让退出时把完整 param 写回到分片(用于 model surgery 后保存改动)。offload_to_cpu=True 让完整副本驻留 CPU 而非 GPU(省 GPU 但慢)。

18.6.16 FSDP × clip_grad_norm

第 17 章 §17.8.31 提过 DDP 下 clip_grad_norm_ 直接用就行(grad 已经全局平均)。FSDP 不行:每 rank 持有 1/N 的 grad,本地 norm 不是全局 norm。

FSDP-2 提供专用接口:

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

FSDP.clip_grad_norm_(model, max_norm=1.0)

内部实现:

  1. 每 rank 算 local grad norm²(即 sum of squares)
  2. AllReduce sum 得全局 norm²
  3. 算 sqrt 得全局 norm
  4. 每 rank 用全局 norm 缩放自己的 grad 切片

只多一次 AllReduce 标量(极小开销),数学上与 single-rank clip 等价。FSDP-1 / FSDP-2 都提供这个 API,名字略有差异(FSDP-2 用顶层 torch.nn.utils.clip_grad_norm_ 自动识别 DTensor 走分布式路径)。

18.6.17 FSDP × torch.compile:v2.4+ 的兼容路径

FSDP-1 与 torch.compile 兼容性差(FlatParameter view + 各种 hook 让 Dynamo 频繁 graph break)。FSDP-2 重新设计后与 compile 几乎完美兼容:

fully_shard(model, mesh=mesh)
compiled_model = torch.compile(model)

实现关键:

  • DTensor 算子在 dispatcher 层有完整注册(第 5 章),Dynamo 能 trace
  • AllGather / ReduceScatter 算子用 functional collectives(第 16 章 §16.7.9),与 functionalize 兼容
  • FSDP-2 的 hook 用 register_post_accumulate_grad_hook,与 compiled autograd 兼容

实测 70B FSDP-2 + torch.compile 比 FSDP-1 + compile 快 1.5-2x —— FSDP-1 的兼容代价是巨大的。所以新代码强烈用 FSDP-2,老代码尽快迁移。

18.6.18 BACKWARD_PRE vs BACKWARD_POST prefetch

§18.4 提过 prefetch 有两种模式:

  • BackwardPrefetch.BACKWARD_PRE(推荐):layer N+1 backward 开始前 prefetch layer N。最大化 overlap、安全性高
  • BackwardPrefetch.BACKWARD_POST:layer N+1 backward 结束后才 prefetch layer N。overlap 少但显存峰值更低

BACKWARD_PRE 是默认,几乎所有场景最优。BACKWARD_POST 仅在显存极紧张时考虑(少 overlap 但少一份 unsharded param 同时驻留)。

forward_prefetch=True 是另一档:forward 时也做 prefetch(默认关)。开启后 forward 速度提升 5-15%,代价是显存峰值轻微上升。生产 LLM 训练通常开启。

18.6.19 ZeRO-2 (SHARD_GRAD_OP) 的具体节省机制

§18.2 表里讲 ZeRO-2 是”params 完整、grads + optimizer 切片”。具体怎么实现?

  • forward:因为 params 完整(每 rank 都有),forward 不需要 AllGather params —— 直接跑 → 显存与 DDP 相同
  • backward:每 rank 算自己 batch 的本地 grad → ReduceScatter(不是 AllReduce!)让每 rank 拿到全局 grad 的 1/N 切片
  • optimizer step:每 rank 只更新自己持有的 1/N 参数(用 1/N grad + 1/N optimizer state)
  • step 后 AllGather 更新后的参数:让所有 rank 重新拿到完整 params,准备下一次 forward

这套流程让 grads 与 optimizer state 各只占 1/N(与 ZeRO-3 同),但 params 完整(vs ZeRO-3 的 1/N)。ZeRO-2 显存比 DDP 省 ~67%(grads + optimizer state 各 1/N,2 / 3 显存),但保留完整 params 的灵活性。

适用场景:模型刚好能装下 params 但 optimizer state(Adam 是 params × 2)装不下。比如 13B 模型 fp32 params 52GB(8 卡能装下)但 Adam state 104GB(装不下)—— ZeRO-2 完美。

18.6.20 NO_SHARD:FSDP 退化成 DDP

ShardingStrategy.NO_SHARD 让 FSDP 不分片任何东西,等价于 DDP

fully_shard(model, sharding_strategy=ShardingStrategy.NO_SHARD)
# 等价于
DDP(model)

这个看似无意义的选项有真实工程价值:

  • 统一接口:训练框架可以”用 FSDP 一套接口表达 DDP / ZeRO-2 / ZeRO-3 / HSDP”,配置切换简单
  • fallback 路径:HSDP 的跨节点维度本质是 NO_SHARD(跨节点不切,节点内才切)
  • debugging:怀疑 FSDP 引入的 bug 时切到 NO_SHARD 看是否消失

NO_SHARD 让 FSDP 的接口涵盖 DDP 全部功能,FSDP-2 在 v3.0+ 可能成为统一的多卡训练 API。

18.6.21 use_orig_params=True:兼容老 optimizer

FSDP-1 用 FlatParameter(§18.6.10)会让用户原本的 nn.Parameter 引用失效 —— optimizer 创建时拿到的是 user param,FSDP wrap 后这些 param 引用指向已经”被合并”的位置。

use_orig_params=True 让 FSDP-1 保留原始 param 引用:用户构造 optimizer 时传 model.parameters(),FSDP 内部建立 orig_param ↔ FlatParameter 切片的映射。optimizer step 时按 orig_param 更新,FSDP 自动写回 FlatParameter。

这条选项让”现有 optimizer 代码不改”就能上 FSDP。代价是内部多一层映射开销(小)。生产代码强烈建议开启。FSDP-2 默认行为就这样(每个 param 独立 DTensor,根本不需要这个选项)。

18.6.22 FSDP × 3D parallel (DP + TP + PP)

70B+ 训练经常需要 3D parallel

  • Data Parallel (DP):FSDP 在 DP 维度上分片
  • Tensor Parallel (TP):把单层 weight 切到多卡(如 attention QKV 切 4 路)
  • Pipeline Parallel (PP):把多层切到多卡,流水线执行

DeviceMesh(§18.6.5)让 3D 配置变简单:

mesh = init_device_mesh("cuda", (PP, DP, TP), mesh_dim_names=("pp", "dp", "tp"))

# FSDP 在 dp 维度上 shard
fully_shard(model, mesh=mesh["dp"])

# TP 在 tp 维度上 shard
parallelize_module(model, mesh["tp"], TPParallelStyle())

# PP 用 mesh["pp"] 配 PipelineStage

3D parallel 让 1024 卡训练 405B 模型(如 Llama-3 405B)成为可能。每个维度的 group 用 functional collectives(第 16 章 §16.7.9)通信,互不干扰。

实战参数选择:8 卡节点内 TP=8(NVLink 高带宽)+ 4 节点 DP(FSDP)+ 4 PP(跨更多节点)= 128 卡训练。具体取决于模型与硬件拓扑。

18.6.23 FSDP × LoRA:高效微调

LoRA(Low-Rank Adaptation)只训练注入的低秩矩阵 A、B,冻结原始 weight。FSDP 与 LoRA 的协作:

# 冻结 base model 参数
for p in base_model.parameters():
    p.requires_grad = False

# 注入 LoRA adapter (有 grad)
inject_lora(base_model, rank=8)

# FSDP wrap (只 shard 有 grad 的部分?)
fully_shard(base_model, ...)

关键问题:FSDP 默认 shard 所有 params,包括 frozen 的 base weights。但 frozen weights 不需要 grad / optimizer state,shard 浪费空间吗?

实际不浪费 —— frozen weights 只是 inference 时需要完整 unshard、不进 backward / optimizer,显存账上反而占更少(没有 grad / optimizer state copy)。FSDP 在内部识别 requires_grad=False 的 param、跳过它们的 grad 处理。

LoRA + FSDP 是 fine-tuning 70B 模型的标配 —— base model 冻结后只训几十兆 LoRA 参数,FSDP 让 base model 分片到多卡能装下。

18.6.24 FSDP × gradient accumulation

显存不够调 batch size 时常用 gradient accumulation:跑 N 个 micro batch 累积梯度、再一次 step。

FSDP 下的微妙点:每 micro batch 默认会触发 ReduceScatter 同步梯度。N 次 micro batch = N 次 collective —— 浪费。

fully_shard(..., reshard_after_forward=False) + with model.no_sync(): 上下文:

with model.no_sync():
    for i in range(accumulation_steps - 1):
        loss = model(batch[i])
        loss.backward()        # 不触发 ReduceScatter, 只本地累积

# 最后一个 micro batch 正常触发 ReduceScatter
loss = model(batch[-1])
loss.backward()
optimizer.step()

no_sync 让 N-1 次 micro batch 跳过 collective,只在最后一次同步累积总梯度。通信量从 N 次降到 1 次,gradient accumulation 几乎免费。

DDP 也有同名 no_sync API,思想一致。这是 PyTorch 多卡训练的标准 gradient accumulation 模式。

18.6.25 FSDP × 量化训练 / 推理

第 20 章讲过量化。FSDP 与量化协作有几个工程点:

训练时 fp8:FSDP MixedPrecision(param_dtype=torch.float8_e4m3fn) 让 AllGather 通信用 fp8,比 bf16 再省一半带宽。但 fp8 数值范围窄,需要 per-tensor scale,FSDP-2 与 TransformerEngine(NVIDIA 库)配合才能正确处理 scale。

推理时 INT8:FSDP shard 的 model 直接量化遇到障碍 —— 量化要看完整张量做 calibration。一般做法:训练用 FSDP / 量化前把 model 收到单卡 / 量化后再加载(不分片或者用 PT2E 静态图)。

LLM 推理几乎不用 FSDP(vLLM / SGLang 用 TP + KV cache 管理替代)。FSDP 主要是训练时的工具,推理时通常切换到其他并行策略。

18.6.26 FSDP root module 的特殊性

FSDP 中的”root module”(最外层 fully_shard 的 module)有几个特殊职责:

  • 管理整个 FSDP unit 树:root 持有所有子 unit 的引用,调度它们的 unshard / reshard 时机
  • 管理 stream:root 创建 _unshard_stream / _pre_unshard_stream 等,子 unit 共用
  • 触发 root_pre_forward:第一次 forward 时 root 做整体初始化(首次 AllGather、CUDA stream sync 等)
  • forward 完成后调 reshard:root 决定 root unit 自身的 reshard 时机(最后一个 reshard 在整 forward 完成后)

fully_shard(model)(不指定 mesh)默认让最外层 model 成为 root。如果用 fully_shard(layer) 给每层都 wrap,最外层依然是 root,每层是 sub-unit。root 的 lifecycle 决定整体训练流程,理解这点能解释为什么”在 root 之外的代码看到的 param 是分片的、root forward 内部看到完整的”。

18.6.27 FSDP × torch.export 与部署

FSDP 是训练工具,部署时通常不带 FSDP 直接 export。部署流程:

  1. 训练完成后用 FSDP.summon_full_params(model) 或者 state_dict_type=FULL_STATE_DICT 收集完整权重
  2. 在单卡(或推理用的并行配置)重建 model(不带 FSDP wrap)
  3. 加载完整权重
  4. torch.export(model, ...) + AOTI(第 15 章 §15.6.7)

这条流程让训练时的 FSDP 与部署时的 AOTI 完全解耦。FSDP 不存在于部署 binary 里 —— 它是纯粹的训练时工具。

实战:HuggingFace Transformers 的 FSDP 训练流程结尾通常调 unwrap_model(model) 取出原始 nn.Module、再保存 state_dict。这是与 FSDP 解耦的标准做法。

18.6.28 FSDP-2 与 DTensor 的层次关系

FSDP-2 内部把每个 nn.Parameter 替换成 DTensor(§18.6.6)。这意味着:

  • 用户视角:model.linear.weight 是 DTensor(local_shard 是 1/N)
  • forward 视角:DTensor 的算子自动触发必要的 placement 转换(如从 Shard(0) 转 Replicate 触发 AllGather)
  • backward 视角:DTensor 的反向规则自动产生对应的反向 collective

这套机制让 FSDP-2 不需要写很多手动的 unshard / reshard 代码 —— DTensor 自己处理 placement 转换。FSDP-2 的核心代码(_fully_shard 目录)只有几千行,远比 FSDP-1 的 fully_sharded_data_parallel.py + _runtime_utils.py 共 4000+ 行少。

DTensor 思想的胜利:把分片表示为类型而非协议。Tensor 的 placement 是 Tensor 类型的一部分,编译器 / 运行时都能利用这个信息做优化。这是 PyTorch 分布式训练演进的下一代方向。

18.6.29 完整训练 step 时间分解

70B Llama 训练单 step 时间(H100,32 卡 4 节点 HSDP,bf16):

| 阶段                            | 占比   | 时长   |
| forward (compute)              | 30%   | 1500ms |
|   ├─ AllGather params          | 15%   | (overlap 在 forward 计算里)
|   └─ forward 实际计算           | 15%   |
| backward (compute)             | 45%   | 2250ms |
|   ├─ AllGather params          | 12%   | (overlap)
|   ├─ backward 实际计算         | 25%   |
|   └─ ReduceScatter grads       | 8%    | (overlap)
| 跨节点 AllReduce (HSDP only)    | 10%   | 500ms  |
| optimizer.step                  | 5%    | 250ms  |
| 单 step 总时间                  | 100%  | ~5000ms|

对比 DDP 70B 单 step ~2000ms(§17.10.6),FSDP 慢 ~2.5x —— 代价是显存从需要 1120GB 降到每卡 35-40GB。这就是 FSDP 的工程哲学:用通信换显存

如果显存够(中等模型),用 DDP 性能更好;如果装不下,FSDP 是唯一选择。混合 HSDP 是中间值 —— 跨节点不分(DDP 风格)、节点内分(FSDP 风格),平衡通信与显存。

18.6.29.5 内存账:FSDP vs DDP 的精确对比

7B 模型 fp32 训练详细内存账(每卡):

| 项                    | DDP    | FSDP-2 ZeRO-3 (8 卡) |
| params                | 28 GB  |  3.5 GB              |
| grads                 | 28 GB  |  3.5 GB              |
| optimizer (Adam m+v)  | 56 GB  |  7 GB                |
| activations           |  8 GB  |  8 GB (相同)         |
| 临时 buffer            |  4 GB  | 28 GB (unsharded peak)|
| 单卡总               |124 GB  | 50 GB                |

注意 FSDP-2 的”临时 buffer”列:unsharded params 在 forward 那一瞬间需要完整 28GB(虽然只 1 ms 后就 reshard)。这就是为什么 FSDP-2 的实际显存峰值不是简单的 1/N。

70B 模型同样表(32 卡 ZeRO-3):

DDP:  params 280 + grads 280 + opt 560 + activ 几十 + buffer = 1100+ GB → 装不下
FSDP-2: params 8.75 + grads 8.75 + opt 17.5 + activ 几十 + buffer 280 (unsharded) = ~330 GB

FSDP-2 在 70B 上让”单卡装下成为可能”。但 buffer 的 unsharded peak 仍是 280GB —— 这告诉我们 unit 粒度不能太大(每 unit unshard 后的临时显存是关键约束)。

18.6.29.7 实测带宽 vs 计算的临界点

FSDP 性能取决于”通信能不能 overlap 进计算”。临界点:

  • 节点内 NVLink (~600 GB/s):fp32 5GB AllGather 约 8 ms,足够 overlap 进十几 layer 的 backward 计算
  • 节点间 IB (~400 GB/s 双向):5GB AllGather 约 12 ms,仍能 overlap
  • 节点间 100 Gbps ethernet (12 GB/s):5GB AllGather 约 400 ms,完全无法 overlap —— 这是为什么 ethernet 集群用 HSDP 而非纯 FSDP

判断你的硬件是否适合纯 FULL_SHARD:

  • 算每 backward step 的 GPU 计算时间(用 profiler 看)
  • 算每个 wrap unit 的 AllGather 通信时间(unit_param_size / 带宽)
  • 比较两者,AllGather > 计算时间就要切 HSDP 或减小 unit 粒度

实战:H100 + NVLink,70B 训练每 unit 约 2GB params、forward 50ms / unit。AllGather 时间 ~3ms 远小于 50ms compute → 完美 overlap。

18.6.30 FSDP × torch.compile 的协作细节

FSDP-2 + compile 的具体内部协作:

fully_shard(model, mesh=mesh)
compiled = torch.compile(model)

out = compiled(x)

发生的事:

  1. 第一次调用:Dynamo trace 看到 model(x),遇到 DTensor 张量
  2. DTensor 算子 dispatch:每个 op 在 dispatcher 层有 DTensor 的特殊实现,知道如何处理 placement
  3. AOTAutograd functionalize:DTensor 的 mutation(unshard / reshard)被识别 + functionalize
  4. min-cut partition:考虑 collective 的 cost(AllGather 比普通 op 贵)做 fusion 决策
  5. Inductor codegen:生成的 Triton kernel 直接调用 NCCL collective 函数,与 compute fuse

最终编译产物里 collective 与 compute 在同一个 kernel 链路,CPU 几乎不参与。这与 FSDP-1 + compile(频繁 graph break)相比有本质提升。

实测 70B FSDP-2 + compile 比纯 FSDP-2 快 1.3-1.5x。第 14 章 §14.9.5 的 transformer 加速比就建立在这条路径上。

18.6.31 FSDP × Pipeline Parallel

Pipeline Parallel (PP) 把模型按层切到多卡,与 FSDP 在不同维度上分。3D parallel 中两者协同:

mesh = init_device_mesh("cuda", (PP, DP), mesh_dim_names=("pp", "dp"))

# 每 pipeline stage 内用 FSDP
for stage_id in range(PP):
    if local_pp_rank == stage_id:
        layers = model.layers[stage_id*L:(stage_id+1)*L]
        fully_shard(layers, mesh=mesh["dp"])

# Pipeline schedule 处理跨 stage 通信
pipe = PipelineStage(layers, ...)

每个 PP stage 内部用 FSDP shard 自己那部分 layers。stage 之间用 P2P send/recv 传 activation。这种”stage 内 FSDP、stage 间 PP”是 405B 训练的标准配置。

实战调优:PP 的”bubble”(pipeline 启动 / 结束的空闲时间)与 FSDP 的 collective 时间互动复杂,需要 profiler(第 21 章)实测每个 stage 的 timeline 优化。

18.6.31.5 FSDP × evaluation:临时切到完整模型

训练循环里偶尔要做 eval(计算 val loss / metrics)。FSDP 模型默认是分片状态,eval 时需要完整 forward。两条路:

A. 保持 FSDP 状态做 eval(推荐):

model.eval()
with torch.no_grad():
    for batch in val_loader:
        out = model(batch)    # FSDP 仍 unshard / reshard, 与训练相同路径

eval 期间 FSDP 仍触发 AllGather + reshard,但因为 no_grad、不做 reshard backward、不需要 ReduceScatter grads。开销是单纯的 forward unshard。

B. summon_full_params 一次性 unshard

with FSDP.summon_full_params(model, recurse=True):
    model.eval()
    with torch.no_grad():
        for batch in val_loader:
            out = model(batch)

完整模型驻留显存,eval 期间不触发 collective —— 但单卡显存压力翻 N 倍。只在 eval batch 多 + 模型小时用,70B 训练绝对不能用(OOM)。

实战推荐 A 路线:FSDP eval 速度足够,省心。

18.6.32 FSDP 错误诊断速查表

症状可能原因诊断
反向卡在 AllGather某 rank 进度不一致 / NCCL hangTORCH_NCCL_ASYNC_ERROR_HANDLING=1 看哪个 rank 先报错
OOM 在 forward 第一层wrap policy 太粗(整 model 一个 unit)改用 ModuleWrapPolicy({TransformerBlock})
训练慢比 DDP 慢一倍HSDP 跨节点带宽不够HYBRID_SHARD_ZERO2 或者升级网络
RuntimeError: Tensors of type DTensor must have device_mesh attributemodel surgery 后 wrap 出错检查是否在 fully_shard 之后改了 model
ckpt save 时 rank 0 OOM用了 FULL_STATE_DICT切到 SHARDED_STATE_DICT + DCP
加载 HF safetensors 后训练异常weight 没正确切片到 DTensordistributed_state_dict_helper 帮助加载
forward 与 eager 数值不一致precision 配置错检查 MixedPrecision policy 的 reduce_dtype

把这套表打印贴工位,FSDP 调试效率能提升 5x 以上。

18.6.33 FSDP 的 unshard 时机决策

FSDP-2 内部的 unshard 调度算法:

flowchart TB
    Start[forward 开始]
    Start --> Stage{当前 unit?}
    Stage -->|root unit| RootUnshard[unshard root + 立即 prefetch unit 1]
    Stage -->|sub unit| SubCheck{已被 prefetch?}
    SubCheck -->|是| Use[直接用]
    SubCheck -->|否| LazyUnshard[紧急 unshard]

    Use --> Compute[执行 forward]
    Compute --> Reshard[reshard 当前 unit]
    Reshard --> Prefetch[启动 prefetch 下下个 unit]
    Prefetch --> Stage

    style RootUnshard fill:#fef3c7
    style LazyUnshard fill:#fee2e2

核心规则:当前 unit 计算时,下一个 unit 的 AllGather 应该已经在飞forward_prefetch=True 让这套机制自动工作。

如果实际跑出来 prefetch 没赶上(latency-bound 场景),会触发 lazy unshard —— 当前 unit 等 AllGather 完成,整个 timeline 出现 gap。这时 profiler 的 distributed view 能直接看到。优化:调整 wrap policy 让 unit 更小(每个 unit AllGather 时间更短、prefetch 更容易及时)。

18.6.33.5 grad accumulation 的 FSDP 显存账

§18.6.24 提了 no_sync 跳过中间 ReduceScatter。但还有显存账值得展开:

正常 FSDP grad accumulation 流程:

micro_step 1: forward (unshard / reshard) + backward (有 ReduceScatter)
              → grad 保留在 1/N 切片
micro_step 2: 同上, grad 累积到 1/N 切片
...
micro_step N: 同上 + optimizer.step

每 micro step 都触发 ReduceScatter,累积 N 次通信。no_sync 模式:

with model.no_sync():    # 内部不触发 ReduceScatter
    micro_step 1: forward + backward (grad 保留为完整, N 倍显存)
    micro_step 2: grad 累积, 仍是完整
    ...
    micro_step N-1: 同上

# 退出 no_sync
micro_step N: forward + backward (这次触发 ReduceScatter, grad 切到 1/N)
optimizer.step()

关键差异:no_sync 期间 grad 是完整的(不分片),显存 N 倍上升。如果 N=8 + grad 大小 280GB → 单 rank 持有 2240GB(不可能装下)。

所以 FSDP 下用 no_sync 必须 配 cpu_offload 或者 ckpt 卸载 grad 才能避免显存爆炸。生产代码里 FSDP grad accumulation 通常不用 no_sync,接受 N 次通信开销 —— 这是 FSDP 与 DDP 工程取舍的不同。

18.6.34 FSDP 与其他并行框架对比

框架核心思想对比
PyTorch FSDP-2DTensor + DeviceMesh与 PyTorch 生态深度集成、与 compile 兼容
DeepSpeed ZeROShard params/grads/state更早实现,是 FSDP 的灵感来源;与 PyTorch 集成靠 wrapper
Megatron-LMTP + PP(无 DP shard)NVIDIA 主推,专攻 TP/PP,DP 用 DDP
ColossalAI多策略统一抽象国产框架,支持多种并行策略
OneFlow全局视角自动并行算子级自动决定并行策略

PyTorch FSDP-2 的工程优势是与 torch.compile / torch.export / DTensor 共生 —— 不是独立框架而是 PyTorch 生态的一等公民。其他框架要么是 wrapper(依赖 PyTorch 但侵入式扩展),要么是独立运行时(迁移成本高)。

实战选择:

  • 用 PyTorch 训练 → 用 FSDP-2
  • 已有 DeepSpeed 代码 → 继续用 DeepSpeed(迁移成本高、收益不一定值得)
  • 极致 TP/PP 优化 → Megatron-LM
  • 国产芯片支持 → ColossalAI / 厂商定制

18.6.35 FSDP × HF Transformers / Lightning 实战

HuggingFace Trainer 与 PyTorch Lightning 都支持 FSDP,但配置方式不同:

HuggingFace TrainingArguments

training_args = TrainingArguments(
    fsdp="full_shard auto_wrap",
    fsdp_config={
        "transformer_layer_cls_to_wrap": ["LlamaDecoderLayer"],
        "min_num_params": 0,
    },
)

PyTorch Lightning Strategy

trainer = pl.Trainer(strategy="fsdp", devices=8, ...)
# 或者更细粒度
trainer = pl.Trainer(strategy=FSDPStrategy(
    auto_wrap_policy={LlamaDecoderLayer},
    sharding_strategy="FULL_SHARD",
))

两个框架内部都调 fully_shard —— 只是包装好让用户不用手写 wrap_policy。但理解原始 API 让你能在 Trainer / Lightning 配置出错时定位是哪一层包装的问题。

国内 Llama-Factory、Firefly 等微调框架也内置 FSDP 支持,配置类似 HF Trainer。

18.6.35.5 FSDP 与 elastic 训练的协作

第 17 章 §17.8.25 / §17.9.7 讲过 elastic(torchrun + max_restarts)。FSDP 与 elastic 协作时有几个细节:

  • 重启时 mesh 配置必须一致:如果重启时用了不同 world_size,必须用 DCP 加载(自动 reshard 到新 mesh)
  • ckpt 频率要够高:失败重启等于丢失从最近 ckpt 到现在的所有计算。70B 训练通常每 1000 步存一次(约 30 分钟)
  • DCP 的写入要 robust:跨节点写文件可能因为某 rank 网络抖动失败,DCP 要支持”部分 rank 重写”

实战:torchrun + DCP + 每 1000 步 ckpt 是 LLM 训练的标配。失败时整 job 重启、加载最近 ckpt 继续训。一周训练 70B 模型期间故障 5-10 次是常态,没有 elastic 自动恢复就要全人工介入。

18.6.36 实战训练流程的完整推荐

70B 模型从零训练的标准 FSDP-2 流程:

# 1. 设置环境
import os
import torch
import torch.distributed as dist
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import fully_shard, MixedPrecision
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
    apply_activation_checkpointing,
)

# 2. 初始化 process group
dist.init_process_group(backend='nccl')
local_rank = int(os.environ['LOCAL_RANK'])
torch.cuda.set_device(local_rank)

# 3. 创建 DeviceMesh (32 卡 4 节点 HSDP)
mesh = init_device_mesh("cuda", (4, 8), mesh_dim_names=("inter", "intra"))

# 4. meta device 构造模型
with torch.device("meta"):
    model = LlamaModel(config)

# 5. activation_checkpoint 每层
apply_activation_checkpointing(
    model,
    check_fn=lambda m: isinstance(m, LlamaDecoderLayer),
)

# 6. FSDP-2 wrap (HSDP)
mp_policy = MixedPrecision(
    param_dtype=torch.bfloat16,
    reduce_dtype=torch.float32,
    buffer_dtype=torch.bfloat16,
)
for layer in model.layers:
    fully_shard(layer, mesh=mesh, mp_policy=mp_policy)
fully_shard(model, mesh=mesh, mp_policy=mp_policy)

# 7. lazy init weights
model.to_empty(device='cuda')
init_model_weights(model)

# 8. compile
model = torch.compile(model)

# 9. optimizer 与 scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, fused=True)
scheduler = ...

# 10. 训练循环
for batch in loader:
    optimizer.zero_grad()
    loss = model(batch).loss
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)   # FSDP 自动处理 DTensor
    optimizer.step()
    scheduler.step()

整套流程把全书前面章节的内容串起来:DeviceMesh(§18.6.5)+ FSDP-2(§18.5)+ MixedPrecision(§18.6 + 第 20 章)+ activation_checkpoint(§18.6.9)+ HSDP(§18.6.7)+ torch.compile(第 12-15 章)+ fused optimizer(第 10 章)+ DTensor grad clip(§18.6.16)。

这条流程是当今 70B+ LLM 训练的事实标准。理解整章后回到这段代码,每行配置都对应某个具体的工程决策。

18.6.37 FSDP 调试:常用日志与工具

调 FSDP 训练的几个常用工具:

TORCH_DISTRIBUTED_DEBUG=DETAIL:开启后所有 collective 在 NCCL kernel 前后打印 tag + tensor 信息,能精确定位”哪个 rank 卡在哪个 collective”。

TORCH_LOGS=fsdp:打印 FSDP 的 unshard / reshard / prefetch 决策,看是否符合预期。

torch.distributed.fsdp.FullyShardedDataParallel.print_runtime_summary(model):训练几个 step 后打印每个 unit 的统计(unshard 次数、平均时间、prefetch 命中率等)。是 FSDP 自带的”轻量 profiler”。

chrome trace + distributed view(第 21 章 §21.9):终极调试工具。能看到每 rank 的 collective timeline + 各 rank 等待时间。

实战调试流程:先开 TORCH_LOGS=fsdp 看决策对不对、再用 chrome trace 看实际 timeline。多数 FSDP 性能问题(如 unshard 没及时 prefetch)能在这两层定位。

18.6.38 FSDP 设计上的几个隐形约束

FSDP 不是万能的。几个不能突破的工程约束:

1. 每个 unit 的 unshard 必须能装下完整 params:FSDP 的最大单 unit param 大小不能超过单 rank GPU 显存。70B 单层 1B 参数 → 4GB(fp32)单 rank 装得下。但 1T 模型单层 100B → 400GB 装不下,FSDP 必须配合 TP 才能跑。

2. 同一 unit 的所有 param 必须同 dtype:FlatParameter(FSDP-1)合并要求 dtype 一致;FSDP-2 用 DTensor 没这个限制但一个 unit 同 dtype 仍是最优。

3. unit 边界不能跨越复杂的 module 控制流:if-branch / loop 内的 module 不适合自己成为 wrap unit,因为不是每次 forward 都被调到,prefetch 决策困难。

4. FSDP-2 的 mesh 一旦决定就不能改:训练中途换 mesh 配置(如从 8 卡变 16 卡)需要先 ckpt 落盘 → 用新 mesh 重 init → DCP load 时自动 reshard。不能”在线变”。

理解这些约束让你设计训练架构时心里有数。“一切都用 FSDP” 不是答案 —— 极限场景下还要 TP / PP / 自定义并行。

18.7 几条工程经验

1. ShardingStrategy 选择:30B 以下用 SHARD_GRAD_OP(ZeRO-2,省 1/2 显存);70B 以上用 FULL_SHARD;跨节点带宽不足用 HYBRID_SHARD

2. 用 v2.4+ 的 FSDP-2 (fully_shard):除非你已经有大量 FSDP-1 代码,否则新代码直接用 FSDP-2

3. forward_prefetch=True + backward_prefetch=BACKWARD_PRE 默认开启:是 overlap 的核心

4. wrapping policy 设到合适粒度:每个 transformer block 单独 shard 通常最优。粒度太细(每个 Linear 都 shard)通信开销过大;太粗(整个模型一个 shard)overlap 不好

5. use_orig_params=True:让 FSDP 不破坏 param 引用,原 optimizer 能直接复用。FSDP-1 的兼容选项,FSDP-2 默认就这样

6. cpu_offload=CPUOffload(offload_params=True):参数卸到 CPU,省更多显存,代价是 H2D 拷贝拖慢训练。仅在显存极紧张时用

7. checkpoint 与 FSDP:FSDP 的 ckpt 用 torch.distributed.checkpoint(DCP)存分布式格式。每 rank 只存自己那 1/N,避免 rank 0 写 1TB 文件的瓶颈

8. activation_checkpoint 与 FSDP 叠加用:FSDP 省 params/grads/optimizer 显存,checkpoint 省 activation 显存。两者正交、可叠加

18.8 实战决策路径

flowchart TD
    Start[要训练大模型]
    Start --> Size{模型 / GPU?}
    Size -->|装得下| DDP[用 DDP]
    Size -->|装不下| Bandwidth{跨节点带宽?}
    Bandwidth -->|InfiniBand 高带宽| Full[FULL_SHARD]
    Bandwidth -->|普通 ethernet| Hybrid[HYBRID_SHARD<br/>节点内 shard 跨节点 DDP]
    Full --> Mp[配 MixedPrecision]
    Hybrid --> Mp
    Mp --> Ckpt[启用 activation_checkpoint]
    Ckpt --> Compile[用 torch.compile 提速]

    style Full fill:#dcfce7,stroke:#22c55e
    style Hybrid fill:#fef3c7,stroke:#f59e0b

70B Llama 训练的典型配置:32 张 H100、HYBRID_SHARD(同 8 卡节点 ZeRO-3、跨 4 节点 DDP)、bf16 通信 + fp32 reduce、每个 transformer block 单独 fully_shard、activation_checkpoint 每 4 层一次。

18.8.5 FSDP 与 PyTorch 演进路线

FSDP 在 PyTorch 演进路线上的位置:

  • v1.10 (2021):FSDP-1 实验性引入,灵感来自 DeepSpeed ZeRO
  • v1.12 (2022):FSDP-1 成为 stable,开始被生产使用
  • v2.0 (2023):FSDP-1 与 torch.compile 集成尝试,但兼容性差
  • v2.2 (2024 初):FSDP-2 (fully_shard) prototype 推出
  • v2.4 (2024 中):FSDP-2 成为推荐 API,FSDP-1 进入维护模式
  • v2.11 (2026):FSDP-2 与 DTensor / DeviceMesh 深度整合,成为新一代分布式训练标准

FSDP-1 大概率会在 v3.x 时代被完全弃用 —— FSDP-2 在性能、灵活性、工具兼容性上全面胜出。但因为 FSDP-1 在生产训练里使用广泛,PyTorch 团队承诺至少维护到 v3.0+。

国内训练框架(如华为 MindSpore 的并行模式、字节 ByteCheckpoint 等)也大量借鉴 FSDP / ZeRO 思想。理解 FSDP 不只是理解 PyTorch 一个工具,是理解整个大模型训练时代的工程基础

18.8.6 完整决策树:用 DDP 还是 FSDP-2?

flowchart TB
    Start[要训练大模型]
    Start --> SizeCheck{单 rank 装得下完整 params + grads + optimizer state?}
    SizeCheck -->|是| DDP[DDP - 性能最优]
    SizeCheck -->|否, 装不下 optimizer state| ZERO2[FSDP SHARD_GRAD_OP - ZeRO-2]
    SizeCheck -->|否, 连 params 也装不下| ZERO3[FSDP FULL_SHARD - ZeRO-3]

    ZERO2 --> Bandwidth{跨节点带宽足够?}
    ZERO3 --> Bandwidth

    Bandwidth -->|是, NVLink/IB| Flat[平面 mesh: 单维度 shard]
    Bandwidth -->|否, ethernet| Hybrid[HSDP 2D mesh: 节点内 shard 跨节点 replicate]

    Flat --> Compile{要 torch.compile?}
    Hybrid --> Compile
    Compile -->|是| FSDP2[FSDP-2 fully_shard API]
    Compile -->|否, 用老代码| FSDP1["FSDP-1 (维护模式)"]

    style DDP fill:#dcfce7
    style FSDP2 fill:#fef3c7,stroke:#f59e0b,stroke-width:2px
    style FSDP1 fill:#fee2e2

这条决策树覆盖 95% 的大模型训练场景。剩下 5% 是极端场景(如 1T+ 模型必须配 TP/PP)需要专门设计。

18.9 跨书关联

  • 第 16 章 ProcessGroup:FSDP 的 AllGather / ReduceScatter 都通过 ProcessGroupNCCL
  • 第 4 章 Caching Allocator:FSDP 的 unshard 触发临时分配大量显存,对 allocator 压力大,常需 expandable_segments=True
  • 第 7 章 §7.5.3 saved_tensors_hooks:FSDP 的 activation checkpoint 与这套 hook 配合
  • 第 13 章 AOTAutograd:FSDP-2 的 collective 也走 AOTAutograd,让通信 op 被 Inductor 编译

18.9.5 FSDP × Engine:与第 8 章的协作

第 17 章 §17.10.7 讲过 DDP × Engine。FSDP 的协作更复杂:

  • forward 时插入 unshard hook:FSDP 通过 register_forward_pre_hook(第 9 章 §9.8)在 forward 前触发 unshard
  • forward 后插入 reshard hookregister_forward_hook 触发 reshard
  • backward 时通过 autograd.Function 重新触发 unshard:与 §13 AOTAutograd 类似的机制
  • backward 完成的 ReduceScatter 通过 grad accumulator post hook 触发(与 DDP §17.8.10 类似但用 ReduceScatter 而非 AllReduce)

整套机制让 FSDP 把分片调度寄生在 PyTorch 的 hook 体系上。Engine 完全不知道 FSDP 存在,只是按 DAG 调度反向 —— hook 触发的 collective 是 Engine 视角的”普通副作用”。

这种”分布式策略寄生在框架 hook 之上”的设计是 PyTorch 分布式训练能持续演进的根本(DDP / FSDP / 用户自定义并行都共享 hook 接口)。第 23 章设计哲学会再回到这条线索。

18.9.6 整章信息密度的小结

读完 Ch 18 你应该能:

  • 决策:DDP / ZeRO-2 / ZeRO-3 / HSDP 各自适用场景(§18.8.6 决策树)
  • 配置:选 wrap_policy 粒度、配 MixedPrecision、决定是否开 cpu_offload
  • 理解:unshard / reshard 时机、prefetch 怎么 overlap、min-cut activation_checkpoint 怎么决定
  • 调试:错误诊断速查表 + chrome trace 看 timeline
  • 接生态:HF Trainer / Lightning / DCP / torch.compile 怎么协作
  • 预判演进:FSDP-1 → FSDP-2 → 未来与 DTensor 进一步融合

70B Llama 训练的实战配置(§18.6.36)把整章串起来 —— 每行配置都对应某节的工程决策。这是当今大模型训练的事实标准,理解它就理解了”为什么 LLM 训练长这样、不能简化成更朴素的形式”。

18.10 设计启示

FSDP 的核心思想:

第一显存与计算 / 通信可以互换:FSDP 把显存压力转成通信压力,让”训练大模型”从”硬件极限”变成”调优问题”

第二分片 + 临时凑齐是分布式数据结构的通用模式:DTensor / 分布式哈希表 / 分片数据库都用这套思路。每节点常态只持有部分数据,需要时凑齐

第三prefetch 是 overlap 通信的标配:任何”计算 - 通信 - 计算”链路都要考虑 prefetch,让通信延迟隐藏到计算时间里

第四模块级 wrap 比模型级 wrap 灵活:FSDP-2 的 fully_shard(submodule) 思想可以借鉴到所有”框架增强 module”的场景

下一章拆序列化:torch.save / torch.load + safetensors + Distributed Checkpoint,看大模型训练的 ckpt 怎么管理。

评论 0