第19章 序列化:torch.save / torch.load 与权重格式
“Saving a tensor sounds trivial. Saving 70 billion of them across 32 GPUs without 1TB of RAM —— that’s another story.”
—— PyTorch DCP design doc
本章要点
torch.save默认用 pickle + zip 容器:每个 tensor 的 storage 单独存一个 zip entry,让 mmap 加载成为可能weights_only=True是新默认(v2.6+):只允许加载白名单类型,防止 pickle RCE 攻击safetensors是 HuggingFace 推动的替代格式:纯字节布局 + JSON header,零代码执行风险,加载比 pickle 快torch.save(model.state_dict())vstorch.save(model):前者只存权重(推荐),后者 pickle 整个对象(依赖类定义)- Distributed Checkpoint (DCP):千卡训练每 rank 只存自己那 1/N,避免 rank 0 写 1TB 单文件
- mmap 加载 (
mmap=True) 让超大 ckpt 加载几乎瞬间完成:OS 把文件映射到虚拟内存,按需 page-in
19.1 torch.save 的真实格式
打开一个 .pt 文件用 unzip -l 能看到:
Archive: model.pt
data.pkl <- pickle 序列化的 Python 对象图 (无 tensor 数据)
data/0 <- 第 1 个 tensor 的字节数据
data/1 <- 第 2 个 tensor
...
version <- 格式版本号
byteorder <- 字节序标记
torch.save 用的是 zip 容器 + 拆分 storage 的设计:
- 把对象树(dict / list / nn.Module)用 pickle 序列化,但每个 tensor 写到一个占位符而非数据
- 每个 storage 的实际字节数据单独存到 zip 里一个 entry(
data/0、data/1…) - 加载时先反序列化 pickle 拿对象骨架 + 占位符,再按需读 zip entry 填字节
MAGIC_NUMBER = 0x1950A86A20F9469CFC6C(serialization.py:65)是文件头标记。
19.1.1 为什么不直接 pickle 整个 dict
朴素 pickle:把所有 tensor 的字节当 bytes 序列化进 pickle 流。问题:
- pickle 流是顺序的,没法 mmap(要全部读到内存才能解析)
- 修改 / 添加单个 tensor 要重写整个文件
- 跨进程共享时序化开销大
zip 容器格式:
- 每个 storage 在 zip 里有独立 entry,支持 mmap
- zip 提供 random access:直接跳到第 N 个 storage
- 兼容现有工具(unzip / zipfile 都能查看)
代价是文件比纯 pickle 略大(zip header 开销)。但对 1GB+ 模型 ckpt 几乎不可见。
19.2 mmap 加载:超大 ckpt 几秒完成
torch.load(path, mmap=True) 让 OS 把整个 zip 文件映射到虚拟内存,实际不读到 RAM:
# 没有 mmap: 必须把整个文件读到内存
state = torch.load('llama-70b.pt') # 280 GB, 单机 RAM 不够
# 有 mmap: OS 按需 page-in
state = torch.load('llama-70b.pt', mmap=True) # 几乎秒返回
weights = state['model.layer1.weight'] # 用到时才真读
mmap 的几个优势:
- 多个进程加载同一个 ckpt 时共享 OS page cache,省一份内存
- 加载时间几乎为零(懒加载)
- 不连续访问的部分根本不读盘
vLLM / HuggingFace 加载大模型时大量用 mmap。生产代码里这是 ckpt 加载的标准操作。
19.3 weights_only:防止 pickle RCE
pickle 是 Python 的”任意对象序列化”格式,加载时会执行嵌入的代码。恶意 ckpt 可以 __reduce__ 触发 os.system("rm -rf /")。
PyTorch v2.6+ 默认 weights_only=True:只允许加载白名单类型(Tensor、Parameter、各种内置类型),其他全部拒绝。这让 ckpt 加载从”等同于 eval(file)”变成”等同于读 JSON”。
如果你的 ckpt 含自定义类(如自定义 Optimizer),要么改用白名单允许:
torch.serialization.add_safe_globals([MyOptimizer])
state = torch.load('ckpt.pt')
要么显式 weights_only=False 退回老行为(不推荐生产用)。
社区从 HuggingFace 的”safetensors 推动”开始一直在向”任何 ckpt 都不该 RCE”靠拢。weights_only 默认 True 是 PyTorch 跟进的安全升级。
19.4 safetensors:HuggingFace 推动的替代格式
safetensors 不是 PyTorch 内置(要 pip install safetensors),但已经成了 HuggingFace 模型分发的事实标准。文件格式:
[8 字节: header 长度]
[N 字节: JSON header — 描述每个 tensor 的 name / dtype / shape / 字节范围]
[剩余: 紧凑 tensor 字节数据,按 header 里描述的偏移排列]
关键设计:
- 零代码执行:纯字节 + JSON,加载就是”按 header 切字节”。安全 100%
- 比 pickle 快:没有反序列化开销,直接 mmap
- 零拷贝加载:tensor 直接 view 在 mmap 内存上
- 跨语言友好:Rust / JS / Go 都能读
PyTorch 端用 safetensors:
from safetensors.torch import save_file, load_file
save_file(model.state_dict(), 'model.safetensors')
loaded = load_file('model.safetensors')
HuggingFace Hub 上的现代模型几乎都用 safetensors(如 Llama / Qwen / DeepSeek 的官方权重)。.bin(pickle)格式仍然存在但被标记为”legacy”。
实测对 70B 模型:
- pickle (.bin) 加载:~120 秒
- safetensors 加载:~30 秒
- safetensors + mmap:< 1 秒
这是为什么 vLLM 默认要求 safetensors。
19.5 model.state_dict() vs torch.save(model)
# 推荐
torch.save(model.state_dict(), 'weights.pt')
# 不推荐
torch.save(model, 'whole_model.pt')
差异:
state_dict() | 整对象 pickle |
|---|---|
| 只存权重数据 | 存整个对象(含类定义引用) |
| 加载时要先 build model 类 | 加载时自动 reconstruct |
| 修改类定义不影响加载 | 改类定义可能加载失败 |
| 文件小 | 文件大(含 Python 类元数据) |
| 跨 PyTorch 版本兼容好 | 跨版本风险 |
第 9 章 §9.6.0.5 讨论过这个区别。生产代码强烈建议 state_dict。整对象 pickle 只在快速保存中间状态时偶尔用。
19.6 Distributed Checkpoint (DCP)
70B 模型 ckpt ≈ 280 GB(参数)+ 560 GB(optimizer state)+ … ≈ 1 TB。如果 rank 0 收集所有 rank 数据再写一个文件:
- rank 0 单机要 1 TB RAM 收集
- rank 0 串行写 1 TB 到磁盘几十分钟
- 加载时反过来 —— 又要几十分钟
DCP(torch.distributed.checkpoint)的解法:每个 rank 只写自己持有的数据,并行写到一个目录:
import torch.distributed.checkpoint as dcp
# 保存
state_dict = {'model': model.state_dict(), 'optim': optimizer.state_dict()}
dcp.save(state_dict, checkpoint_id='ckpt_step_1000')
# 加载 (跨任意 rank 数恢复)
dcp.load(state_dict, checkpoint_id='ckpt_step_1000')
DCP 输出的不是单文件,而是一个目录:
ckpt_step_1000/
__0_0.distcp <- rank 0 的数据
__1_0.distcp <- rank 1 的数据
...
metadata.json <- 描述每片数据的逻辑映射
每 rank 写自己的 __N_0.distcp 文件,完全并行。加载时 DCP 根据 metadata.json 把每片正确路由回需要它的 rank(即便加载时 rank 数与保存时不同 —— 自动 reshard)。
这套机制让大规模训练 ckpt 从”几十分钟串行写”变成”几分钟并行写”,是 70B+ 训练的工程基础。
19.6.1 DCP 的 reshard 能力
最强的特性:保存时 8 卡 ZeRO-3,加载时 16 卡 ZeRO-3,DCP 能自动重新切分。
实现机制:metadata.json 用 逻辑张量坐标(不是物理 rank)描述每片数据。加载时根据当前 rank 拓扑重新映射 —— 保存时 rank 0 持有 [0:1000] 切片,加载时如果 rank 0 现在应该持有 [0:500],DCP 自动切。
这种”加载时 reshard”让训练流水线灵活得多 —— 训练用 64 卡,fine-tuning 切到 8 卡只需重新加载,不需要离线 resharding。
19.6.5 storage 共享检测:避免存重复字节
PyTorch 张量经常共享 storage(第 2 章 §2.6 view 机制)。torch.save 不会傻傻地把每个 tensor 的字节都存一份 —— 它做了storage 共享检测:
a = torch.randn(10, 10)
b = a.view(100) # b 与 a 共享 storage
c = a[0:5] # c 也共享
state = {'a': a, 'b': b, 'c': c}
torch.save(state, 'shared.pt')
# 文件只含 1 份 storage 字节, 不是 3 份
实现机制是 pickle 的 persistent_id / persistent_load 钩子:
- save 时:每个 tensor 序列化前调
persistent_id(tensor)把它的 storage 用一个 ID 替代(同一 storage 返回同一 ID)。zip 里每个唯一 ID 对应一个data/Nentry - load 时:每个 ID 解析时调
persistent_load(id),从 zip 取出对应字节、构造 storage,再多个 tensor 共用这一份 storage
实测:3 个共享 storage 的 tensor save 后,文件大小 ≈ 1 份 tensor 大小(不是 3 份)。这套机制对 view-heavy 模型(如 transformer 的 QKV 投影)能省可观空间。
但有个反直觉的坑:view 张量也存了它的”完整 base storage”,不是它自己的子集。如果你 b = a[0:5]; del a; torch.save({'b': b}),文件大小是整个 a 的,不是 5 个元素。要省空间得 b = a[0:5].clone() 让 b 有独立 storage。
19.6.6 PyTorchFileReader / PyTorchFileWriter
torch._C.PyTorchFileReader / PyTorchFileWriter(C++ 在 caffe2/serialize/inline_container.cc)是 zip 容器的底层。它们不是用标准 zip 库,是 PyTorch 自家实现的精简版 —— 只支持 STORE(不压缩)模式,因为:
- tensor 字节本身已经接近随机,压缩率低(几乎 1.0)
- 压缩开销让 save / load 慢几倍
- 不压缩让 mmap 直接映射文件 = 内存(mmap 对压缩文件无效)
副作用:.pt 文件比同样数据的 .npz(用 zip DEFLATE 压缩)大约 1.0-1.05x,但加载快 3-5x。这是性能 vs 空间的明确选择。
PyTorchFileWriter 还支持 multi-stream write:多个 storage 字节并行写到 zip 里不同 entry。70B 模型 ckpt 用 4 个 IO thread 写比单线程快 2-3x。torch.save 默认开启这个优化。
19.6.7 DCP 的 SavePlan / LoadPlan
torch/distributed/checkpoint/planner.py:105/115 的 SavePlan / LoadPlan 是 DCP 的 IR:
@dataclass
class SavePlan:
items: list[WriteItem] # 每个 item 描述一段要写的数据
storage_data: Any # backend-specific 数据
planner_data: Any # planner 状态
@dataclass
class WriteItem:
index: MetadataIndex # 逻辑索引 (fqn + offsets)
type: WriteItemType # TENSOR / BYTE_IO
tensor_data: TensorWriteData | None
...
DCP save 流程:
- 每个 rank 调
planner.create_local_plan(state_dict)生成自己的 SavePlan - rank 0 收集所有 SavePlan、调
planner.create_global_plan(plans)做去重 + 优化 - broadcast 全局 plan 回所有 rank
- 每个 rank 按 plan 调
storage.write_data(plan)写自己负责的部分 - rank 0 写 metadata.json
LoadPlan 反向:每 rank 描述”我需要哪些 tensor”,DCP 根据 metadata.json 路由到正确的物理 chunk。
这套 plan-based 设计让 DCP 能换 storage 后端:默认 FileSystemWriter(写本地 / NFS),也可换 S3Writer(写云存储)、GCSWriter 等。第三方实现 StorageWriter / StorageReader 接口即可,业务层无感知。
FileSystemReader(filesystem.py:837)与 FileSystemWriter(:965)是默认实现 —— 写到目录里若干 .distcp 文件 + 一个 .metadata。
19.6.8 Resharding 算法:8 卡 ckpt 加载到 16 卡
§19.6.1 提到 DCP 能 reshard,具体怎么做?看 torch/distributed/checkpoint/_dedup_save_plans.py + _resharding.py。
场景:保存时 8 卡 FSDP,每 rank 持有 weight 的 1/8 切片;加载时 16 卡 FSDP,每 rank 应该持有 1/16 切片。算法:
graph TB
Save["保存时 (8 卡):<br/>rank 0: weight[0:N/8]<br/>rank 1: weight[N/8:2N/8]<br/>..."]
Save --> M["metadata.json:<br/>'weight' chunks 用逻辑坐标:<br/>chunk[0:N/8], chunk[N/8:2N/8], ..."]
M --> Load["加载时 (16 卡):<br/>每 rank 算自己需要 weight[i*N/16:(i+1)*N/16]"]
Load --> P{逻辑切片<br/>vs 物理 chunk 关系}
P --> Calc["rank 0 需要 weight[0:N/16]<br/>= 物理 chunk[0:N/8] 的前一半"]
Calc --> Read[读 rank 0 chunk 文件<br/>取前 N/16 字节]
style M fill:#fef3c7
style Calc fill:#dbeafe
具体步骤(load.py:LoadPlanner._populate_load_items):
- 解析 metadata.json 拿到所有物理 chunk 的逻辑坐标范围
- 对当前每 rank 的需求范围 R_i,找出它与哪些物理 chunk 有交集
- 为每个交集生成一个 LoadItem:从某物理 chunk 的某偏移读多少字节、写到本地 tensor 的某偏移
- 调度 LoadItem:尽量批量读同一文件、最小化 IO
特殊情况处理:
- 新 rank 数 > 旧:每物理 chunk 被多个 rank 共享读(多读不冲突)
- 新 rank 数 < 旧:每 rank 读多个物理 chunk 拼起来
- dim 改变:DCP 不支持自动改 sharding 维度(如从 dim 0 sharding 改成 dim 1),需要离线工具
实测:1024 卡训练保存 → 64 卡 fine-tune 加载,DCP 自动 reshard 用时 5-10 分钟(取决于网络)。无需”先 gather 到全量再切” —— 直接按 chunk 重新映射。
理解 reshard 让你看到 DCP 不是”分布式 save/load”那么简单,而是”逻辑坐标系统 + 物理 chunk 路由”的工程。这是大模型生命周期管理(pretrain → fine-tune → distill)的关键基础设施。
19.6.9 Async Checkpoint:训练不被 ckpt 阻塞
torch/distributed/checkpoint/state_dict_saver.py:async_save(v2.4+)让 ckpt 写入与训练并发:
import torch.distributed.checkpoint as dcp
# 同步版本: 训练等 ckpt 写完才继续
dcp.save(state_dict, ...)
# next training step starts...
# 异步版本: 训练立刻继续, 后台慢慢写
future = dcp.async_save(state_dict, ...)
# next training step starts immediately
# ...
future.result() # 偶尔等一次, 确认上次 ckpt 写完
实现机制:
graph LR
T[训练主线程] --> CP[snapshot 整个 state_dict<br/>1. 把 GPU tensor 异步 copy 到 CPU<br/>2. 用同一份 CPU 内存的 ref 创建 staging dict]
CP --> RT[训练继续<br/>下一 step]
CP --> BG[后台线程<br/>写 staging dict 到磁盘]
BG --> Done[完成 future]
style CP fill:#fef3c7
style BG fill:#dcfce7
关键点:snapshot 阶段必须同步(要确保后台写的是当前 step 状态、不是几 step 后的)。snapshot 用 cudaMemcpyAsync 把 GPU 数据拷到 CPU pinned memory —— 几百 ms。然后训练继续;后台线程慢慢把 CPU 数据写盘 —— 几十秒。
实战收益:每 100 step 写一次 ckpt,每次 ckpt 阻塞 60s → 训练吞吐损失 60/(100×step_time)。step_time = 1s 时 → 损失 60% 吞吐!开 async ckpt 后损失降到 1-2%。
代价:需要 2x CPU 内存(一份给当前训练 + 一份 staging)。70B 模型 ckpt 占 140 GB CPU 内存 —— 单机 256 GB RAM 通常够。如果 CPU 紧张,不能开 async。
理解 async ckpt 让你看 v2.x 大模型训练的真实工程取舍:宁愿多吃 CPU 内存也不让 ckpt 写阻塞训练,因为 GPU 时间贵得多。
19.6.10 PEFT / LoRA 的 ckpt 增量保存
LoRA / Adapter / Prefix-tuning 等参数高效微调(PEFT)的核心是”只训练少量参数”。ckpt 也只需保存这些参数:
# 全量 fine-tune: state_dict 含所有 7B 参数
torch.save(model.state_dict(), 'full_ft.pt') # ~14 GB
# LoRA: 只存 LoRA 模块的参数
lora_state_dict = {k: v for k, v in model.state_dict().items() if 'lora' in k}
torch.save(lora_state_dict, 'lora.pt') # ~50 MB
peft 库(HuggingFace)做了完整封装:model.save_pretrained() 自动只保存 LoRA 权重 + adapter_config.json(描述秩、target_modules 等)。
加载流程:
- 先加载 base model(如 Llama-7B 全量权重)
- 加载 adapter_config.json 知道在哪些 module 上加 LoRA
- 加载 lora.pt 把 LoRA 权重塞进对应 module
这种”base model + small adapter”模式的工程意义:
- 存储成本:1 个 base + 100 个 adapter(每个 50 MB)= 14 GB + 5 GB = 19 GB;vs 100 个全量 = 1.4 TB
- 分发:用户可单独下载 adapter,不重复下 base
- 服务:vLLM 等支持 base + 多 adapter 共存推理(Multi-LoRA Serving),只切换 adapter 不重 load base
对应到 PyTorch 序列化层面:state_dict 是 dict,支持任意子集 —— 这是 LoRA 这种增量保存的基础设施。load_state_dict(strict=False) 允许部分 key 不存在 → 只覆盖匹配的 key。这层设计让 PEFT 能优雅落地。
19.6.11 断点续训:随机状态完整保存
光保存 model + optim 还不够。完整可恢复 ckpt 必须包括:
ckpt = {
'model': model.state_dict(),
'optim': optimizer.state_dict(),
'scheduler': lr_scheduler.state_dict(),
'sampler': sampler.state_dict(), # §11.9.18
'epoch': epoch,
'step': step,
'rng_states': {
'cpu': torch.get_rng_state(),
'cuda': torch.cuda.get_rng_state_all(), # 所有 GPU 各自
'numpy': np.random.get_state(),
'python': random.getstate(),
},
'amp_scaler': scaler.state_dict() if scaler else None,
}
torch.save(ckpt, 'full_ckpt.pt')
为什么 RNG 也要存?两个场景:
- augment 复现:
RandomCrop用torch.rand决定 crop 位置,rng 不一致 → 同一 sample 生成不同 augmented view → 不可复现 - dropout / noise injection:reproducibility 要求
torch.set_rng_state 加载时恢复 CPU RNG;torch.cuda.set_rng_state_all 恢复每个 GPU 的 RNG。但 DataLoader worker 的 RNG 是另一套——每个 worker 启动时基于 base_seed + worker_id 计算,所以保存 sampler.state_dict 已经隐含了 worker RNG 的种子。
实战:HuggingFace Trainer / Accelerate 内置完整 RNG 保存。如果你自写训练循环,要么照抄上面的字典、要么直接用这些库。漏一个 RNG 类型就会导致”resume 后 loss 跳变”——很难诊断的 bug。
19.6.12 fsspec:直接读写云存储
PyTorch v2.x+ 集成了 fsspec(Python 文件系统抽象库),让 torch.save / torch.load 直接支持 S3 / GCS / Azure Blob / HDFS:
torch.save(state, 's3://my-bucket/ckpts/step_1000.pt')
state = torch.load('s3://my-bucket/ckpts/step_1000.pt')
底层调用 fsspec.open() 拿到一个 file-like 对象,PyTorch 不区分本地 / 云。
实战注意点:
- 认证:要先配好 boto3 / gcloud CLI 让 fsspec 用上 credentials
- 大文件:S3 multipart upload 自动启用(每段 100 MB),70B ckpt 上传几分钟(取决于带宽)
- list / glob 慢:S3 list 操作是 O(N),DCP load 时如果用 globbing 找文件会慢;DCP 直接读 metadata.json 避免这个
- eventually consistent:S3 写完不能立刻 list 出(几秒延迟)。生产代码 save 后等 5 秒再发”ckpt ready”通知
DCP 的 FileSystemWriter v2.6+ 支持 fsspec 路径:
dcp.save(state_dict, checkpoint_id='s3://bucket/ckpt/step_1000')
让 1024 卡训练每 rank 直接写 S3、不需要本地落盘 + 后台同步。云原生训练的标准模式。
19.6.13 GGUF / GGML:推理优化的 weight 格式
safetensors 是训练 / 通用分发场景,端侧 / CPU 推理有更专门的格式:GGUF(GGML 的下一代)。它专为 llama.cpp 等 CPU/Apple Silicon 推理框架设计:
| 维度 | safetensors | GGUF |
|---|---|---|
| 量化支持 | 不内置 | 内置(int4/int5/int8 等多种) |
| 元信息 | tensor 名 + shape | 加上 chat template、stop tokens、tokenizer 配置 |
| 加载方式 | mmap 友好 | mmap 友好 + 端侧推理优化 |
| 编辑器支持 | HuggingFace Hub | llama.cpp / Ollama 生态 |
| 适用场景 | 训练 / 服务器推理 | CPU / 端侧推理 |
PyTorch 这层不直接支持 GGUF,但社区有 transformers → GGUF 转换工具:
python convert.py --input model.safetensors --output model.gguf --quantize q4_k_m
转换后用 llama.cpp 加载、在 CPU / Mac M-series 上跑推理。这条路径让 PyTorch 训练的模型能落地到端侧 —— 训练用 PyTorch、部署给社区用 GGUF。
理解 GGUF 在 PyTorch 生态外的存在让你看到 model 序列化是个多格式共存的世界。每种格式服务一类场景:safetensors 通用、.pt2 (AOTI) 服务器推理、.pte (ExecuTorch) 移动端、GGUF 端侧 CPU。PyTorch 自身只覆盖前几个,社区补足后两个。
19.6.14 HuggingFace from_pretrained:业界实际加载流程
PyTorch 用户极少手写 torch.load,更常见的是:
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3-8B")
这一行的内部流程:
graph TB
Call[from_pretrained] --> DL[1. 从 Hub 下载文件<br/>config.json / model.safetensors / tokenizer]
DL --> Cache[本地 cache: ~/.cache/huggingface/hub/]
Cache --> Cfg[2. 解析 config.json<br/>找到 model class + dtype + ...]
Cfg --> Build[3. 实例化空 model<br/>只 build 网络结构, 权重 init 为 meta tensor]
Build --> Sharded{4. 单文件还是 sharded?}
Sharded -->|单文件| Single[加载 model.safetensors]
Sharded -->|多文件| Multi[加载 model-00001-of-00007.safetensors<br/>+ model.safetensors.index.json]
Single --> Load[5. load_state_dict to model<br/>用 mmap, 一次填一个 tensor]
Multi --> Load
Load --> Done[ready model]
style Build fill:#fef3c7
style Load fill:#dcfce7
几个工程精巧点:
- meta tensor build:第 3 步用
torch.device('meta')创建模型骨架,无显存分配。第 5 步把 mmap 字节直接 cast 到 GPU,跳过 “CPU full copy → GPU copy” 双重内存压力 - sharded 文件命名:
model-00001-of-00007.safetensors是事实标准,每个 5-20 GB(safetensors 单文件 size 上限相关) - index.json:含
{"weight_map": {"model.layers.0.weight": "model-00001-of-00007.safetensors"}}的反查表 —— 加载某个权重时知道去哪个文件读 - streaming load:边下载边加载(v0.20+ 的 transformers),适合大文件 first-time use
理解这套流程让你看 vLLM / SGLang / 自家推理服务的”模型加载几分钟”开销主要在哪 —— 一是网络下载、二是 GPU memory 分配、三是 cuda 上下文 init。真正的 deserialization 反而最快(mmap + zero-copy)。
19.6.15 sharded safetensors:单文件大小限制怎么处理
safetensors 单文件理论无限制,但实际 HuggingFace Hub 限制 每文件 ≤ 50 GB(防止下载超时)。70B model = 140 GB → 必须切到 3+ 个文件。
切分逻辑(transformers 的 save_sharded 函数):
- 按 fqn 顺序遍历所有 tensor
- 累计 size > shard_size_limit (默认 5GB) 时切到下一个文件
- 每个 tensor 不跨文件(一个 tensor 必须完整在一个文件里)
- 写完所有文件后生成
model.safetensors.index.json
注意点:
- shard 大小:太小(< 1GB)→ 文件多、下载并发开销;太大(> 50GB)→ Hub 不接受。5-10 GB 最优
- 顺序敏感:fqn 顺序变化会让 sharding 结果不一致 → ckpt 不可复现。HF 用
sorted(state_dict.keys())保证一致 - 加载时按 index 找:每加载一个 weight,从 index 反查在哪个文件、再从该文件 mmap → 一次只读必要的字节
实战:HF Hub 的 Llama-70B 分成 8 个文件(17 GB 每个)。如果用户带宽 100Mbps、并发 4 → 总下载时间 = 17×8 / (100/8 × 4) ≈ 4 分钟。这就是 from_pretrained 等几分钟的根本原因。
19.6.16 跨设备加载:map_location 与 device_map
ckpt 保存时 tensor 在 GPU,加载时机器没那么多 GPU 怎么办?map_location 解决这个:
# 把所有 tensor 加载到 CPU
state = torch.load('ckpt.pt', map_location='cpu')
# 把所有 tensor 加载到 cuda:0 (即便保存时在 cuda:7)
state = torch.load('ckpt.pt', map_location='cuda:0')
# 自定义函数: 按 tensor name 决定去哪
def loader(storage, loc):
return storage.cuda(0)
state = torch.load('ckpt.pt', map_location=loader)
底层:torch.load 在恢复 storage 时调 map_location 决定目标 device。
device_map 是 HuggingFace transformers 的高级版本:
model = AutoModelForCausalLM.from_pretrained(
"Llama-70B",
device_map="auto", # 自动分配 layer 到多 GPU
)
device_map="auto" 算法:
- 估算每 layer 占多少 memory
- 按层数把 layer 平均分到所有 GPU
- 加载 weight 时直接 map 到目标 GPU(不经过 CPU)
实战:8 卡机器加载 Llama-70B(140 GB),device_map=“auto” 让每卡 ~17.5 GB。配合 offload_folder 还能让多余 layer 落到磁盘 → 单卡推理 7B、双卡 70B、单卡 + offload 70B 都能跑。
理解这层让你看到”内存不够也能 load 模型”不是魔法,是 PyTorch + HF 联合做的精细 device 管理。
19.6.17 ckpt 完整性校验
长时间训练 ckpt 损坏会让”resume 后 loss 直接 NaN”。生产建议加 checksum:
import hashlib
def save_with_checksum(state_dict, path):
torch.save(state_dict, path)
with open(path, 'rb') as f:
sha256 = hashlib.sha256(f.read()).hexdigest()
with open(path + '.sha256', 'w') as f:
f.write(sha256)
def load_with_checksum(path):
with open(path, 'rb') as f:
actual = hashlib.sha256(f.read()).hexdigest()
with open(path + '.sha256') as f:
expected = f.read().strip()
if actual != expected:
raise ValueError(f"checksum mismatch: {path}")
return torch.load(path)
DCP 的 metadata.json 内置每 chunk 的 sha256 hash(v2.6+),加载时自动校验。这是大模型训练防”silent corruption”的工程标准。
corruption 来源:
- 磁盘 bit flip:宇宙射线 / 硬件老化、概率约 1e-12 per byte,1 TB ckpt 平均每 1000 次出 1 bit 错
- 网络传输错误:TCP checksum 不能保证 100%,云存储传输偶发 corruption
- 进程 crash 中断写入:写一半 SIGKILL、ckpt 不完整
每种都让 ckpt 看起来”加载正常但 resume 后 loss 异常”。checksum 让这种 silent failure 立刻暴露。生产大模型训练必须有 checksum,否则浪费几小时训练在坏 ckpt 上。
19.6.18 序列化的内存峰值
很多人以为 torch.save(state_dict, path) 内存占用 = state_dict size。实际峰值更高:
graph LR
SD[state_dict<br/>1x size] --> Pic[pickle 序列化<br/>+ 一份临时副本<br/>2x size]
Pic --> Zip[zip 容器<br/>3x size 临时]
Zip --> File[最终文件<br/>~1x size]
style Pic fill:#fef3c7
style Zip fill:#fee2e2
70B 模型 state_dict 占 140 GB GPU + 140 GB CPU;save 时峰值可能 280-420 GB CPU 内存(看 PyTorch 版本)。机器要预留 3x ckpt size 的 RAM 才安全。
PyTorch v2.4+ 加了 _disable_tensor_pickling,让 tensor 直接以 zip entry 写入、不走 pickle 临时副本 → 峰值降到 1.2x ckpt size。新代码默认启用。
DCP 的 async save 让峰值进一步降低:snapshot 时 staging 占 1x,写盘后释放,且每 rank 只要 1/N。
如果你遇到”save ckpt 时进程 OOM”,先检查 PyTorch 版本是否 >= 2.4(自动避免临时副本)。还不行就 chunked save:把 state_dict 切成几份分别 save,最后合并。
19.6.19 Optimizer state:被忽视的 ckpt 大头
很多人保存 ckpt 只想到 model.state_dict,但Optimizer state 占用通常是 model 的 2-3 倍。
Adam 优化器每个参数维护:
exp_avg(一阶矩):与 weight 同 shape、同 dtype(fp32)exp_avg_sq(二阶矩):与 weight 同 shape、同 dtype(fp32)step(当前步数):标量
70B 模型 weight in bf16 = 140 GB;Adam state in fp32 = 70B × 4 字节 × 2 = 560 GB。Adam state 占 ckpt 80%。
# state_dict 的真实结构 (Adam)
{
'state': {
0: {'exp_avg': tensor[1024,1024], 'exp_avg_sq': tensor[1024,1024], 'step': 1000},
1: {'exp_avg': ..., 'exp_avg_sq': ..., 'step': 1000},
... # 每个参数对应一个 dict
},
'param_groups': [...] # 优化器配置 (lr, betas, weight_decay 等)
}
param_groups 用整数索引引用 state 中的参数。这种”间接寻址”让 optimizer 与 model 解耦,但加载时必须保证 model 参数顺序与保存时一致——否则 state[0] 错配到 weight[1],训练立刻崩盘。
DCP 也支持优化器 state ckpt 与 reshard。FSDP-2 的 optimizer state 按 rank sharded → DCP 保存每 rank 自己那 1/N → reload 时按当前 rank 数重新 reshard。这是大模型训练能从 1024 卡 fine-tune 到 64 卡推理的根本前提。
8-bit / FP8 Optimizer(如 bitsandbytes 的 8-bit Adam):把 state 量化到 8-bit 存储,state size 降到 1/4。生产代码越来越常见 —— 把 560 GB Adam state 压到 140 GB,单机训练成本立刻可承受。
19.6.20 大文件支持:ZIP64 与 stream
标准 zip 单文件最大 4 GB(header offset 是 32-bit)。70B model 单文件 280 GB → 远超限制。PyTorch 用 ZIP64 扩展(PyTorchFileWriter 自动启用):
- 文件大小用 64-bit 表示
- entry offset 也用 64-bit
- 外层结构兼容标准 zip(小文件仍可被 unzip 工具打开)
实战:280 GB .pt 文件能正常 save / load,无需用户手动切分。但有些工具不支持 ZIP64:
unzip -l:会报”file too large”错(GNU unzip 默认不带 ZIP64)zipinfo:同样限制bsdtar(macOS / FreeBSD):支持 ZIP64
诊断 .pt 文件时用 Python 的 zipfile 模块(强制 ZIP64 兼容):
import zipfile
with zipfile.ZipFile('model.pt') as z:
print(z.namelist()) # 列文件
print([info.file_size for info in z.infolist()])
这层细节不常用,但调 ckpt 问题时能避免”为什么 unzip 看不了”的困惑。生产代码强烈建议直接用 DCP(多文件,每文件 < 5GB),避开 ZIP64 的工具兼容问题。
19.6.21 自定义 reduce:让自家类能被 save
如果你的 state_dict 含自定义类(如 LR scheduler 的 metadata),默认 pickle 可能失败。解法:实现 __reduce__:
class MyScheduler:
def __init__(self, base_lr, warmup_steps):
self.base_lr = base_lr
self.warmup_steps = warmup_steps
self._cache = ... # 不想 pickle 的私有状态
def __reduce__(self):
# pickle 时只保存关键参数, 加载时重建
return (MyScheduler, (self.base_lr, self.warmup_steps))
__reduce__ 返回 (callable, args)。pickle 调用时调 callable(*args) 重建对象。这种”自定义重建”让你完全控制序列化字段。
但 weights_only=True(v2.6+ 默认)下 自定义 __reduce__ 不被信任 —— 出于安全考虑只允许白名单类型。要让自定义类支持加载:
import torch.serialization
torch.serialization.add_safe_globals([MyScheduler])
# 现在可以加载
state = torch.load('ckpt.pt')
或者用 safetensors 替代 torch.save —— 它根本不支持 Python 对象(只支持 tensor 字典),这种”局限性”反而让安全成为默认。
实战建议:state_dict 应该只含 tensor 与原生 Python 类型(int/float/str/list/dict),不含自定义类。如果非要保存复杂对象(如 scheduler 的 internal state),用 JSON 序列化简单 metadata、复杂逻辑在加载侧重建。这是从 pickle 安全 + 跨版本兼容角度都最好的实践。
19.6.22 ckpt 版本管理:HuggingFace Hub 的 git-like 模型
大模型项目都需要”版本管理”——v1.0 的 ckpt、v1.1 的 ckpt、experimental 分支的 ckpt 共存。HuggingFace Hub 用 git LFS 提供 git-like 版本管理:
# Hub 上一个 model repo 实际是 git repo
git clone https://huggingface.co/meta-llama/Llama-3-8B
ls -la
# .git/ <- git 历史
# config.json <- 元信息
# model.safetensors <- 权重 (LFS-tracked, 大文件不在 git history)
# tokenizer.json
# README.md
每次模型更新 = git commit + LFS push。用户可以:
git checkout v1.0:拿历史版本git diff main..v2.0 config.json:看不同版本配置变化git log --oneline:看模型更新历史
对应到 PyTorch 序列化层面:ckpt 不再是单文件,而是一组协同的文件(safetensors + config + tokenizer + chat template + …)。HF 把这套打包成”model card”概念,让模型分发等价于 git repo 分发。
实战影响:
- 不要把
from_pretrained("path/to/local")当成”加载一个文件”,它实际加载一个目录 + 解析 config + … - ckpt 兼容性:升级 transformers 库时旧 ckpt 可能加载失败(因为 config schema 变了)
- 自家训练的 ckpt 上传 Hub 的工程:用
model.push_to_hub()自动打包 + LFS push
理解这套生态让你看到”序列化”在 LLM 时代不只是”save tensor”,是”模型生命周期管理”的入口。PyTorch 自身只覆盖 tensor 字节,HF 生态补足版本管理 + 元信息 + 分发协议。
19.6.23 ckpt 内容 inspector
调试 ckpt 问题时常需”打开看里面有啥”,但又不想真加载几百 GB。常用工具:
# 1. 看 zip 结构 (不读 tensor 数据)
import zipfile
with zipfile.ZipFile('model.pt') as z:
for name in z.namelist():
print(name)
# 2. 看 metadata 不加载 tensor (skip_data 上下文)
with torch.serialization.skip_data():
state = torch.load('model.pt') # tensor 是 placeholder, 不读字节
print(state.keys()) # 看 fqn 列表
print({k: v.shape for k, v in state.items()})
# 3. safetensors header 直接解析
from safetensors import safe_open
with safe_open('model.safetensors', framework='pt') as f:
print(f.metadata()) # JSON header
print(list(f.keys())) # tensor 名列表
# 4. DCP 的 metadata.json 直接看
import json
with open('ckpt_dir/metadata.json') as f:
meta = json.load(f)
print(meta['storage_metadata']) # 每 chunk 的逻辑坐标
实战场景:
- 比较两个 ckpt 是否一致:第 2 种方式拿 fqn list + shape 对比
- 找某个 weight 在哪个 shard 文件:safetensors 的
weight_map(model.safetensors.index.json) - 查 ckpt 训练步数:
state['step']或state['epoch'] - 诊断 reshard 失败:看 DCP metadata.json 的 chunk 坐标是否符合预期
把这些命令做成 inspect_ckpt.py 脚本,是大模型训练运维必备的工具。
19.6.24 ONNX 互操作:跨框架部署的中间格式
ONNX(Open Neural Network Exchange)是跨框架部署的标准 IR。PyTorch 的 ONNX 导出:
import torch.onnx
torch.onnx.export(
model,
args=(example_input,),
f="model.onnx",
input_names=['input'],
output_names=['output'],
dynamic_axes={'input': {0: 'batch'}, 'output': {0: 'batch'}},
opset_version=18,
)
torch.onnx 内部有两条路径:
torch.onnx.dynamo_export(v2.0+):用 Dynamo 抓 fx graph 转 ONNX。能处理动态控制流,是默认推荐torch.onnx.export(v1.x 老版):用torch.jit.trace抓静态图。仍可用但被标记为 legacy
ONNX 的价值:
- 跨框架推理:训练用 PyTorch,推理用 TensorRT / OpenVINO / Core ML / DirectML
- 嵌入式 / 边缘:Coral / Jetson 等支持 ONNX runtime
- 优化器生态:ONNX Runtime 自带 graph optimization、量化 toolkit
代价:
- 算子覆盖度:PyTorch 几千个 op,ONNX 标准化几百个。冷门 op 导出失败
- dynamic shape 支持有限:导出时 dynamic_axes 必须显式声明,否则 shape 写死
- trace 不支持复杂控制流:含
if x.sum() > 0的 model 用 trace 导出会丢分支
实战建议:首选 .pt2 (AOTI) 服务器推理、ExecuTorch 移动端、ONNX 仅作为”必须用 TensorRT/OpenVINO” 的桥梁。PyTorch 团队近年把投入从 ONNX 转向自家 export 路径(torch.export)—— 让 PyTorch 模型在 PyTorch 生态内闭环。
19.6.25 ckpt rollback:训练失败的恢复策略
长时间训练偶尔出现 NaN loss / divergence / OOM 等问题,需要 rollback 到更早 ckpt。生产标准做法:
graph LR
Step1[step 1000<br/>ckpt A] --> Step2[step 2000<br/>ckpt B]
Step2 --> Step3[step 3000<br/>ckpt C]
Step3 --> Step4[step 4000<br/>NaN loss!]
Step4 -.rollback.-> Step3
Step3 --> Step5[step 4000 retry<br/>不同 lr / batch]
style Step4 fill:#fee2e2
style Step5 fill:#dcfce7
工程做法:
- 保留多个最近 ckpt:rolling window,如最近 5 个 step 的 ckpt
- NaN detection:每 step 后检查
loss.isnan(),立刻 abort - 自动 rollback:检测到 NaN → 自动 load 上一个 ckpt → 调小 lr 重试 → 继续
LLM training infra 都内置这套。Megatron-LM、HuggingFace Trainer 都有 --save_total_limit=5(保留 5 个)+ --resume_from_checkpoint=latest(自动找最新)+ NaN guard。
对应到 PyTorch 序列化层面:ckpt 路径要含 step 号方便排序:
ckpt_step_1000/
ckpt_step_2000/
ckpt_step_3000/
ckpt_step_4000/ <- corrupted, skip
rollback 脚本逻辑:
import os, glob
ckpts = sorted(glob.glob('ckpt_step_*'),
key=lambda x: int(x.split('_')[-1]))
for c in reversed(ckpts):
if validate(c): # checksum + 加载 + 跑一个 batch 不出 NaN
return c
print(f"corrupted: {c}, trying older")
这种”层层回滚找最近完好版本”是大模型训练的运维标准动作。理解这层让你看到 ckpt 不只是”备份”,是 training infra 的核心 control plane。
19.6.26 推理服务的 ckpt 加载优化
vLLM / SGLang 等推理引擎的”启动慢”问题常被诟病。70B 模型加载 5-10 分钟。生产优化方向:
1. 模型预热池
服务集群保持 N 个 worker 进程,每个预先加载好 model。新请求来时直接路由到 worker,不重新 load。代价是常驻 RAM 高(每 worker 一份 weight)。
2. 共享内存加载
多个 worker 进程共享同一份 weight:master 进程 mmap 一次,子进程 fork 后共享 page。Linux CoW 让”读”操作不复制内存,N 个 worker 仅占 1 份 weight。
# master
weight_mmap = torch.load('model.pt', mmap=True)
# fork 子进程
for _ in range(num_workers):
pid = os.fork()
if pid == 0:
# 子进程: weight_mmap 与 master 共享
serve(weight_mmap)
CUDA tensor 复杂些 —— GPU memory 不能 fork 共享。要用 torch.multiprocessing.share_memory_() 让 cuda tensor 走 IPC。但 vLLM、SGLang 实测表明这套能把 8 worker × 70B 从 1120 GB 降到 140 GB。
3. layered loading
启动时只加载 embedding + 第一层,后续 layer lazy load(首请求时 page-in)。让首响应时间提前 5-10 秒,但代价是首请求慢。适合”长尾请求 + 低 QPS”场景。
4. tensor parallel pre-shard
ckpt 保存时按 TP 维度预先切好(如 8 个文件、每文件 1/8)。各 GPU 直接读自己负责的文件,跳过 master broadcast。生产部署的标准模式。
理解这套让你看 vLLM 启动几秒到几十秒不是恒定数字,是”工程取舍 × 硬件配置”的输出。每种优化都有 tradeoff:内存 vs 启动时间 vs 复杂度。
19.6.27 跨格式转换矩阵
PyTorch 生态有多种 ckpt 格式,转换关系:
| 源格式 → 目标 | 工具 | 备注 |
|---|---|---|
.pt (pickle) → .safetensors | safetensors 库 | 安全 + 加载快,强烈推荐 |
.safetensors → .pt (pickle) | safetensors 库 | 兼容老代码 |
.pt / .safetensors → .pt2 (AOTI) | torch._inductor.aoti_compile_and_package | 服务器推理 |
.pt / .safetensors → .pte (ExecuTorch) | executorch 库 | 移动端 / 嵌入式 |
.safetensors → ONNX | transformers 内置 | 跨框架 |
.safetensors → GGUF | llama.cpp/convert.py | CPU / Mac 推理 |
.bin(HF legacy)→ .safetensors | transformers 库 | HF 现在默认输出后者 |
TF SavedModel → .safetensors | 自家工具 | 偶尔用,质量参差 |
没有”通用万能格式”。每种服务一类场景:
graph TB
Train[训练阶段] --> Pt[.pt / .safetensors]
Pt -->|服务器推理| Aoti[.pt2 AOTI]
Pt -->|移动端| Pte[.pte ExecuTorch]
Pt -->|跨框架| Onnx[ONNX]
Pt -->|CPU 推理| Gguf[GGUF]
Pt -->|分发| Hub[HuggingFace Hub]
style Pt fill:#fef3c7
style Aoti fill:#dcfce7
style Pte fill:#dbeafe
style Gguf fill:#fce7f3
工程决策树:
- 训练 → 还要继续训练:用 DCP,支持 reshard
- 训练 → 单机推理 (GPU 服务器):safetensors + AOTI 编译
- 训练 → 移动端:ExecuTorch
- 训练 → 别的框架:ONNX(最后选择)
- 训练 → 社区分发:safetensors + HF Hub
- 训练 → CPU / Mac 端推理:转 GGUF
理解这个矩阵让你不会盲目”转 ONNX 部署”——很多场景 .pt2 / .pte 更优。每条转换链路都有 tradeoff,用最少的转换跳转达到目标平台是工程正解。
19.6.28 ckpt 加密:付费模型的工程做法
商业模型(如付费版本的微调模型)需要防止权重被直接复制使用。常见方案:
1. ckpt 整体加密
import cryptography
key = derive_key(license) # 用 license 算出对称密钥
encrypted = AES.encrypt(open('model.safetensors', 'rb').read(), key)
open('model.enc', 'wb').write(encrypted)
# 解密 + 加载
data = AES.decrypt(open('model.enc', 'rb').read(), key)
state = safetensors.load(BytesIO(data))
代价:解密时整个模型在内存里,运行时机器仍能 dump RAM 拿到明文。
2. weight 分层加密
只加密关键 layer(如 attention),公开 backbone。攻击者拿到部分 weight 不够推理,提高破解成本。但密码学上仍是”obfuscation 而非真加密”。
3. TEE / SGX 内运行
模型 weight 在 trusted execution environment 内解密 + 推理,OS 都看不到明文。代价:只在特定硬件可用 + 性能损失大。
4. 服务化 (model as service)
最常用的”加密”是不分发 weight、只分发 API。客户付费调用 API、不能拿权重。OpenAI、Anthropic 都用这种。
PyTorch 序列化层不直接支持加密(设计哲学是”开放”),上层应用自己实现。safetensors 的 metadata 字段可以存 license 信息,但不强制 enforce。
理解这层让你看到”模型保护”是个产品 / 商业问题,不是技术问题 —— PyTorch 把权重存成可读字节是正确选择,分发策略另说。
19.6.29 ckpt 设计的演进时间线
PyTorch 序列化的几个关键节点:
| 版本 | 改进 | 意义 |
|---|---|---|
| v0.4 | torch.save 用 zip 容器 | mmap 加载成为可能 |
| v1.0 | API 稳定 | 生产可用 |
| v1.6 | new zip format(_use_new_zipfile_serialization) | 更可靠的 zip 实现 |
| v1.10 | torch.load 支持 fsspec 路径 | 云存储集成 |
| v1.12 | DCP 引入(实验性) | 分布式 ckpt 雏形 |
| v2.0 | DCP 稳定 + 内置 | 千卡训练成为可能 |
| v2.4 | async DCP 支持 | 训练不被 ckpt 阻塞 |
| v2.4 | DCP reshard 稳定 | 加载时灵活切片 |
| v2.6 | weights_only=True 默认 | RCE 防御默认开 |
| v2.8 | safetensors 集成进 .pt2 | 安全 + 性能默认 |
| v2.10 | DCP storage backends 完善(S3/GCS native) | 云原生训练 |
| v2.11 | API 稳定,生态成熟 | 大模型训练标配 |
整体趋势:
- v0.x-v1.x:从单文件 pickle 走向 zip 容器,让 mmap / 大文件可行
- v1.x-v2.x:从单机走向分布式(DCP),让千卡训练 ckpt 可行
- v2.x:从可用走向安全(weights_only)+ 性能(async)+ 云原生
理解这条演进让你看 PyTorch 团队对 ckpt 的持续投入 —— 每个 minor version 都有 ckpt 相关改进。这是因为大模型时代 ckpt 性能 = 训练成本 —— 每 100 step ckpt 一次,写 ckpt 慢 1 分钟 × 1000 次 step = 16 小时浪费。改进 ckpt 性能直接省钱。
19.6.30 一个 70B 训练 ckpt 配置参考
把全章的工程实践合起来,给一个 70B 训练 ckpt 的真实配置:
import torch
import torch.distributed.checkpoint as dcp
from datetime import timedelta
CKPT_INTERVAL = 200 # 每 200 step 保存一次
CKPT_RETAIN = 5 # 保留最近 5 个
CKPT_DIR = "s3://my-org-bucket/training/run_xxx/ckpts"
def save_checkpoint(model, optimizer, scheduler, step, sampler):
state_dict = {
'model': model.state_dict(),
'optim': optimizer.state_dict(),
'scheduler': scheduler.state_dict(),
'sampler': sampler.state_dict(),
'step': step,
'rng_states': {
'cpu': torch.get_rng_state(),
'cuda': torch.cuda.get_rng_state_all(),
},
}
# async + DCP + 写 S3
future = dcp.async_save(
state_dict,
checkpoint_id=f"{CKPT_DIR}/step_{step}",
)
return future
def load_latest_checkpoint(model, optimizer, ...):
ckpts = list_s3_dir(CKPT_DIR)
for ckpt_id in sorted(ckpts, key=lambda x: int(x.split('_')[-1]), reverse=True):
try:
state_dict = init_state_dict(model, optimizer, ...)
dcp.load(state_dict, checkpoint_id=ckpt_id)
validate_loaded(state_dict) # checksum + 一个 dummy forward
return state_dict
except Exception as e:
print(f"corrupted: {ckpt_id}, trying older: {e}")
raise RuntimeError("no valid ckpt found")
def cleanup_old_checkpoints():
ckpts = sorted(list_s3_dir(CKPT_DIR), key=lambda x: int(x.split('_')[-1]))
for old in ckpts[:-CKPT_RETAIN]:
delete_s3_dir(old)
主训练循环:
ckpt_future = None
for step, batch in enumerate(loader):
if step % CKPT_INTERVAL == 0:
if ckpt_future is not None:
ckpt_future.result() # 等上一个 ckpt 写完
ckpt_future = save_checkpoint(model, optim, sched, step, sampler)
cleanup_old_checkpoints()
loss = train_step(batch)
if torch.isnan(loss):
# NaN 检测 + 自动 rollback
state = load_latest_checkpoint(...)
# 调小 lr 重试
for g in optim.param_groups:
g['lr'] *= 0.5
这套配置覆盖了:
- DCP(千卡 reshard 友好)
- async(不阻塞训练)
- S3(云原生,无需本地落盘)
- rolling retention(控制存储成本)
- NaN guard(自动 rollback)
- RNG 完整保存(resume 可复现)
- checksum 验证(silent corruption 防御)
理解每条配置的 why,让你能为自家训练任务量身定制。这套配置在 HuggingFace Trainer / Megatron-LM / DeepSpeed 等主流训练框架的 ckpt 默认实现中都能找到对应实现,是公开社区在大模型训练中沉淀下来的工程实践。
19.6.31 自定义 DCP storage backend:写一个云存储后端
StorageWriter / StorageReader 是 DCP 的扩展点。第三方实现这两个抽象类即可让 DCP 写到任意存储。简化示例:
from torch.distributed.checkpoint.storage import StorageWriter, StorageReader
class MyS3Writer(StorageWriter):
def __init__(self, bucket: str, prefix: str):
self.bucket = bucket
self.prefix = prefix
self.s3 = boto3.client('s3')
def reset(self, checkpoint_id):
self.checkpoint_id = checkpoint_id
def set_up_storage_writer(self, is_coordinator):
# rank 0 (coordinator) 创建 prefix 目录
if is_coordinator:
self.s3.put_object(Bucket=self.bucket,
Key=f"{self.prefix}/{self.checkpoint_id}/")
def write_data(self, plan, planner):
# 把 plan 中所有 WriteItem 写到 S3
for item in plan.items:
data = planner.resolve_data(item)
key = f"{self.prefix}/{self.checkpoint_id}/__{item.index.fqn}.bin"
self.s3.put_object(Bucket=self.bucket, Key=key, Body=data)
return Future() # 完成 future
def finish(self, metadata, results):
# 写 metadata.json
self.s3.put_object(
Bucket=self.bucket,
Key=f"{self.prefix}/{self.checkpoint_id}/metadata.json",
Body=json.dumps(metadata).encode(),
)
# 用法
dcp.save(state_dict,
storage_writer=MyS3Writer('my-bucket', 'training/run_1'))
实战中 v2.10+ 的 PyTorch 已经内置 fsspec-based S3 支持,自家写 backend 仅在特殊需求时用:
- 自家对象存储(如内部 ceph cluster)
- 加密存储(写前 encrypt)
- 跨区域复制(写多个 region 提高可用性)
- 进度监控(每 chunk 写完发 metric 到监控系统)
理解 storage backend 的接口让你看 DCP 不是绑死本地文件系统的,是真正的 pluggable。生产场景如果对 ckpt 存储有特殊需求,扩展这层就够、不需要 fork DCP 自身。
19.7 几条工程经验
1. 默认 weights_only=True:从 v2.6 起这是默认。如果遇到加载报错说”unsupported global”,先想想是否真的需要那个类,能换掉就换
2. 分发模型给社区用 safetensors:兼容性、安全性、加载速度全面优于 pickle
3. 大模型加载默认 mmap=True:v2.4+ 已经在改善默认行为,但显式开总是稳
4. 训练 ckpt 推荐 DCP:单 rank 写大文件已经过时
5. torch.serialization.skip_data 上下文:load 时只想看元数据(如查 ckpt 含哪些 key)不想读数据,进这个上下文跳过实际字节加载
6. ckpt 校验:写完 ckpt 后立刻 load 一遍验证,避免训练几小时后才发现 ckpt 坏了。可以用单独 rank 异步验证
7. 跨 PyTorch 版本加载:state_dict 通常向后兼容。整对象 pickle 跨版本可能 deserialize 失败 —— 又一个用 state_dict 的理由
8. fsspec 路径:torch.save(state, 's3://bucket/ckpt.pt') 现在直接支持云存储(v2.x+)。生产训练直接写 S3 / OSS 不需要本地落盘
19.8 跨书关联
- 第 9 章 §9.6 nn.Module.state_dict:序列化的”上游”,本章是它的”下游”
- 第 18 章 FSDP:FSDP 训练必用 DCP 否则 ckpt 不可行
- 《vLLM 内核探秘》第 7 章 模型加载:vLLM 推理时大量用 safetensors + mmap 加载预训练模型
19.9 设计启示
序列化设计的核心思想:
第一:容器格式 + 数据分离:zip + tensor entries 让 mmap、并行加载、增量更新成为可能。比单一 blob 序列化灵活得多
第二:weights_only 把”加载”从代码执行降级到字节解析:这条原则在所有”加载第三方数据”场景都成立 —— JSON / Protobuf / safetensors 都比 pickle 安全。能不执行就不执行,是数据格式设计的底线
第三:逻辑坐标而非物理位置:DCP 用逻辑张量坐标解耦 “保存时拓扑” 与 “加载时拓扑”,让 ckpt 可重 reshard。这条思想在分布式数据库分片管理也常见,是”让数据脱离硬件假设”的工程标准做法
第四:state_dict 让模型架构与权重解耦:架构是 Python 代码(活的、可演进),权重是数据(死的、可移植)。这种”代码 + 数据”分离是任何”长寿模型”的工程基础
第五:异步、并行、增量是大规模训练的三件套:async save 让训练不阻塞、DCP 让每 rank 并行写、PEFT/LoRA 让分发只发增量。三个机制叠加才让千卡训练 ckpt 从”工程难题”变成”标准操作”。这条”在多个维度同时下功夫”的设计思路,对任何重型 IO 系统的设计都成立 —— 不要指望一招制胜,要在每条流水线上都挖性能
下一章拆量化与混合精度训练 —— FSDP 已经给了精度策略的初步配置,这章把整个 PyTorch 量化生态摊开,看 autocast 的 CastPolicy 决策、GradScaler 的动态范围管理、PT2E 量化框架怎么把整个量化栈现代化。
评论 0
还没有评论,来说两句吧。
评论加载失败,刷新重试。