第19章 序列化:torch.save / torch.load 与权重格式

“Saving a tensor sounds trivial. Saving 70 billion of them across 32 GPUs without 1TB of RAM —— that’s another story.”

—— PyTorch DCP design doc

本章要点

  • torch.save 默认用 pickle + zip 容器:每个 tensor 的 storage 单独存一个 zip entry,让 mmap 加载成为可能
  • weights_only=True 是新默认(v2.6+):只允许加载白名单类型,防止 pickle RCE 攻击
  • safetensors 是 HuggingFace 推动的替代格式:纯字节布局 + JSON header,零代码执行风险,加载比 pickle 快
  • torch.save(model.state_dict()) vs torch.save(model):前者只存权重(推荐),后者 pickle 整个对象(依赖类定义)
  • Distributed Checkpoint (DCP):千卡训练每 rank 只存自己那 1/N,避免 rank 0 写 1TB 单文件
  • mmap 加载 (mmap=True) 让超大 ckpt 加载几乎瞬间完成:OS 把文件映射到虚拟内存,按需 page-in

19.1 torch.save 的真实格式

打开一个 .pt 文件用 unzip -l 能看到:

Archive:  model.pt
  data.pkl                    <- pickle 序列化的 Python 对象图 (无 tensor 数据)
  data/0                      <- 第 1 个 tensor 的字节数据
  data/1                      <- 第 2 个 tensor
  ...
  version                     <- 格式版本号
  byteorder                   <- 字节序标记

torch.save 用的是 zip 容器 + 拆分 storage 的设计:

  1. 把对象树(dict / list / nn.Module)用 pickle 序列化,但每个 tensor 写到一个占位符而非数据
  2. 每个 storage 的实际字节数据单独存到 zip 里一个 entry(data/0data/1 …)
  3. 加载时先反序列化 pickle 拿对象骨架 + 占位符,再按需读 zip entry 填字节

MAGIC_NUMBER = 0x1950A86A20F9469CFC6Cserialization.py:65)是文件头标记。

19.1.1 为什么不直接 pickle 整个 dict

朴素 pickle:把所有 tensor 的字节当 bytes 序列化进 pickle 流。问题:

  • pickle 流是顺序的,没法 mmap(要全部读到内存才能解析)
  • 修改 / 添加单个 tensor 要重写整个文件
  • 跨进程共享时序化开销大

zip 容器格式:

  • 每个 storage 在 zip 里有独立 entry,支持 mmap
  • zip 提供 random access:直接跳到第 N 个 storage
  • 兼容现有工具(unzip / zipfile 都能查看)

代价是文件比纯 pickle 略大(zip header 开销)。但对 1GB+ 模型 ckpt 几乎不可见。

19.2 mmap 加载:超大 ckpt 几秒完成

torch.load(path, mmap=True) 让 OS 把整个 zip 文件映射到虚拟内存,实际不读到 RAM

# 没有 mmap: 必须把整个文件读到内存
state = torch.load('llama-70b.pt')   # 280 GB, 单机 RAM 不够

# 有 mmap: OS 按需 page-in
state = torch.load('llama-70b.pt', mmap=True)   # 几乎秒返回
weights = state['model.layer1.weight']           # 用到时才真读

mmap 的几个优势:

  • 多个进程加载同一个 ckpt 时共享 OS page cache,省一份内存
  • 加载时间几乎为零(懒加载)
  • 不连续访问的部分根本不读盘

vLLM / HuggingFace 加载大模型时大量用 mmap。生产代码里这是 ckpt 加载的标准操作。

19.3 weights_only:防止 pickle RCE

pickle 是 Python 的”任意对象序列化”格式,加载时会执行嵌入的代码。恶意 ckpt 可以 __reduce__ 触发 os.system("rm -rf /")

PyTorch v2.6+ 默认 weights_only=True:只允许加载白名单类型(Tensor、Parameter、各种内置类型),其他全部拒绝。这让 ckpt 加载从”等同于 eval(file)”变成”等同于读 JSON”。

如果你的 ckpt 含自定义类(如自定义 Optimizer),要么改用白名单允许:

torch.serialization.add_safe_globals([MyOptimizer])
state = torch.load('ckpt.pt')

要么显式 weights_only=False 退回老行为(不推荐生产用)。

社区从 HuggingFace 的”safetensors 推动”开始一直在向”任何 ckpt 都不该 RCE”靠拢。weights_only 默认 True 是 PyTorch 跟进的安全升级。

19.4 safetensors:HuggingFace 推动的替代格式

safetensors 不是 PyTorch 内置(要 pip install safetensors),但已经成了 HuggingFace 模型分发的事实标准。文件格式:

[8 字节: header 长度]
[N 字节: JSON header — 描述每个 tensor 的 name / dtype / shape / 字节范围]
[剩余: 紧凑 tensor 字节数据,按 header 里描述的偏移排列]

关键设计:

  • 零代码执行:纯字节 + JSON,加载就是”按 header 切字节”。安全 100%
  • 比 pickle 快:没有反序列化开销,直接 mmap
  • 零拷贝加载:tensor 直接 view 在 mmap 内存上
  • 跨语言友好:Rust / JS / Go 都能读

PyTorch 端用 safetensors:

from safetensors.torch import save_file, load_file

save_file(model.state_dict(), 'model.safetensors')
loaded = load_file('model.safetensors')

HuggingFace Hub 上的现代模型几乎都用 safetensors(如 Llama / Qwen / DeepSeek 的官方权重)。.bin(pickle)格式仍然存在但被标记为”legacy”。

实测对 70B 模型:

  • pickle (.bin) 加载:~120 秒
  • safetensors 加载:~30 秒
  • safetensors + mmap:< 1 秒

这是为什么 vLLM 默认要求 safetensors。

19.5 model.state_dict() vs torch.save(model)

# 推荐
torch.save(model.state_dict(), 'weights.pt')

# 不推荐
torch.save(model, 'whole_model.pt')

差异:

state_dict()整对象 pickle
只存权重数据存整个对象(含类定义引用)
加载时要先 build model 类加载时自动 reconstruct
修改类定义不影响加载改类定义可能加载失败
文件小文件大(含 Python 类元数据)
跨 PyTorch 版本兼容好跨版本风险

第 9 章 §9.6.0.5 讨论过这个区别。生产代码强烈建议 state_dict。整对象 pickle 只在快速保存中间状态时偶尔用。

19.6 Distributed Checkpoint (DCP)

70B 模型 ckpt ≈ 280 GB(参数)+ 560 GB(optimizer state)+ … ≈ 1 TB。如果 rank 0 收集所有 rank 数据再写一个文件:

  • rank 0 单机要 1 TB RAM 收集
  • rank 0 串行写 1 TB 到磁盘几十分钟
  • 加载时反过来 —— 又要几十分钟

DCP(torch.distributed.checkpoint)的解法:每个 rank 只写自己持有的数据,并行写到一个目录

import torch.distributed.checkpoint as dcp

# 保存
state_dict = {'model': model.state_dict(), 'optim': optimizer.state_dict()}
dcp.save(state_dict, checkpoint_id='ckpt_step_1000')

# 加载 (跨任意 rank 数恢复)
dcp.load(state_dict, checkpoint_id='ckpt_step_1000')

DCP 输出的不是单文件,而是一个目录:

ckpt_step_1000/
  __0_0.distcp                <- rank 0 的数据
  __1_0.distcp                <- rank 1 的数据
  ...
  metadata.json               <- 描述每片数据的逻辑映射

每 rank 写自己的 __N_0.distcp 文件,完全并行。加载时 DCP 根据 metadata.json 把每片正确路由回需要它的 rank(即便加载时 rank 数与保存时不同 —— 自动 reshard)。

这套机制让大规模训练 ckpt 从”几十分钟串行写”变成”几分钟并行写”,是 70B+ 训练的工程基础。

19.6.1 DCP 的 reshard 能力

最强的特性:保存时 8 卡 ZeRO-3,加载时 16 卡 ZeRO-3,DCP 能自动重新切分。

实现机制:metadata.json 用 逻辑张量坐标(不是物理 rank)描述每片数据。加载时根据当前 rank 拓扑重新映射 —— 保存时 rank 0 持有 [0:1000] 切片,加载时如果 rank 0 现在应该持有 [0:500],DCP 自动切。

这种”加载时 reshard”让训练流水线灵活得多 —— 训练用 64 卡,fine-tuning 切到 8 卡只需重新加载,不需要离线 resharding。

19.6.5 storage 共享检测:避免存重复字节

PyTorch 张量经常共享 storage(第 2 章 §2.6 view 机制)。torch.save 不会傻傻地把每个 tensor 的字节都存一份 —— 它做了storage 共享检测

a = torch.randn(10, 10)
b = a.view(100)             # b 与 a 共享 storage
c = a[0:5]                  # c 也共享

state = {'a': a, 'b': b, 'c': c}
torch.save(state, 'shared.pt')
# 文件只含 1 份 storage 字节, 不是 3 份

实现机制是 pickle 的 persistent_id / persistent_load 钩子:

  • save 时:每个 tensor 序列化前调 persistent_id(tensor) 把它的 storage 用一个 ID 替代(同一 storage 返回同一 ID)。zip 里每个唯一 ID 对应一个 data/N entry
  • load 时:每个 ID 解析时调 persistent_load(id),从 zip 取出对应字节、构造 storage,再多个 tensor 共用这一份 storage

实测:3 个共享 storage 的 tensor save 后,文件大小 ≈ 1 份 tensor 大小(不是 3 份)。这套机制对 view-heavy 模型(如 transformer 的 QKV 投影)能省可观空间。

但有个反直觉的坑:view 张量也存了它的”完整 base storage”,不是它自己的子集。如果你 b = a[0:5]; del a; torch.save({'b': b}),文件大小是整个 a 的,不是 5 个元素。要省空间得 b = a[0:5].clone() 让 b 有独立 storage。

19.6.6 PyTorchFileReader / PyTorchFileWriter

torch._C.PyTorchFileReader / PyTorchFileWriter(C++ 在 caffe2/serialize/inline_container.cc)是 zip 容器的底层。它们不是用标准 zip 库,是 PyTorch 自家实现的精简版 —— 只支持 STORE(不压缩)模式,因为:

  • tensor 字节本身已经接近随机,压缩率低(几乎 1.0)
  • 压缩开销让 save / load 慢几倍
  • 不压缩让 mmap 直接映射文件 = 内存(mmap 对压缩文件无效)

副作用:.pt 文件比同样数据的 .npz(用 zip DEFLATE 压缩)大约 1.0-1.05x,但加载快 3-5x。这是性能 vs 空间的明确选择。

PyTorchFileWriter 还支持 multi-stream write:多个 storage 字节并行写到 zip 里不同 entry。70B 模型 ckpt 用 4 个 IO thread 写比单线程快 2-3x。torch.save 默认开启这个优化。

19.6.7 DCP 的 SavePlan / LoadPlan

torch/distributed/checkpoint/planner.py:105/115SavePlan / LoadPlan 是 DCP 的 IR:

@dataclass
class SavePlan:
    items: list[WriteItem]            # 每个 item 描述一段要写的数据
    storage_data: Any                  # backend-specific 数据
    planner_data: Any                  # planner 状态

@dataclass
class WriteItem:
    index: MetadataIndex               # 逻辑索引 (fqn + offsets)
    type: WriteItemType                 # TENSOR / BYTE_IO
    tensor_data: TensorWriteData | None
    ...

DCP save 流程:

  1. 每个 rank 调 planner.create_local_plan(state_dict) 生成自己的 SavePlan
  2. rank 0 收集所有 SavePlan、调 planner.create_global_plan(plans) 做去重 + 优化
  3. broadcast 全局 plan 回所有 rank
  4. 每个 rank 按 plan 调 storage.write_data(plan) 写自己负责的部分
  5. rank 0 写 metadata.json

LoadPlan 反向:每 rank 描述”我需要哪些 tensor”,DCP 根据 metadata.json 路由到正确的物理 chunk。

这套 plan-based 设计让 DCP 能换 storage 后端:默认 FileSystemWriter(写本地 / NFS),也可换 S3Writer(写云存储)、GCSWriter 等。第三方实现 StorageWriter / StorageReader 接口即可,业务层无感知。

FileSystemReaderfilesystem.py:837)与 FileSystemWriter:965)是默认实现 —— 写到目录里若干 .distcp 文件 + 一个 .metadata

19.6.8 Resharding 算法:8 卡 ckpt 加载到 16 卡

§19.6.1 提到 DCP 能 reshard,具体怎么做?看 torch/distributed/checkpoint/_dedup_save_plans.py + _resharding.py

场景:保存时 8 卡 FSDP,每 rank 持有 weight 的 1/8 切片;加载时 16 卡 FSDP,每 rank 应该持有 1/16 切片。算法:

graph TB
    Save["保存时 (8 卡):<br/>rank 0: weight[0:N/8]<br/>rank 1: weight[N/8:2N/8]<br/>..."]
    Save --> M["metadata.json:<br/>'weight' chunks 用逻辑坐标:<br/>chunk[0:N/8], chunk[N/8:2N/8], ..."]
    M --> Load["加载时 (16 卡):<br/>每 rank 算自己需要 weight[i*N/16:(i+1)*N/16]"]
    Load --> P{逻辑切片<br/>vs 物理 chunk 关系}
    P --> Calc["rank 0 需要 weight[0:N/16]<br/>= 物理 chunk[0:N/8] 的前一半"]
    Calc --> Read[读 rank 0 chunk 文件<br/>取前 N/16 字节]

    style M fill:#fef3c7
    style Calc fill:#dbeafe

具体步骤(load.py:LoadPlanner._populate_load_items):

  1. 解析 metadata.json 拿到所有物理 chunk 的逻辑坐标范围
  2. 对当前每 rank 的需求范围 R_i,找出它与哪些物理 chunk 有交集
  3. 为每个交集生成一个 LoadItem:从某物理 chunk 的某偏移读多少字节、写到本地 tensor 的某偏移
  4. 调度 LoadItem:尽量批量读同一文件、最小化 IO

特殊情况处理

  • 新 rank 数 > 旧:每物理 chunk 被多个 rank 共享读(多读不冲突)
  • 新 rank 数 < 旧:每 rank 读多个物理 chunk 拼起来
  • dim 改变:DCP 不支持自动改 sharding 维度(如从 dim 0 sharding 改成 dim 1),需要离线工具

实测:1024 卡训练保存 → 64 卡 fine-tune 加载,DCP 自动 reshard 用时 5-10 分钟(取决于网络)。无需”先 gather 到全量再切” —— 直接按 chunk 重新映射。

理解 reshard 让你看到 DCP 不是”分布式 save/load”那么简单,而是”逻辑坐标系统 + 物理 chunk 路由”的工程。这是大模型生命周期管理(pretrain → fine-tune → distill)的关键基础设施。

19.6.9 Async Checkpoint:训练不被 ckpt 阻塞

torch/distributed/checkpoint/state_dict_saver.py:async_save(v2.4+)让 ckpt 写入与训练并发

import torch.distributed.checkpoint as dcp

# 同步版本: 训练等 ckpt 写完才继续
dcp.save(state_dict, ...)
# next training step starts...

# 异步版本: 训练立刻继续, 后台慢慢写
future = dcp.async_save(state_dict, ...)
# next training step starts immediately
# ...
future.result()  # 偶尔等一次, 确认上次 ckpt 写完

实现机制:

graph LR
    T[训练主线程] --> CP[snapshot 整个 state_dict<br/>1. 把 GPU tensor 异步 copy 到 CPU<br/>2. 用同一份 CPU 内存的 ref 创建 staging dict]
    CP --> RT[训练继续<br/>下一 step]
    CP --> BG[后台线程<br/>写 staging dict 到磁盘]
    BG --> Done[完成 future]

    style CP fill:#fef3c7
    style BG fill:#dcfce7

关键点:snapshot 阶段必须同步(要确保后台写的是当前 step 状态、不是几 step 后的)。snapshot 用 cudaMemcpyAsync 把 GPU 数据拷到 CPU pinned memory —— 几百 ms。然后训练继续;后台线程慢慢把 CPU 数据写盘 —— 几十秒。

实战收益:每 100 step 写一次 ckpt,每次 ckpt 阻塞 60s → 训练吞吐损失 60/(100×step_time)。step_time = 1s 时 → 损失 60% 吞吐!开 async ckpt 后损失降到 1-2%。

代价:需要 2x CPU 内存(一份给当前训练 + 一份 staging)。70B 模型 ckpt 占 140 GB CPU 内存 —— 单机 256 GB RAM 通常够。如果 CPU 紧张,不能开 async

理解 async ckpt 让你看 v2.x 大模型训练的真实工程取舍:宁愿多吃 CPU 内存也不让 ckpt 写阻塞训练,因为 GPU 时间贵得多。

19.6.10 PEFT / LoRA 的 ckpt 增量保存

LoRA / Adapter / Prefix-tuning 等参数高效微调(PEFT)的核心是”只训练少量参数”。ckpt 也只需保存这些参数:

# 全量 fine-tune: state_dict 含所有 7B 参数
torch.save(model.state_dict(), 'full_ft.pt')   # ~14 GB

# LoRA: 只存 LoRA 模块的参数
lora_state_dict = {k: v for k, v in model.state_dict().items() if 'lora' in k}
torch.save(lora_state_dict, 'lora.pt')          # ~50 MB

peft 库(HuggingFace)做了完整封装:model.save_pretrained() 自动只保存 LoRA 权重 + adapter_config.json(描述秩、target_modules 等)。

加载流程:

  1. 先加载 base model(如 Llama-7B 全量权重)
  2. 加载 adapter_config.json 知道在哪些 module 上加 LoRA
  3. 加载 lora.pt 把 LoRA 权重塞进对应 module

这种”base model + small adapter”模式的工程意义:

  • 存储成本:1 个 base + 100 个 adapter(每个 50 MB)= 14 GB + 5 GB = 19 GB;vs 100 个全量 = 1.4 TB
  • 分发:用户可单独下载 adapter,不重复下 base
  • 服务:vLLM 等支持 base + 多 adapter 共存推理(Multi-LoRA Serving),只切换 adapter 不重 load base

对应到 PyTorch 序列化层面:state_dict 是 dict,支持任意子集 —— 这是 LoRA 这种增量保存的基础设施。load_state_dict(strict=False) 允许部分 key 不存在 → 只覆盖匹配的 key。这层设计让 PEFT 能优雅落地。

19.6.11 断点续训:随机状态完整保存

光保存 model + optim 还不够。完整可恢复 ckpt 必须包括:

ckpt = {
    'model': model.state_dict(),
    'optim': optimizer.state_dict(),
    'scheduler': lr_scheduler.state_dict(),
    'sampler': sampler.state_dict(),         # §11.9.18
    'epoch': epoch,
    'step': step,
    'rng_states': {
        'cpu': torch.get_rng_state(),
        'cuda': torch.cuda.get_rng_state_all(),  # 所有 GPU 各自
        'numpy': np.random.get_state(),
        'python': random.getstate(),
    },
    'amp_scaler': scaler.state_dict() if scaler else None,
}
torch.save(ckpt, 'full_ckpt.pt')

为什么 RNG 也要存?两个场景:

  • augment 复现RandomCroptorch.rand 决定 crop 位置,rng 不一致 → 同一 sample 生成不同 augmented view → 不可复现
  • dropout / noise injection:reproducibility 要求

torch.set_rng_state 加载时恢复 CPU RNG;torch.cuda.set_rng_state_all 恢复每个 GPU 的 RNG。但 DataLoader worker 的 RNG 是另一套——每个 worker 启动时基于 base_seed + worker_id 计算,所以保存 sampler.state_dict 已经隐含了 worker RNG 的种子。

实战:HuggingFace Trainer / Accelerate 内置完整 RNG 保存。如果你自写训练循环,要么照抄上面的字典、要么直接用这些库。漏一个 RNG 类型就会导致”resume 后 loss 跳变”——很难诊断的 bug。

19.6.12 fsspec:直接读写云存储

PyTorch v2.x+ 集成了 fsspec(Python 文件系统抽象库),让 torch.save / torch.load 直接支持 S3 / GCS / Azure Blob / HDFS:

torch.save(state, 's3://my-bucket/ckpts/step_1000.pt')
state = torch.load('s3://my-bucket/ckpts/step_1000.pt')

底层调用 fsspec.open() 拿到一个 file-like 对象,PyTorch 不区分本地 / 云。

实战注意点:

  • 认证:要先配好 boto3 / gcloud CLI 让 fsspec 用上 credentials
  • 大文件:S3 multipart upload 自动启用(每段 100 MB),70B ckpt 上传几分钟(取决于带宽)
  • list / glob 慢:S3 list 操作是 O(N),DCP load 时如果用 globbing 找文件会慢;DCP 直接读 metadata.json 避免这个
  • eventually consistent:S3 写完不能立刻 list 出(几秒延迟)。生产代码 save 后等 5 秒再发”ckpt ready”通知

DCP 的 FileSystemWriter v2.6+ 支持 fsspec 路径:

dcp.save(state_dict, checkpoint_id='s3://bucket/ckpt/step_1000')

让 1024 卡训练每 rank 直接写 S3、不需要本地落盘 + 后台同步。云原生训练的标准模式。

19.6.13 GGUF / GGML:推理优化的 weight 格式

safetensors 是训练 / 通用分发场景,端侧 / CPU 推理有更专门的格式:GGUF(GGML 的下一代)。它专为 llama.cpp 等 CPU/Apple Silicon 推理框架设计:

维度safetensorsGGUF
量化支持不内置内置(int4/int5/int8 等多种)
元信息tensor 名 + shape加上 chat template、stop tokens、tokenizer 配置
加载方式mmap 友好mmap 友好 + 端侧推理优化
编辑器支持HuggingFace Hubllama.cpp / Ollama 生态
适用场景训练 / 服务器推理CPU / 端侧推理

PyTorch 这层不直接支持 GGUF,但社区有 transformers → GGUF 转换工具:

python convert.py --input model.safetensors --output model.gguf --quantize q4_k_m

转换后用 llama.cpp 加载、在 CPU / Mac M-series 上跑推理。这条路径让 PyTorch 训练的模型能落地到端侧 —— 训练用 PyTorch、部署给社区用 GGUF

理解 GGUF 在 PyTorch 生态外的存在让你看到 model 序列化是个多格式共存的世界。每种格式服务一类场景:safetensors 通用、.pt2 (AOTI) 服务器推理、.pte (ExecuTorch) 移动端、GGUF 端侧 CPU。PyTorch 自身只覆盖前几个,社区补足后两个。

19.6.14 HuggingFace from_pretrained:业界实际加载流程

PyTorch 用户极少手写 torch.load,更常见的是:

from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3-8B")

这一行的内部流程:

graph TB
    Call[from_pretrained] --> DL[1. 从 Hub 下载文件<br/>config.json / model.safetensors / tokenizer]
    DL --> Cache[本地 cache: ~/.cache/huggingface/hub/]
    Cache --> Cfg[2. 解析 config.json<br/>找到 model class + dtype + ...]
    Cfg --> Build[3. 实例化空 model<br/>只 build 网络结构, 权重 init 为 meta tensor]
    Build --> Sharded{4. 单文件还是 sharded?}
    Sharded -->|单文件| Single[加载 model.safetensors]
    Sharded -->|多文件| Multi[加载 model-00001-of-00007.safetensors<br/>+ model.safetensors.index.json]
    Single --> Load[5. load_state_dict to model<br/>用 mmap, 一次填一个 tensor]
    Multi --> Load
    Load --> Done[ready model]

    style Build fill:#fef3c7
    style Load fill:#dcfce7

几个工程精巧点:

  • meta tensor build:第 3 步用 torch.device('meta') 创建模型骨架,无显存分配。第 5 步把 mmap 字节直接 cast 到 GPU,跳过 “CPU full copy → GPU copy” 双重内存压力
  • sharded 文件命名model-00001-of-00007.safetensors 是事实标准,每个 5-20 GB(safetensors 单文件 size 上限相关)
  • index.json:含 {"weight_map": {"model.layers.0.weight": "model-00001-of-00007.safetensors"}} 的反查表 —— 加载某个权重时知道去哪个文件读
  • streaming load:边下载边加载(v0.20+ 的 transformers),适合大文件 first-time use

理解这套流程让你看 vLLM / SGLang / 自家推理服务的”模型加载几分钟”开销主要在哪 —— 一是网络下载、二是 GPU memory 分配、三是 cuda 上下文 init。真正的 deserialization 反而最快(mmap + zero-copy)。

19.6.15 sharded safetensors:单文件大小限制怎么处理

safetensors 单文件理论无限制,但实际 HuggingFace Hub 限制 每文件 ≤ 50 GB(防止下载超时)。70B model = 140 GB → 必须切到 3+ 个文件。

切分逻辑(transformerssave_sharded 函数):

  1. 按 fqn 顺序遍历所有 tensor
  2. 累计 size > shard_size_limit (默认 5GB) 时切到下一个文件
  3. 每个 tensor 不跨文件(一个 tensor 必须完整在一个文件里)
  4. 写完所有文件后生成 model.safetensors.index.json

注意点:

  • shard 大小:太小(< 1GB)→ 文件多、下载并发开销;太大(> 50GB)→ Hub 不接受。5-10 GB 最优
  • 顺序敏感:fqn 顺序变化会让 sharding 结果不一致 → ckpt 不可复现。HF 用 sorted(state_dict.keys()) 保证一致
  • 加载时按 index 找:每加载一个 weight,从 index 反查在哪个文件、再从该文件 mmap → 一次只读必要的字节

实战:HF Hub 的 Llama-70B 分成 8 个文件(17 GB 每个)。如果用户带宽 100Mbps、并发 4 → 总下载时间 = 17×8 / (100/8 × 4) ≈ 4 分钟。这就是 from_pretrained 等几分钟的根本原因。

19.6.16 跨设备加载:map_location 与 device_map

ckpt 保存时 tensor 在 GPU,加载时机器没那么多 GPU 怎么办?map_location 解决这个:

# 把所有 tensor 加载到 CPU
state = torch.load('ckpt.pt', map_location='cpu')

# 把所有 tensor 加载到 cuda:0 (即便保存时在 cuda:7)
state = torch.load('ckpt.pt', map_location='cuda:0')

# 自定义函数: 按 tensor name 决定去哪
def loader(storage, loc):
    return storage.cuda(0)
state = torch.load('ckpt.pt', map_location=loader)

底层:torch.load 在恢复 storage 时调 map_location 决定目标 device。

device_map 是 HuggingFace transformers 的高级版本:

model = AutoModelForCausalLM.from_pretrained(
    "Llama-70B",
    device_map="auto",      # 自动分配 layer 到多 GPU
)

device_map="auto" 算法:

  1. 估算每 layer 占多少 memory
  2. 按层数把 layer 平均分到所有 GPU
  3. 加载 weight 时直接 map 到目标 GPU(不经过 CPU)

实战:8 卡机器加载 Llama-70B(140 GB),device_map=“auto” 让每卡 ~17.5 GB。配合 offload_folder 还能让多余 layer 落到磁盘 → 单卡推理 7B、双卡 70B、单卡 + offload 70B 都能跑。

理解这层让你看到”内存不够也能 load 模型”不是魔法,是 PyTorch + HF 联合做的精细 device 管理。

19.6.17 ckpt 完整性校验

长时间训练 ckpt 损坏会让”resume 后 loss 直接 NaN”。生产建议加 checksum:

import hashlib

def save_with_checksum(state_dict, path):
    torch.save(state_dict, path)
    with open(path, 'rb') as f:
        sha256 = hashlib.sha256(f.read()).hexdigest()
    with open(path + '.sha256', 'w') as f:
        f.write(sha256)

def load_with_checksum(path):
    with open(path, 'rb') as f:
        actual = hashlib.sha256(f.read()).hexdigest()
    with open(path + '.sha256') as f:
        expected = f.read().strip()
    if actual != expected:
        raise ValueError(f"checksum mismatch: {path}")
    return torch.load(path)

DCP 的 metadata.json 内置每 chunk 的 sha256 hash(v2.6+),加载时自动校验。这是大模型训练防”silent corruption”的工程标准。

corruption 来源:

  • 磁盘 bit flip:宇宙射线 / 硬件老化、概率约 1e-12 per byte,1 TB ckpt 平均每 1000 次出 1 bit 错
  • 网络传输错误:TCP checksum 不能保证 100%,云存储传输偶发 corruption
  • 进程 crash 中断写入:写一半 SIGKILL、ckpt 不完整

每种都让 ckpt 看起来”加载正常但 resume 后 loss 异常”。checksum 让这种 silent failure 立刻暴露。生产大模型训练必须有 checksum,否则浪费几小时训练在坏 ckpt 上。

19.6.18 序列化的内存峰值

很多人以为 torch.save(state_dict, path) 内存占用 = state_dict size。实际峰值更高:

graph LR
    SD[state_dict<br/>1x size] --> Pic[pickle 序列化<br/>+ 一份临时副本<br/>2x size]
    Pic --> Zip[zip 容器<br/>3x size 临时]
    Zip --> File[最终文件<br/>~1x size]

    style Pic fill:#fef3c7
    style Zip fill:#fee2e2

70B 模型 state_dict 占 140 GB GPU + 140 GB CPU;save 时峰值可能 280-420 GB CPU 内存(看 PyTorch 版本)。机器要预留 3x ckpt size 的 RAM 才安全

PyTorch v2.4+ 加了 _disable_tensor_pickling,让 tensor 直接以 zip entry 写入、不走 pickle 临时副本 → 峰值降到 1.2x ckpt size。新代码默认启用。

DCP 的 async save 让峰值进一步降低:snapshot 时 staging 占 1x,写盘后释放,且每 rank 只要 1/N。

如果你遇到”save ckpt 时进程 OOM”,先检查 PyTorch 版本是否 >= 2.4(自动避免临时副本)。还不行就 chunked save:把 state_dict 切成几份分别 save,最后合并。

19.6.19 Optimizer state:被忽视的 ckpt 大头

很多人保存 ckpt 只想到 model.state_dict,但Optimizer state 占用通常是 model 的 2-3 倍

Adam 优化器每个参数维护:

  • exp_avg(一阶矩):与 weight 同 shape、同 dtype(fp32)
  • exp_avg_sq(二阶矩):与 weight 同 shape、同 dtype(fp32)
  • step(当前步数):标量

70B 模型 weight in bf16 = 140 GB;Adam state in fp32 = 70B × 4 字节 × 2 = 560 GB。Adam state 占 ckpt 80%

# state_dict 的真实结构 (Adam)
{
    'state': {
        0: {'exp_avg': tensor[1024,1024], 'exp_avg_sq': tensor[1024,1024], 'step': 1000},
        1: {'exp_avg': ..., 'exp_avg_sq': ..., 'step': 1000},
        ...                    # 每个参数对应一个 dict
    },
    'param_groups': [...]      # 优化器配置 (lr, betas, weight_decay 等)
}

param_groups 用整数索引引用 state 中的参数。这种”间接寻址”让 optimizer 与 model 解耦,但加载时必须保证 model 参数顺序与保存时一致——否则 state[0] 错配到 weight[1],训练立刻崩盘。

DCP 也支持优化器 state ckpt 与 reshard。FSDP-2 的 optimizer state 按 rank sharded → DCP 保存每 rank 自己那 1/N → reload 时按当前 rank 数重新 reshard。这是大模型训练能从 1024 卡 fine-tune 到 64 卡推理的根本前提。

8-bit / FP8 Optimizer(如 bitsandbytes 的 8-bit Adam):把 state 量化到 8-bit 存储,state size 降到 1/4。生产代码越来越常见 —— 把 560 GB Adam state 压到 140 GB,单机训练成本立刻可承受。

19.6.20 大文件支持:ZIP64 与 stream

标准 zip 单文件最大 4 GB(header offset 是 32-bit)。70B model 单文件 280 GB → 远超限制。PyTorch 用 ZIP64 扩展PyTorchFileWriter 自动启用):

  • 文件大小用 64-bit 表示
  • entry offset 也用 64-bit
  • 外层结构兼容标准 zip(小文件仍可被 unzip 工具打开)

实战:280 GB .pt 文件能正常 save / load,无需用户手动切分。但有些工具不支持 ZIP64:

  • unzip -l:会报”file too large”错(GNU unzip 默认不带 ZIP64)
  • zipinfo:同样限制
  • bsdtar(macOS / FreeBSD):支持 ZIP64

诊断 .pt 文件时用 Python 的 zipfile 模块(强制 ZIP64 兼容):

import zipfile
with zipfile.ZipFile('model.pt') as z:
    print(z.namelist())                    # 列文件
    print([info.file_size for info in z.infolist()])

这层细节不常用,但调 ckpt 问题时能避免”为什么 unzip 看不了”的困惑。生产代码强烈建议直接用 DCP(多文件,每文件 < 5GB),避开 ZIP64 的工具兼容问题。

19.6.21 自定义 reduce:让自家类能被 save

如果你的 state_dict 含自定义类(如 LR scheduler 的 metadata),默认 pickle 可能失败。解法:实现 __reduce__

class MyScheduler:
    def __init__(self, base_lr, warmup_steps):
        self.base_lr = base_lr
        self.warmup_steps = warmup_steps
        self._cache = ...    # 不想 pickle 的私有状态

    def __reduce__(self):
        # pickle 时只保存关键参数, 加载时重建
        return (MyScheduler, (self.base_lr, self.warmup_steps))

__reduce__ 返回 (callable, args)。pickle 调用时调 callable(*args) 重建对象。这种”自定义重建”让你完全控制序列化字段。

weights_only=True(v2.6+ 默认)下 自定义 __reduce__ 不被信任 —— 出于安全考虑只允许白名单类型。要让自定义类支持加载:

import torch.serialization
torch.serialization.add_safe_globals([MyScheduler])

# 现在可以加载
state = torch.load('ckpt.pt')

或者用 safetensors 替代 torch.save —— 它根本不支持 Python 对象(只支持 tensor 字典),这种”局限性”反而让安全成为默认。

实战建议:state_dict 应该只含 tensor 与原生 Python 类型(int/float/str/list/dict),不含自定义类。如果非要保存复杂对象(如 scheduler 的 internal state),用 JSON 序列化简单 metadata、复杂逻辑在加载侧重建。这是从 pickle 安全 + 跨版本兼容角度都最好的实践。

19.6.22 ckpt 版本管理:HuggingFace Hub 的 git-like 模型

大模型项目都需要”版本管理”——v1.0 的 ckpt、v1.1 的 ckpt、experimental 分支的 ckpt 共存。HuggingFace Hub 用 git LFS 提供 git-like 版本管理:

# Hub 上一个 model repo 实际是 git repo
git clone https://huggingface.co/meta-llama/Llama-3-8B
ls -la
# .git/                    <- git 历史
# config.json              <- 元信息
# model.safetensors        <- 权重 (LFS-tracked, 大文件不在 git history)
# tokenizer.json
# README.md

每次模型更新 = git commit + LFS push。用户可以:

  • git checkout v1.0:拿历史版本
  • git diff main..v2.0 config.json:看不同版本配置变化
  • git log --oneline:看模型更新历史

对应到 PyTorch 序列化层面:ckpt 不再是单文件,而是一组协同的文件(safetensors + config + tokenizer + chat template + …)。HF 把这套打包成”model card”概念,让模型分发等价于 git repo 分发。

实战影响:

  • 不要把 from_pretrained("path/to/local") 当成”加载一个文件”,它实际加载一个目录 + 解析 config + …
  • ckpt 兼容性:升级 transformers 库时旧 ckpt 可能加载失败(因为 config schema 变了)
  • 自家训练的 ckpt 上传 Hub 的工程:用 model.push_to_hub() 自动打包 + LFS push

理解这套生态让你看到”序列化”在 LLM 时代不只是”save tensor”,是”模型生命周期管理”的入口。PyTorch 自身只覆盖 tensor 字节,HF 生态补足版本管理 + 元信息 + 分发协议。

19.6.23 ckpt 内容 inspector

调试 ckpt 问题时常需”打开看里面有啥”,但又不想真加载几百 GB。常用工具:

# 1. 看 zip 结构 (不读 tensor 数据)
import zipfile
with zipfile.ZipFile('model.pt') as z:
    for name in z.namelist():
        print(name)

# 2. 看 metadata 不加载 tensor (skip_data 上下文)
with torch.serialization.skip_data():
    state = torch.load('model.pt')   # tensor 是 placeholder, 不读字节
print(state.keys())                   # 看 fqn 列表
print({k: v.shape for k, v in state.items()})

# 3. safetensors header 直接解析
from safetensors import safe_open
with safe_open('model.safetensors', framework='pt') as f:
    print(f.metadata())               # JSON header
    print(list(f.keys()))             # tensor 名列表

# 4. DCP 的 metadata.json 直接看
import json
with open('ckpt_dir/metadata.json') as f:
    meta = json.load(f)
print(meta['storage_metadata'])       # 每 chunk 的逻辑坐标

实战场景:

  • 比较两个 ckpt 是否一致:第 2 种方式拿 fqn list + shape 对比
  • 找某个 weight 在哪个 shard 文件:safetensors 的 weight_map (model.safetensors.index.json)
  • 查 ckpt 训练步数state['step']state['epoch']
  • 诊断 reshard 失败:看 DCP metadata.json 的 chunk 坐标是否符合预期

把这些命令做成 inspect_ckpt.py 脚本,是大模型训练运维必备的工具。

19.6.24 ONNX 互操作:跨框架部署的中间格式

ONNX(Open Neural Network Exchange)是跨框架部署的标准 IR。PyTorch 的 ONNX 导出:

import torch.onnx

torch.onnx.export(
    model,
    args=(example_input,),
    f="model.onnx",
    input_names=['input'],
    output_names=['output'],
    dynamic_axes={'input': {0: 'batch'}, 'output': {0: 'batch'}},
    opset_version=18,
)

torch.onnx 内部有两条路径:

  1. torch.onnx.dynamo_export(v2.0+):用 Dynamo 抓 fx graph 转 ONNX。能处理动态控制流,是默认推荐
  2. torch.onnx.export(v1.x 老版):用 torch.jit.trace 抓静态图。仍可用但被标记为 legacy

ONNX 的价值:

  • 跨框架推理:训练用 PyTorch,推理用 TensorRT / OpenVINO / Core ML / DirectML
  • 嵌入式 / 边缘:Coral / Jetson 等支持 ONNX runtime
  • 优化器生态:ONNX Runtime 自带 graph optimization、量化 toolkit

代价:

  • 算子覆盖度:PyTorch 几千个 op,ONNX 标准化几百个。冷门 op 导出失败
  • dynamic shape 支持有限:导出时 dynamic_axes 必须显式声明,否则 shape 写死
  • trace 不支持复杂控制流:含 if x.sum() > 0 的 model 用 trace 导出会丢分支

实战建议:首选 .pt2 (AOTI) 服务器推理、ExecuTorch 移动端、ONNX 仅作为”必须用 TensorRT/OpenVINO” 的桥梁。PyTorch 团队近年把投入从 ONNX 转向自家 export 路径(torch.export)—— 让 PyTorch 模型在 PyTorch 生态内闭环。

19.6.25 ckpt rollback:训练失败的恢复策略

长时间训练偶尔出现 NaN loss / divergence / OOM 等问题,需要 rollback 到更早 ckpt。生产标准做法:

graph LR
    Step1[step 1000<br/>ckpt A] --> Step2[step 2000<br/>ckpt B]
    Step2 --> Step3[step 3000<br/>ckpt C]
    Step3 --> Step4[step 4000<br/>NaN loss!]
    Step4 -.rollback.-> Step3
    Step3 --> Step5[step 4000 retry<br/>不同 lr / batch]

    style Step4 fill:#fee2e2
    style Step5 fill:#dcfce7

工程做法:

  • 保留多个最近 ckpt:rolling window,如最近 5 个 step 的 ckpt
  • NaN detection:每 step 后检查 loss.isnan(),立刻 abort
  • 自动 rollback:检测到 NaN → 自动 load 上一个 ckpt → 调小 lr 重试 → 继续

LLM training infra 都内置这套。Megatron-LM、HuggingFace Trainer 都有 --save_total_limit=5(保留 5 个)+ --resume_from_checkpoint=latest(自动找最新)+ NaN guard。

对应到 PyTorch 序列化层面:ckpt 路径要含 step 号方便排序

ckpt_step_1000/
ckpt_step_2000/
ckpt_step_3000/
ckpt_step_4000/    <- corrupted, skip

rollback 脚本逻辑:

import os, glob
ckpts = sorted(glob.glob('ckpt_step_*'),
               key=lambda x: int(x.split('_')[-1]))
for c in reversed(ckpts):
    if validate(c):    # checksum + 加载 + 跑一个 batch 不出 NaN
        return c
    print(f"corrupted: {c}, trying older")

这种”层层回滚找最近完好版本”是大模型训练的运维标准动作。理解这层让你看到 ckpt 不只是”备份”,是 training infra 的核心 control plane。

19.6.26 推理服务的 ckpt 加载优化

vLLM / SGLang 等推理引擎的”启动慢”问题常被诟病。70B 模型加载 5-10 分钟。生产优化方向:

1. 模型预热池

服务集群保持 N 个 worker 进程,每个预先加载好 model。新请求来时直接路由到 worker,不重新 load。代价是常驻 RAM 高(每 worker 一份 weight)。

2. 共享内存加载

多个 worker 进程共享同一份 weight:master 进程 mmap 一次,子进程 fork 后共享 page。Linux CoW 让”读”操作不复制内存,N 个 worker 仅占 1 份 weight

# master
weight_mmap = torch.load('model.pt', mmap=True)
# fork 子进程
for _ in range(num_workers):
    pid = os.fork()
    if pid == 0:
        # 子进程: weight_mmap 与 master 共享
        serve(weight_mmap)

CUDA tensor 复杂些 —— GPU memory 不能 fork 共享。要用 torch.multiprocessing.share_memory_() 让 cuda tensor 走 IPC。但 vLLM、SGLang 实测表明这套能把 8 worker × 70B 从 1120 GB 降到 140 GB。

3. layered loading

启动时只加载 embedding + 第一层,后续 layer lazy load(首请求时 page-in)。让首响应时间提前 5-10 秒,但代价是首请求慢。适合”长尾请求 + 低 QPS”场景。

4. tensor parallel pre-shard

ckpt 保存时按 TP 维度预先切好(如 8 个文件、每文件 1/8)。各 GPU 直接读自己负责的文件,跳过 master broadcast。生产部署的标准模式。

理解这套让你看 vLLM 启动几秒到几十秒不是恒定数字,是”工程取舍 × 硬件配置”的输出。每种优化都有 tradeoff:内存 vs 启动时间 vs 复杂度。

19.6.27 跨格式转换矩阵

PyTorch 生态有多种 ckpt 格式,转换关系:

源格式 → 目标工具备注
.pt (pickle) → .safetensorssafetensors安全 + 加载快,强烈推荐
.safetensors.pt (pickle)safetensors兼容老代码
.pt / .safetensors.pt2 (AOTI)torch._inductor.aoti_compile_and_package服务器推理
.pt / .safetensors.pte (ExecuTorch)executorch移动端 / 嵌入式
.safetensors → ONNXtransformers 内置跨框架
.safetensors → GGUFllama.cpp/convert.pyCPU / Mac 推理
.bin(HF legacy)→ .safetensorstransformersHF 现在默认输出后者
TF SavedModel → .safetensors自家工具偶尔用,质量参差

没有”通用万能格式”。每种服务一类场景:

graph TB
    Train[训练阶段] --> Pt[.pt / .safetensors]
    Pt -->|服务器推理| Aoti[.pt2 AOTI]
    Pt -->|移动端| Pte[.pte ExecuTorch]
    Pt -->|跨框架| Onnx[ONNX]
    Pt -->|CPU 推理| Gguf[GGUF]
    Pt -->|分发| Hub[HuggingFace Hub]

    style Pt fill:#fef3c7
    style Aoti fill:#dcfce7
    style Pte fill:#dbeafe
    style Gguf fill:#fce7f3

工程决策树:

  • 训练 → 还要继续训练:用 DCP,支持 reshard
  • 训练 → 单机推理 (GPU 服务器):safetensors + AOTI 编译
  • 训练 → 移动端:ExecuTorch
  • 训练 → 别的框架:ONNX(最后选择)
  • 训练 → 社区分发:safetensors + HF Hub
  • 训练 → CPU / Mac 端推理:转 GGUF

理解这个矩阵让你不会盲目”转 ONNX 部署”——很多场景 .pt2 / .pte 更优。每条转换链路都有 tradeoff,用最少的转换跳转达到目标平台是工程正解。

19.6.28 ckpt 加密:付费模型的工程做法

商业模型(如付费版本的微调模型)需要防止权重被直接复制使用。常见方案:

1. ckpt 整体加密

import cryptography
key = derive_key(license)         # 用 license 算出对称密钥
encrypted = AES.encrypt(open('model.safetensors', 'rb').read(), key)
open('model.enc', 'wb').write(encrypted)

# 解密 + 加载
data = AES.decrypt(open('model.enc', 'rb').read(), key)
state = safetensors.load(BytesIO(data))

代价:解密时整个模型在内存里,运行时机器仍能 dump RAM 拿到明文。

2. weight 分层加密

只加密关键 layer(如 attention),公开 backbone。攻击者拿到部分 weight 不够推理,提高破解成本。但密码学上仍是”obfuscation 而非真加密”。

3. TEE / SGX 内运行

模型 weight 在 trusted execution environment 内解密 + 推理,OS 都看不到明文。代价:只在特定硬件可用 + 性能损失大。

4. 服务化 (model as service)

最常用的”加密”是不分发 weight、只分发 API。客户付费调用 API、不能拿权重。OpenAI、Anthropic 都用这种。

PyTorch 序列化层不直接支持加密(设计哲学是”开放”),上层应用自己实现。safetensors 的 metadata 字段可以存 license 信息,但不强制 enforce。

理解这层让你看到”模型保护”是个产品 / 商业问题,不是技术问题 —— PyTorch 把权重存成可读字节是正确选择,分发策略另说。

19.6.29 ckpt 设计的演进时间线

PyTorch 序列化的几个关键节点:

版本改进意义
v0.4torch.save 用 zip 容器mmap 加载成为可能
v1.0API 稳定生产可用
v1.6new zip format(_use_new_zipfile_serialization)更可靠的 zip 实现
v1.10torch.load 支持 fsspec 路径云存储集成
v1.12DCP 引入(实验性)分布式 ckpt 雏形
v2.0DCP 稳定 + 内置千卡训练成为可能
v2.4async DCP 支持训练不被 ckpt 阻塞
v2.4DCP reshard 稳定加载时灵活切片
v2.6weights_only=True 默认RCE 防御默认开
v2.8safetensors 集成进 .pt2安全 + 性能默认
v2.10DCP storage backends 完善(S3/GCS native)云原生训练
v2.11API 稳定,生态成熟大模型训练标配

整体趋势:

  • v0.x-v1.x:从单文件 pickle 走向 zip 容器,让 mmap / 大文件可行
  • v1.x-v2.x:从单机走向分布式(DCP),让千卡训练 ckpt 可行
  • v2.x:从可用走向安全(weights_only)+ 性能(async)+ 云原生

理解这条演进让你看 PyTorch 团队对 ckpt 的持续投入 —— 每个 minor version 都有 ckpt 相关改进。这是因为大模型时代 ckpt 性能 = 训练成本 —— 每 100 step ckpt 一次,写 ckpt 慢 1 分钟 × 1000 次 step = 16 小时浪费。改进 ckpt 性能直接省钱。

19.6.30 一个 70B 训练 ckpt 配置参考

把全章的工程实践合起来,给一个 70B 训练 ckpt 的真实配置:

import torch
import torch.distributed.checkpoint as dcp
from datetime import timedelta

CKPT_INTERVAL = 200            # 每 200 step 保存一次
CKPT_RETAIN = 5                # 保留最近 5 个

CKPT_DIR = "s3://my-org-bucket/training/run_xxx/ckpts"

def save_checkpoint(model, optimizer, scheduler, step, sampler):
    state_dict = {
        'model': model.state_dict(),
        'optim': optimizer.state_dict(),
        'scheduler': scheduler.state_dict(),
        'sampler': sampler.state_dict(),
        'step': step,
        'rng_states': {
            'cpu': torch.get_rng_state(),
            'cuda': torch.cuda.get_rng_state_all(),
        },
    }

    # async + DCP + 写 S3
    future = dcp.async_save(
        state_dict,
        checkpoint_id=f"{CKPT_DIR}/step_{step}",
    )
    return future

def load_latest_checkpoint(model, optimizer, ...):
    ckpts = list_s3_dir(CKPT_DIR)
    for ckpt_id in sorted(ckpts, key=lambda x: int(x.split('_')[-1]), reverse=True):
        try:
            state_dict = init_state_dict(model, optimizer, ...)
            dcp.load(state_dict, checkpoint_id=ckpt_id)
            validate_loaded(state_dict)        # checksum + 一个 dummy forward
            return state_dict
        except Exception as e:
            print(f"corrupted: {ckpt_id}, trying older: {e}")
    raise RuntimeError("no valid ckpt found")

def cleanup_old_checkpoints():
    ckpts = sorted(list_s3_dir(CKPT_DIR), key=lambda x: int(x.split('_')[-1]))
    for old in ckpts[:-CKPT_RETAIN]:
        delete_s3_dir(old)

主训练循环:

ckpt_future = None
for step, batch in enumerate(loader):
    if step % CKPT_INTERVAL == 0:
        if ckpt_future is not None:
            ckpt_future.result()              # 等上一个 ckpt 写完
        ckpt_future = save_checkpoint(model, optim, sched, step, sampler)
        cleanup_old_checkpoints()

    loss = train_step(batch)
    if torch.isnan(loss):
        # NaN 检测 + 自动 rollback
        state = load_latest_checkpoint(...)
        # 调小 lr 重试
        for g in optim.param_groups:
            g['lr'] *= 0.5

这套配置覆盖了:

  • DCP(千卡 reshard 友好)
  • async(不阻塞训练)
  • S3(云原生,无需本地落盘)
  • rolling retention(控制存储成本)
  • NaN guard(自动 rollback)
  • RNG 完整保存(resume 可复现)
  • checksum 验证(silent corruption 防御)

理解每条配置的 why,让你能为自家训练任务量身定制。这套配置在 HuggingFace Trainer / Megatron-LM / DeepSpeed 等主流训练框架的 ckpt 默认实现中都能找到对应实现,是公开社区在大模型训练中沉淀下来的工程实践。

19.6.31 自定义 DCP storage backend:写一个云存储后端

StorageWriter / StorageReader 是 DCP 的扩展点。第三方实现这两个抽象类即可让 DCP 写到任意存储。简化示例:

from torch.distributed.checkpoint.storage import StorageWriter, StorageReader

class MyS3Writer(StorageWriter):
    def __init__(self, bucket: str, prefix: str):
        self.bucket = bucket
        self.prefix = prefix
        self.s3 = boto3.client('s3')

    def reset(self, checkpoint_id):
        self.checkpoint_id = checkpoint_id

    def set_up_storage_writer(self, is_coordinator):
        # rank 0 (coordinator) 创建 prefix 目录
        if is_coordinator:
            self.s3.put_object(Bucket=self.bucket,
                              Key=f"{self.prefix}/{self.checkpoint_id}/")

    def write_data(self, plan, planner):
        # 把 plan 中所有 WriteItem 写到 S3
        for item in plan.items:
            data = planner.resolve_data(item)
            key = f"{self.prefix}/{self.checkpoint_id}/__{item.index.fqn}.bin"
            self.s3.put_object(Bucket=self.bucket, Key=key, Body=data)
        return Future()    # 完成 future

    def finish(self, metadata, results):
        # 写 metadata.json
        self.s3.put_object(
            Bucket=self.bucket,
            Key=f"{self.prefix}/{self.checkpoint_id}/metadata.json",
            Body=json.dumps(metadata).encode(),
        )

# 用法
dcp.save(state_dict,
         storage_writer=MyS3Writer('my-bucket', 'training/run_1'))

实战中 v2.10+ 的 PyTorch 已经内置 fsspec-based S3 支持,自家写 backend 仅在特殊需求时用:

  • 自家对象存储(如内部 ceph cluster)
  • 加密存储(写前 encrypt)
  • 跨区域复制(写多个 region 提高可用性)
  • 进度监控(每 chunk 写完发 metric 到监控系统)

理解 storage backend 的接口让你看 DCP 不是绑死本地文件系统的,是真正的 pluggable。生产场景如果对 ckpt 存储有特殊需求,扩展这层就够、不需要 fork DCP 自身。

19.7 几条工程经验

1. 默认 weights_only=True:从 v2.6 起这是默认。如果遇到加载报错说”unsupported global”,先想想是否真的需要那个类,能换掉就换

2. 分发模型给社区用 safetensors:兼容性、安全性、加载速度全面优于 pickle

3. 大模型加载默认 mmap=True:v2.4+ 已经在改善默认行为,但显式开总是稳

4. 训练 ckpt 推荐 DCP:单 rank 写大文件已经过时

5. torch.serialization.skip_data 上下文:load 时只想看元数据(如查 ckpt 含哪些 key)不想读数据,进这个上下文跳过实际字节加载

6. ckpt 校验:写完 ckpt 后立刻 load 一遍验证,避免训练几小时后才发现 ckpt 坏了。可以用单独 rank 异步验证

7. 跨 PyTorch 版本加载:state_dict 通常向后兼容。整对象 pickle 跨版本可能 deserialize 失败 —— 又一个用 state_dict 的理由

8. fsspec 路径torch.save(state, 's3://bucket/ckpt.pt') 现在直接支持云存储(v2.x+)。生产训练直接写 S3 / OSS 不需要本地落盘

19.8 跨书关联

  • 第 9 章 §9.6 nn.Module.state_dict:序列化的”上游”,本章是它的”下游”
  • 第 18 章 FSDP:FSDP 训练必用 DCP 否则 ckpt 不可行
  • 《vLLM 内核探秘》第 7 章 模型加载:vLLM 推理时大量用 safetensors + mmap 加载预训练模型

19.9 设计启示

序列化设计的核心思想:

第一容器格式 + 数据分离:zip + tensor entries 让 mmap、并行加载、增量更新成为可能。比单一 blob 序列化灵活得多

第二weights_only 把”加载”从代码执行降级到字节解析:这条原则在所有”加载第三方数据”场景都成立 —— JSON / Protobuf / safetensors 都比 pickle 安全。能不执行就不执行,是数据格式设计的底线

第三逻辑坐标而非物理位置:DCP 用逻辑张量坐标解耦 “保存时拓扑” 与 “加载时拓扑”,让 ckpt 可重 reshard。这条思想在分布式数据库分片管理也常见,是”让数据脱离硬件假设”的工程标准做法

第四state_dict 让模型架构与权重解耦:架构是 Python 代码(活的、可演进),权重是数据(死的、可移植)。这种”代码 + 数据”分离是任何”长寿模型”的工程基础

第五异步、并行、增量是大规模训练的三件套:async save 让训练不阻塞、DCP 让每 rank 并行写、PEFT/LoRA 让分发只发增量。三个机制叠加才让千卡训练 ckpt 从”工程难题”变成”标准操作”。这条”在多个维度同时下功夫”的设计思路,对任何重型 IO 系统的设计都成立 —— 不要指望一招制胜,要在每条流水线上都挖性能

下一章拆量化与混合精度训练 —— FSDP 已经给了精度策略的初步配置,这章把整个 PyTorch 量化生态摊开,看 autocast 的 CastPolicy 决策、GradScaler 的动态范围管理、PT2E 量化框架怎么把整个量化栈现代化。

评论 0