第9章 nn.Module:参数注册、Hooks、序列化
“nn.Module looks magical because it does so many things in
__setattr__. Once you understand that one method, the whole system becomes obvious.”—— PyTorch core team note
本章要点
nn.Module.__setattr__拦截赋值:检测 value 类型决定放到_parameters/_buffers/_modules中的哪一个。这是”self.linear = nn.Linear(...)自动注册”的全部魔法__getattr__是补救机制:当用户访问的属性不在__dict__时,从三个 dict 里查找。linear.weight实际是从linear._parameters['weight']取出state_dict()递归收集:先调本模块的_save_to_state_dict,再递归_modules.items()。返回的字典 key 是 dotted path(如'layer1.linear.weight')load_state_dict严格匹配 + 名字模糊容忍:strict=True默认要求 key 完全匹配,但提供详细的 missing/unexpected 报告供用户排查- 8 种 hook 协同工作:forward_pre / forward / backward / full_backward / full_backward_pre / state_dict_pre / state_dict / load_state_dict_pre。让用户能在 Module 生命周期的任何点插入自定义逻辑
__call__ = _wrapped_call_impl:module(x)不等于module.forward(x),它包裹了 hooks 调用、autograd 反向 hook 注入、错误处理
9.1 一段神奇的代码
每个 PyTorch 用户都写过这样的模型:
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(10, 20)
self.linear2 = nn.Linear(20, 30)
self.bn = nn.BatchNorm1d(20)
self.dropout = nn.Dropout(0.5)
def forward(self, x):
x = self.linear1(x)
x = self.bn(x)
x = self.linear2(x)
return x
model = MyModel()
print(list(model.parameters())) # 自动列出 linear1.weight, linear1.bias, linear2.weight, ..., bn.weight, bn.bias
print(model.state_dict().keys()) # 自动收集 'linear1.weight', 'bn.running_mean', ...
torch.save(model.state_dict(), 'ckpt.pt') # 序列化
魔法到处都是:
self.linear1 = nn.Linear(...)怎么自动加入model.parameters()?bn.running_mean(不是 Parameter)怎么也进了 state_dict?state_dict()的 key'linear1.weight'是怎么拼出来的?model(x)调forward(x)还是别的?
这些问题的答案都在 torch/nn/modules/module.py(3054 行)。本章拆它的设计。
9.1.1 nn.Module 的设计目标
Module 不只是”参数容器” —— 它是 PyTorch 模型的统一抽象,要解决几个问题:
- 递归组合:模型由层组成,层又由更小层组成。需要无限递归
- 参数自动管理:用户写
self.linear = nn.Linear(...)自动让 optimizer 看到所有权重 - 状态保存与加载:训练 ckpt、模型部署都需要”把所有可变状态存到磁盘”
- 跨 device 迁移:
.cuda()/.cpu()/.half()要递归整个网络 - 训练 / 推理模式切换:BN / Dropout 等需要在 train / eval 时不同行为
- 可扩展性:第三方框架(HuggingFace、PyTorch Lightning、DDP 等)能 hook 进 Module 生命周期
这些目标共同塑造了 Module 的内部结构。理解每个设计选择背后的目标,能让你看 nn.Module 源码不再迷失在 3054 行细节里。
9.2 __init__:建立四个核心容器
打开 torch/nn/modules/module.py:483:
class Module:
def __init__(self, *args: Any, **kwargs: Any) -> None:
...
super().__setattr__("_parameters", {})
super().__setattr__("_buffers", {})
super().__setattr__("_non_persistent_buffers_set", set())
super().__setattr__("_backward_pre_hooks", OrderedDict())
super().__setattr__("_backward_hooks", OrderedDict())
super().__setattr__("_is_full_backward_hook", None)
super().__setattr__("_forward_hooks", OrderedDict())
super().__setattr__("_forward_hooks_with_kwargs", OrderedDict())
super().__setattr__("_forward_hooks_always_called", OrderedDict())
super().__setattr__("_forward_pre_hooks", OrderedDict())
super().__setattr__("_forward_pre_hooks_with_kwargs", OrderedDict())
super().__setattr__("_state_dict_hooks", OrderedDict())
super().__setattr__("_state_dict_pre_hooks", OrderedDict())
super().__setattr__("_load_state_dict_pre_hooks", OrderedDict())
super().__setattr__("_load_state_dict_post_hooks", OrderedDict())
super().__setattr__("_modules", {})
注意每行都用 super().__setattr__ 而不是 self.xxx = ... —— 这是为了绕过自己重写的 __setattr__(直接给 __dict__ 赋值)。否则会触发无限递归。
四个核心容器:
| 字段 | 类型 | 用途 |
|---|---|---|
_parameters | dict[str, Parameter] | 可学习参数(weight、bias 等) |
_buffers | dict[str, Tensor] | 不可学习但需保存的状态(如 BN 的 running_mean) |
_non_persistent_buffers_set | set[str] | 不存进 state_dict 的 buffer 名(如 dropout 的随机数生成器) |
_modules | dict[str, Module] | 子模块(自动递归整个网络) |
剩下的 12 个字段都是 hook 容器(OrderedDict 保证调用顺序)。Module 的所有”高级特性”都建立在这 16 个字段之上。
为什么用 OrderedDict 而不是 list?因为 hook 注册返回 RemovableHandle,handle 持有一个 id,用户调 handle.remove() 时通过 id 从 dict 删除。dict 的随机访问 O(1) 比 list 的线性查找方便。OrderedDict 保留插入顺序让多个 hook 按注册顺序调用 —— 这是用户期望的语义。
至于 _parameters / _buffers / _modules 用普通 dict 而不是 OrderedDict:v3.7+ 普通 dict 已经保证插入顺序,不需要显式 OrderedDict。但 PyTorch 在保留某些 OrderedDict 是为了与老 Python 兼容,加上代码风格惯性。
graph TB
M[nn.Module 实例]
M --> P["_parameters dict<br/>weight / bias / ..."]
M --> B["_buffers dict<br/>running_mean / num_batches_tracked"]
M --> Mods["_modules dict<br/>子模块递归"]
M --> H["8 种 _xxx_hooks<br/>前/后/状态/反向"]
Mods --> Sub1[子 Module 1]
Mods --> Sub2[子 Module 2]
Sub1 --> SubP1[子 Module 1 的 _parameters]
Sub1 --> SubB1[子 Module 1 的 _buffers]
Sub1 --> SubM1[子 Module 1 的 _modules]
style P fill:#dbeafe,stroke:#3b82f6
style B fill:#fef3c7,stroke:#f59e0b
style Mods fill:#dcfce7,stroke:#22c55e
9.3 __setattr__:拦截赋值的核心魔法
flowchart TD
Set["self.x = value"]
Set --> Q1{value 是 Parameter?}
Q1 -->|是| P1["注册到 _parameters dict<br/>从 _modules / _buffers 清理"]
Q1 -->|否| Q2{value 是 Module?}
Q2 -->|是| P2["注册到 _modules dict<br/>建立子模块树"]
Q2 -->|否| Q3{value 是 Buffer?}
Q3 -->|是| P3["注册到 _buffers dict<br/>跟随 .cuda 等迁移"]
Q3 -->|否| Q4{name 在某个 dict 里?}
Q4 -->|是| P4["更新对应 dict<br/>类型必须匹配"]
Q4 -->|否| P5["走 object.__setattr__<br/>普通 Python 属性"]
style P1 fill:#fef3c7
style P2 fill:#dcfce7
style P3 fill:#dbeafe
style P5 fill:#f3e8ff
module.py:1972 的 __setattr__ 是 nn.Module 最重要的方法,简化版:
def __setattr__(self, name: str, value):
params = self.__dict__.get("_parameters")
if isinstance(value, Parameter):
# 1. value 是 Parameter → 注册到 _parameters
remove_from(self.__dict__, self._buffers, self._modules, ...)
self.register_parameter(name, value)
elif params is not None and name in params:
# 2. 已存在的 param 被赋 None → 删除
if value is not None:
raise TypeError("cannot assign ... as parameter ... ")
self.register_parameter(name, value)
else:
modules = self.__dict__.get("_modules")
if isinstance(value, Module):
# 3. value 是 Module → 注册到 _modules
remove_from(self.__dict__, self._parameters, self._buffers, ...)
modules[name] = value
elif modules is not None and name in modules:
# 4. 已存在的子模块被赋值
modules[name] = value
else:
buffers = self.__dict__.get("_buffers")
if isinstance(value, Buffer) or (buffers is not None and name in buffers):
# 5. value 是 Buffer 或已存在的 buffer 名 → 注册到 _buffers
...
else:
# 6. 普通属性 → 走默认 __setattr__
object.__setattr__(self, name, value)
——这就是”self.linear1 = nn.Linear(...) 自动注册”的真相。Python 的 __setattr__ 协议被拦截,根据 value 类型路由到不同容器。
9.3.0.5 为什么要拦截赋值:备选方案的对比
PyTorch 的__setattr__ 拦截不是唯一方案。其他框架处理”自动注册参数”的不同思路:
- Keras(早期):要求用户在
build()方法里显式调self.add_weight(...)。冗长但显式 - JAX/Flax:用 dataclass 字段声明,外部 framework 通过 reflection 收集字段。需要
flax.linen.Module装饰器 - TensorFlow:用
tf.Variable在某个tf.Module子类上自动追踪 - PaddlePaddle:与 PyTorch 类似,
__setattr__拦截
PyTorch 选择 __setattr__ 拦截的好处是 用户写代码与普通 Python 完全一致 —— 不需要装饰器、不需要 build 方法、不需要 add_weight 调用。代价是 nn.Module 是个”魔法重”的类,新人学起来需要一段时间适应它的行为不是普通 Python 类。这种”直观但魔法”vs”显式但冗长”是框架设计的永恒权衡。
9.3.1 一个具体追踪:self.linear = nn.Linear(10, 20)
在 MyModel.__init__ 里写 self.linear = nn.Linear(10, 20) 时:
- Python 解释器调
MyModel.__setattr__('linear', nn.Linear(10, 20)) MyModel.__setattr__实际是继承的Module.__setattr__- 检查 value 类型:
isinstance(nn.Linear(...), Parameter)→ False,isinstance(nn.Linear(...), Module)→ True - 走分支 3:
self._modules['linear'] = nn.Linear(...) - 注意
self.linear这个名字根本没进self.__dict__
这就是为什么后面 model.linear 还能取到值 —— 通过 __getattr__ 的兜底机制。
注意 __setattr__ 里的 remove_from(...) 调用 —— 它从其他容器里删除同名条目。这是为了避免”同一个名字同时在 _parameters 和 _modules 里”的不一致状态。如果你先 self.x = nn.Parameter(...) 再 self.x = nn.Linear(...),Parameter 会从 _parameters 删除、Linear 进入 _modules。这种”覆盖式赋值”的语义和普通 Python 类的 self.x = ... 完全一致,用户感受不到区别。
9.3.1.5 一个反直觉的细节:覆盖式赋值如何工作
class M(nn.Module):
def __init__(self):
super().__init__()
self.x = nn.Linear(10, 20) # 进 _modules
self.x = nn.Parameter(torch.randn(5)) # 此时 _modules 里的旧值被删, 进 _parameters
第二次赋值时 __setattr__ 检测 value 是 Parameter,先调 remove_from(self.__dict__, self._buffers, self._modules, ...) 删除其他容器里的同名旧值,再 register_parameter。这套”先清后加”语义让覆盖式赋值表现得和普通 Python 对象一致 —— 用户感受不到内部容器切换。
9.3.2 为什么 Parameter 要单独继承
torch.nn.Parameter 是 torch.Tensor 的子类,但本身不加任何新功能:
class Parameter(torch.Tensor):
pass # 就是个空壳子类
它存在的唯一意义是 当作 isinstance 检查的标记。让 __setattr__ 能区分”这是要训练的张量”和”这是普通中间张量”。如果你写:
self.weight = torch.randn(10, 20, requires_grad=True) # ❌ 不会进 _parameters
self.weight = nn.Parameter(torch.randn(10, 20)) # ✅ 进 _parameters
第一种写法 __setattr__ 检测不到 Parameter 类型,直接进 __dict__,不会被 model.parameters() 列出,optimizer 找不到它。这是新手常踩的坑。
9.4 __getattr__:从三个 dict 兜底
module.py:1955 的 __getattr__ 处理”读”路径:
def __getattr__(self, name: str):
if "_parameters" in self.__dict__:
_parameters = self.__dict__["_parameters"]
if name in _parameters:
return _parameters[name]
if "_buffers" in self.__dict__:
_buffers = self.__dict__["_buffers"]
if name in _buffers:
return _buffers[name]
if "_modules" in self.__dict__:
modules = self.__dict__["_modules"]
if name in modules:
return modules[name]
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
Python 的 __getattr__ 只在标准属性查找失败时才被调。所以普通属性走 __dict__ 查找正常路径,特殊属性(参数/buffer/子模块)走兜底。
linear.weight 的查找流程:
- Python 先查
linear.__dict__['weight']—— 没有(因为__setattr__没把它放这) - 查
type(linear).__dict__['weight']—— 没有(不是类属性) - 走
__getattr__('weight')—— 在_parameters['weight']里找到,返回
这种”读 / 写不对称”是 Python 元编程的经典套路:写靠 __setattr__ 路由到正确容器,读靠 __getattr__ 从容器兜底取出。用户全程无感。
9.4.1 性能取舍:每次属性访问都过 __getattr__ 慢吗
新手经常担心:每次 self.linear 都过 __getattr__,会不会比普通属性访问慢?
实测一下:CPython 普通属性访问约 30 ns,过 __getattr__ 约 80 ns。差距 50 ns。对每秒上百万次属性访问的代码(如递归遍历模块树)累积可观,但对单次 forward 调用(几毫秒到几百毫秒)几乎不可见。
PyTorch 的 trick 是 forward 里访问的子模块已经被局部变量缓存:
def forward(self, x):
x = self.linear1(x) # 第一次访问 linear1, 走 __getattr__
x = self.linear2(x) # 第二次访问 linear2, 也走 __getattr__
return x
每次只查一次。如果你写 for _ in range(1000): x = self.linear(x),每次循环都过 __getattr__ —— 这种代码在 PyTorch 里少见但确实有性能损失。torch.compile 在编译时会把这种重复 attribute access 特殊化到具体的子模块对象,消除 __getattr__ 开销。第 12 章 Dynamo 章会展开。
9.5 __call__:模型调用的真实入口
很多人以为 model(x) 等价于 model.forward(x)。错。看 module.py:1918:
__call__: Callable[..., Any] = _wrapped_call_impl
_wrapped_call_impl 不直接调 forward,它包裹了一层逻辑:
def _wrapped_call_impl(self, *args, **kwargs):
# 1. 如果有 forward_pre_hooks, 跑
# 2. 如果有 backward_hooks, 设置 BackwardHook 包装输入
# 3. 调 forward
# 4. 如果有 forward_hooks, 跑
# 5. 如果有 non_full_backward_hooks, 给输出的 grad_fn 注册 hook
# 6. 错误处理 + always_called hooks
return result
所以 model(x) ≠ model.forward(x)。两者的差别在于前者跑 hooks,后者不跑。这就是为什么生产代码要用 model(x) 而不是 model.forward(x) —— hooks 是 PyTorch 各种工具(profiler、Hook-based pruning、量化等)的扩展点,绕过它们会导致这些工具失效。
9.5.1 _compiled_call_impl 与 torch.compile 的协作
module.py 还有一个重要变量 _compiled_call_impl。当用户对 module 调 torch.compile(model) 时,PyTorch 把编译后的 forward 实现挂到这里,__call__ 优先调用 _compiled_call_impl 而不是普通 forward。
这种”插槽式”设计让 torch.compile 可以透明加速 module,用户代码不用改:
model = MyModel()
model = torch.compile(model) # 替换内部 _compiled_call_impl
out = model(x) # 走编译后的 forward
第 15 章 torch.compile 端到端会展开这套协作。注意 __getstate__ 在序列化时会剔除 _compiled_call_impl(因为它依赖运行时编译产物,不能跨进程传递)。这种”序列化时剥离运行时态”是 Module 兼容多种使用场景的细节。
9.5.2 forward 是 Python 函数还是 Module 方法
很多人没意识到:forward 是普通 Python 方法,不是任何特殊声明。Module 的 __call__ 通过 forward_call = self.forward 拿到这个方法,调用时传递参数。这意味着:
- 你可以在 forward 里写任何 Python 代码(if/for/while/try)
- forward 可以接受任何参数(位置 / 关键字 / 默认值)
- forward 可以返回任何东西(tensor、tuple、dict、list、None)
这种自由度是 PyTorch 动态图的体现。代价是:forward 必须在每次 module(x) 都被解释执行(除非用 torch.compile),无法像静态图框架那样跨调用复用编译产物。
子类可以选择重写 __call__ 而非 forward(如 PyTorch Lightning 的 LightningModule)—— 但绝大多数代码遵循 “只重写 forward” 的惯例,让 hook 等机制不被破坏。
9.5.3 reentrant 的 forward 调用
如果在 forward 里再调一次 self(x)(递归调用同一个 module),会怎样?
class RecursiveModule(nn.Module):
def forward(self, x, depth):
if depth == 0:
return x
return self(x, depth - 1) # 递归调用自己
这是合法的,会跑两次 hooks(外层和内层各一次)。但 backward 时反向图会反映出递归结构 —— 同一个 grad_fn 实例可能在反向链上出现多次。这种”算子复用”在某些 RNN 实现里出现,PyTorch 完全支持。
sequenceDiagram
autonumber
participant U as user: model(x)
participant W as _wrapped_call_impl
participant Pre as forward_pre_hooks
participant F as forward(x)
participant BH as BackwardHook
participant Post as forward_hooks
U->>W: __call__(x)
W->>Pre: 跑所有 pre-hook
Pre-->>W: 可能修改 args
W->>BH: setup_input_hook (注册反向 hook)
W->>F: forward(args)
F-->>W: result
W->>BH: setup_output_hook
W->>Post: 跑所有 post-hook
Post-->>W: 可能修改 result
W-->>U: result
9.6 state_dict:递归收集状态
module.py:2195 的 state_dict 是模型保存的核心。简化逻辑:
def state_dict(self, *, destination=None, prefix="", keep_vars=False):
if destination is None:
destination = OrderedDict()
destination._metadata = OrderedDict()
local_metadata = dict(version=self._version)
destination._metadata[prefix[:-1]] = local_metadata
for hook in self._state_dict_pre_hooks.values():
hook(self, prefix, keep_vars)
self._save_to_state_dict(destination, prefix, keep_vars)
# 递归子模块
for name, module in self._modules.items():
if module is not None:
module.state_dict(
destination=destination,
prefix=prefix + name + ".",
keep_vars=keep_vars,
)
return destination
_save_to_state_dict 把当前 module 的 _parameters 和 persistent _buffers 写到 destination:
def _save_to_state_dict(self, destination, prefix, keep_vars):
for name, param in self._parameters.items():
if param is not None:
destination[prefix + name] = param if keep_vars else param.detach()
for name, buf in self._buffers.items():
if buf is not None and name not in self._non_persistent_buffers_set:
destination[prefix + name] = buf if keep_vars else buf.detach()
注意 递归 + prefix 是 dotted path 的来源:
- 调用
model.state_dict(),prefix ="" - 进
model.linear1.state_dict(prefix="linear1.") - 在 linear1 里
_save_to_state_dict把_parameters['weight']存到destination['linear1.weight'] - 类似产生
'linear1.bias'、'linear2.weight'、'bn.running_mean'…
这种”前缀 + 递归”的设计让任意深度的 Module 树扁平化为一个 dict[str, Tensor],序列化非常友好。
值得注意一个工程细节:state_dict 返回的是 OrderedDict + _metadata 字段。_metadata 是个隐藏字段(用 _metadata 而不是 metadata,避免冲突),存每个模块的”版本号”。_load_from_state_dict 在加载时会检查 version,调用对应的迁移逻辑 —— 让 PyTorch 能在某个 module 类升级 API 时仍然加载旧 ckpt。这套版本号机制大多数用户感受不到,但是 PyTorch 跨版本 ckpt 兼容的关键。
9.5.4 forward 与 no_grad 的细微交互
如果你在 forward 里写 with torch.no_grad(): 包住一段代码,会怎样?
def forward(self, x):
with torch.no_grad():
prepared = self.preprocess(x) # 这段不参与反向
return self.main(prepared)
注意 prepared 没有 grad_fn —— 它从 no_grad 块出来。self.main(prepared) 反向时反向链就在 prepared 这里断了:main 的反向能算出 grad 给 prepared,但 prepared.grad_fn 是 None,反向终止,不会传给 preprocess 的参数。
这是 PyTorch 给”前处理 / 特征提取不参与训练”的标准模式。Vision Transformer 的”frozen ViT backbone + 训练 head”等冻结预训练模型场景就靠这个。比手动设置 param.requires_grad = False 更直接、也更难误用。
9.6.0.5 一个常见误解:state_dict 是不是模型本身
新手以为 torch.save(model.state_dict()) 等价于”保存了整个模型”。错:state_dict 只保存权重和 buffer 值,不保存模型类定义。要重建模型必须先 model = MyModel(...) 用 Python 构造,再 model.load_state_dict(...) 灌入权重。
如果想”完整保存模型对象”,要 torch.save(model)(不传 state_dict)—— 这会用 pickle 序列化整个 model 对象。但这种方式有个大坑:pickle 依赖类定义。如果你的 MyModel 类定义改了(哪怕只是改了 method 实现),加载 pickle 时会报错或行为不一致。
所以生产代码强烈建议用 state_dict 而非整对象 pickle:让模型架构(Python 代码)与权重(state_dict)完全解耦,各自独立演进。HuggingFace、PyTorch Lightning 等所有正经框架都遵循这个原则。
9.6.1 keep_vars=True 的妙用
默认 keep_vars=False:state_dict 里的张量都被 .detach() —— 与原模型断开 autograd 关系。这是为了避免序列化时把整个反向图也存进去。
但训练用的”checkpoint averaging”或者 “Polyak averaging” 等技术需要 state_dict 仍然连着 autograd 图:
# averaged model 的 weights 仍然是 differentiable
avg_dict = {k: v for k, v in model.state_dict(keep_vars=True).items()}
这是为研究场景留的逃生口。普通用户用默认 detach 就够。
9.6.2 _non_persistent_buffers_set 的设计
不是所有 buffer 都该进 state_dict。比如 BatchNorm 的 num_batches_tracked 是必要的(要存),但某些用户自己加的”调试用 running stats”可能不该存。Module.register_buffer(name, tensor, persistent=False) 让用户标记某 buffer 不进 state_dict —— 名字加入 _non_persistent_buffers_set,_save_to_state_dict 时跳过。
这种”buffer 是否持久化”的细分在 v1.x 之后才出现,是社区反馈”我有些 buffer 不想存”后加的功能。
9.6.3 自定义 state_dict 输出格式
某些场景需要修改 state_dict 的 key(如要转成与 HuggingFace 兼容的命名)。可以注册 _state_dict_hook:
def rename_keys(module, state_dict, prefix, local_metadata):
# 把 module.weight 改成 module.W
if prefix + 'weight' in state_dict:
state_dict[prefix + 'W'] = state_dict.pop(prefix + 'weight')
model._register_state_dict_hook(rename_keys)
Hugging Face Transformers 库的 from_pretrained / save_pretrained 内部就有大量这类 hook,把 PyTorch 原始命名转换成 HuggingFace 标准命名(如 bert.encoder.layer.0.attention.self.query ↔ HF 风格)。这是 ecosystem 互操作性的工程基础。
9.7 load_state_dict:参数加载与名字匹配
module.py:2531 的 load_state_dict 反向操作。它做:
- 跑
_load_state_dict_pre_hooks - 递归遍历模块树,每层 module 的
_load_from_state_dict把 state_dict 里对应 prefix 的 key 拷到自己的_parameters/_buffers - 跑
_load_state_dict_post_hooks - 收集 missing_keys / unexpected_keys / error_msgs,根据
strict参数决定 raise 还是 warn
最常见的”加载报错”场景:训练时模型结构变了,旧 ckpt 的 key 与新模型对不上。strict=False 让加载继续但记录差异,方便迁移学习场景:
# 加载预训练 ResNet, 替换最后的 fc 层
model = MyResNet(num_classes=10)
state_dict = torch.load('resnet50.pth')
missing, unexpected = model.load_state_dict(state_dict, strict=False)
# 通常 unexpected = ['fc.weight', 'fc.bias'] (旧 fc,被新 num_classes 改了)
第 19 章序列化章会展开 torch.save / torch.load 与 state_dict 的协作,包括 safetensors 格式的优势。
9.7.0.5 missing_keys 与 unexpected_keys 的诊断价值
load_state_dict 返回的 missing_keys / unexpected_keys 是诊断”模型 vs ckpt 不匹配”的金矿:
- missing:当前模型有但 ckpt 没有的 key —— 通常是模型新加了层,ckpt 是老版本
- unexpected:ckpt 有但当前模型没有的 key —— 通常是模型删了层,或 ckpt 多保存了不必要的 buffer
调试技巧:
missing, unexpected = model.load_state_dict(state_dict, strict=False)
print('Missing keys:', missing[:5])
print('Unexpected keys:', unexpected[:5])
print(f'Match rate: {1 - (len(missing) + len(unexpected)) / len(state_dict):.1%}')
匹配率 > 95% 通常说明 ckpt 大体兼容,少量差异可能是 BN 的 num_batches_tracked 或 dropout 的 random state 这种次要 buffer。匹配率 < 80% 通常意味着模型架构有重大变化,加载后行为可能不正常。
9.7.1 加载流程的细节
load_state_dict 实际工作流:
flowchart TB
Start[load_state_dict 入口]
Start --> Pre[跑 _load_state_dict_pre_hooks]
Pre --> Load[递归 _load:<br/>对每个 module 调 _load_from_state_dict]
Load --> Each["_load_from_state_dict 做:<br/>1. 跑 _load_state_dict_pre_hooks (per-module)<br/>2. 把 state_dict 里 prefix+name 的 tensor 拷到 _parameters/_buffers<br/>3. 处理 missing/unexpected keys<br/>4. 跑 _load_state_dict_post_hooks (per-module)"]
Each --> Post[跑全局 _load_state_dict_post_hooks]
Post --> Check{strict=True?}
Check -->|是, 有 missing/unexpected| Err[raise RuntimeError]
Check -->|否| Return[返回 missing_keys, unexpected_keys]
style Each fill:#fef3c7,stroke:#f59e0b
注意 数据拷贝是 inplace:state_dict 里的 tensor 用 param.copy_(loaded_tensor) 拷到现有的 Parameter,不是创建新对象。这样原模型的 Parameter 引用保持有效(optimizer 的 param_groups 不会失效),但权重值已经更新。这是为什么 load 后不需要重新 build optimizer。
这种 “in-place loading” 设计有个隐形约束:state_dict 里的 tensor 必须形状与现有 Parameter 完全一致。形状不一致会报错。如果你想换形状(如改 num_classes),要先重新构造 model 或者用 strict=False 跳过那一项。
9.8 8 种 hook:模块生命周期的扩展点
nn.Module 提供 8 种 hook,按调用时机分类:
| Hook | 注册方法 | 触发时机 |
|---|---|---|
forward_pre_hook | register_forward_pre_hook | forward 调用前,可改 args |
forward_hook | register_forward_hook | forward 调用后,可改 result |
full_backward_pre_hook | register_full_backward_pre_hook | 反向开始前 |
full_backward_hook | register_full_backward_hook | 反向结束后,能拿 grad_input / grad_output |
backward_hook (deprecated) | register_backward_hook | 早期版本,与 full_backward 不兼容 |
state_dict_pre_hook | _register_state_dict_hook | state_dict 调用前 |
state_dict_hook | _register_state_dict_hook | state_dict 调用后,可改返回值 |
load_state_dict_pre/post_hook | _register_load_state_dict_hook | load_state_dict 前后 |
每种 hook 都是为特定扩展场景设计的:
- profiler / 性能分析:用 forward_pre/post hook 记时间戳
- 混合精度 cast:用 forward_pre hook 把 fp32 输入转 fp16
- 梯度统计 / 调试:用 full_backward_hook 检查 grad_input / grad_output
- DDP gradient bucketing:用 backward hook 检测某 grad 算完后立即触发 AllReduce(第 17 章)
- 量化 observer:forward_hook 收集激活分布,PT2E 量化的核心
- 检查点平均:state_dict_hook 在 save 时拦截
每条 hook 都是 PyTorch 与生态对接的关键点。理解 hook 机制能让你优雅扩展任何 nn.Module 行为,不需要改源码。
9.8.0 hook 与 dispatcher mode 的关系
第 5 章 §5.7 我们看到 TorchDispatchMode 也是一种”拦截每个算子”的扩展点。Module hook 与它的区别:
| 维度 | Module hook | TorchDispatchMode |
|---|---|---|
| 拦截粒度 | 一整个 Module 的 forward | 每一个 ATen 算子 |
| 注册方式 | module.register_forward_hook(fn) | with MyMode(): |
| 看到的输入 | Module 的输入参数 | ATen 算子的张量参数 |
| 适合场景 | ”在某层之后做 X" | "对所有算子做 X” |
两者不是互斥的 —— 你可以同时用。Module hook 是粗粒度(按层),DispatchMode 是细粒度(按算子)。生产代码里大部分扩展用 Module hook(更易理解),少数极致 hack 用 DispatchMode(如假执行、量化 observer)。
9.8.2 RemovableHandle 设计
每个 register_xxx_hook 都返回 RemovableHandle:
handle = model.register_forward_hook(my_hook)
# 后续...
handle.remove() # 移除
RemovableHandle 是个简陋的 dataclass:持有 hook id 和注册它的字典引用。remove() 就是从字典 pop 出 id。这种”返回 handle 让用户可以撤销”是 Python 标准库 weakref.WeakValueDictionary 等也用的模式。
为什么不直接给用户 hook 函数引用让他们传回来?因为同一个函数可能被注册多次(在不同 hook 集上),handle 的 id 让”第 N 次注册”可识别可撤销。这种”id-based 注册系统”在 RxJava、JS 事件监听里都有体现。
9.8.5 一个高级技巧:用 hook 实现 model surgery
“model surgery” 是指在不改源码的前提下修改训练好的模型行为。例如:把 ResNet 的某层 ReLU 换成 GELU,但只改一层、保持其他层原样。
def replace_relu_with_gelu(module):
for name, child in module.named_children():
if isinstance(child, nn.ReLU):
setattr(module, name, nn.GELU())
else:
replace_relu_with_gelu(child)
replace_relu_with_gelu(model)
这段代码递归遍历所有 module,遇到 ReLU 就替换成 GELU。setattr(module, name, ...) 触发 __setattr__,自动从 _modules 删除旧的 ReLU、加入新的 GELU。整套 surgery 操作零侵入、不需要改原模型代码。
这种能力让 PyTorch 在研究场景里非常灵活 —— 可以快速尝试”如果某层换成 X 会怎样”。HuggingFace 的 PEFT (LoRA / Prefix-tuning) 库就大量用 model surgery 把”普通 Linear 替换成 LoRALinear”。
9.8.1 一个真实例子:grad clipping with hook
def grad_clip_hook(module, grad_input, grad_output):
return tuple(g.clamp(-1.0, 1.0) if g is not None else g for g in grad_input)
model.register_full_backward_hook(grad_clip_hook)
这种实现比 torch.nn.utils.clip_grad_norm_ 更精细 —— 在每个 module 反向时立即裁,不等所有梯度算完。但需要小心 inplace clamp 与 autograd 的交互,所以生产代码通常仍用标准 API。
9.8.3 hook 执行的隐形顺序
每种 hook 之间的调用顺序是有确定 spec 的,按 _wrapped_call_impl 实现的顺序:
- global forward_pre_hook(用
nn.modules.module.register_module_forward_pre_hook注册,对所有 module 生效) - self.forward_pre_hook(per-module)
- backward_pre_hook 设置(如果有,会插入 BackwardHook 包装输入)
- self.forward(args)
- global forward_hook
- self.forward_hook
- backward_hook 设置(包装输出)
- non_full_backward_hook 注册到 grad_fn
理解这套顺序很重要。比如你写一个 forward_hook 想看到”经过其他所有 hook 修改之后的最终 result”,要确保你的 hook 注册时机最晚(dict 是按插入顺序),或者用 prepend=True 参数。这种顺序敏感性是 Module hook 与 dispatch mode 的另一个区别 —— DispatchMode 用栈结构、最近的 mode 先拦截,逻辑更清晰。
9.8.4 用 hook 实现 GPT-style cache:一个高级例子
class KVCache:
def __init__(self):
self.cache = {}
def hook(self, module, input, output):
# 假设 module 是 attention 层
layer_id = id(module)
self.cache[layer_id] = output[1] # K, V
cache = KVCache()
for layer in model.transformer.layers:
layer.attention.register_forward_hook(cache.hook)
这是 transformer 推理时手动实现 KV cache 的 toy 写法。生产代码的 KV cache(如 vLLM)当然更精巧(PagedAttention),但思路一致:用 hook 拦截每层 attention 输出,存下 K/V,下一次 forward 时复用。Module hook 让”对模型的非侵入式扩展”非常直接。
9.9 train() / eval():模式切换
def train(self, mode: bool = True):
self.training = mode
for module in self.children():
module.train(mode)
return self
这个方法把 self.training 设为 True/False 并递归子模块。BatchNorm、Dropout 等模块在 forward 时检查 self.training 决定行为:
class Dropout(Module):
def forward(self, input):
if self.training:
return F.dropout(input, self.p, ...)
else:
return input # eval 模式下 dropout 是 identity
新手最常见的 bug 就是 inference 时忘了 model.eval() —— BN 仍然用 batch 统计、Dropout 仍然丢神经元、模型行为非确定。生产 inference 服务必须 model.eval()。
注意 eval() 和 no_grad() 是正交的概念:
eval():切到推理模式(影响 BN/Dropout 行为)no_grad():关掉 autograd(影响梯度计算)
生产 inference 通常两者都要:
model.eval()
with torch.no_grad():
out = model(x)
或者更高效的:
model.eval()
with torch.inference_mode():
out = model(x)
9.9.1 model.children() vs model.modules()
两个递归方法容易混:
children()只返回直接子模块(不递归到子模块的子模块)modules()递归返回所有模块(包括 self、子子模块……)
类似还有 named_children() 和 named_modules() 带名字版本。生产代码里:
# 给所有 Conv2d 的 weight 做 init
for m in model.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
modules() 返回深度优先遍历整个网络。这是 PyTorch 给”按类型批量处理参数”留的最方便接口。
children() 与 modules() 看似冗余,但场景不同:
- 想给”每个直接子层”应用变换(如最外层 wrap 一下):用
children() - 想搜整个网络里所有某类型层(如所有 BN):用
modules() - 想知道某层在网络里的命名位置:用
named_modules()
apply(fn) 内部就是 for m in self.modules(): fn(m) 加 self return。这种”递归遍历 + 用户提供闭包”模式让 PyTorch 与训练流程框架(如 Lightning)能轻松实现”对所有参数做 X”的统一接口。
9.9.5 train/eval 在分布式与 inference 服务中的微妙性
model.train() / model.eval() 看似简单,分布式训练里有些坑:
1. DDP 包装后的 train/eval:DDP 把 model 包成 DistributedDataParallel(model),调 ddp_model.train() 会递归设置 inner module 的 training。但有些第三方库的 wrapper 不递归,需要显式 ddp_model.module.train()
2. 模型分阶段切换:典型场景是先训练后用一段 inference 子图。常见错误:
model.eval()
out = model.encoder(x) # eval 模式
loss = some_loss(out)
loss.backward() # 仍然会反向, 但 BN 用的是 running stats 而非 batch stats
# 训练效果会很差
正确做法:明确分开训练阶段和推理阶段,不要混合。
3. 推理服务的并发安全:model.eval() 设置了 self.training=False,多线程并发推理时不会互相影响(因为每个推理只读 training)。但如果你某条路径里又调了 model.train(),并发的其他推理会被污染。生产服务里不要在请求处理路径切换 train/eval。
9.10 跨书关联
- 《vLLM 内核探秘》第 7 章 模型加载:vLLM 加载 HuggingFace 模型时直接构建 PyTorch nn.Module 树,再用自定义 weight loader 喂权重 —— 完全依赖本章的 state_dict 机制
- 《Tokio 异步运行时》第 X 章 任务上下文:Tokio Task 的 spawn / hook 注册机制与 nn.Module 的 hook 思想类似,都是”在统一抽象上预留扩展点”
- 《Serde 元编程》派生宏:Serde 的
#[derive(Serialize)]在编译期生成的逻辑,与 nn.Module 在运行期通过__setattr__拦截做的事是同一目的:自动收集类的字段 - 《MCP 协议剖析》Server 注册:MCP server 用 decorator 注册 method handler 与 nn.Module hook 注册思想极相似 —— 都是”对类的字段做声明式处理”
9.10.5 一个对照:torch.fx vs nn.Module
最后值得提一下 torch.fx。它是 PyTorch 用来”把 nn.Module 转成符号化 IR”的工具:
import torch.fx
model = MyModel()
gm = torch.fx.symbolic_trace(model)
# gm 是一个 torch.fx.GraphModule, 它的 forward 是一段被记录的 IR
print(gm.graph)
fx.symbolic_trace 用 __torch_function__ 协议拦截每个算子调用,把它记录成 IR 节点。这套 IR 是 torch.compile 的早期基础(v1.x 时代),现在被 Dynamo 部分取代但仍在量化、模型变换等场景广泛使用。
理解 fx 与 nn.Module 的关系:nn.Module 是面向人的 API(layers + forward),fx Graph 是面向编译器的 IR。两套表示互相可转换:fx.symbolic_trace 把 Module 转成 Graph;fx.GraphModule 又是一种 Module(继承自 nn.Module)。这种”原始 + 编译”两套表示并存是 PyTorch 框架的经典哲学,也是第 12-15 章 torch.compile 章节的根本设计。
注意 fx 现在已经被 Dynamo 取代主导地位(Dynamo 更鲁棒、能处理更多 Python 控制流),但 fx 在某些特定场景仍然广泛使用:
- 量化 (quantization):PT2E 量化的 prepare / convert 阶段大量用 fx
- 模型变换 / pruning:剪枝、重参数化等需要明确的 Graph IR 才好做
- 简单的 trace 场景:用户想看”我的 Module 实际跑了哪些算子”时,
symbolic_trace比 Dynamo 简单得多
第 14 章 Inductor 章会展开 fx 与 Dynamo 之间的接力。
9.11 几条工程建议
实战 nn.Module 相关的最佳实践:
1. 永远 super().__init__() 在 __init__ 第一行:忘了会让 _parameters 等字典没初始化,后续赋值报错”cannot assign parameters before Module.init()”
2. 学习参数用 nn.Parameter 包:直接 self.weight = torch.randn(...) 不会被自动注册为参数
3. 不参与训练但要保存的状态用 register_buffer:如 BN 的 running stats、自定义的位置编码
4. 临时中间状态不要存为 attribute:会被 state_dict 收进来,让 ckpt 变大。用局部变量。但注意:纯 Python 对象(list、dict、自定义类)不会被 state_dict 收 —— 只有 Tensor / Parameter / Buffer / Module 才被自动收集
5. 子模块要写到 __init__ 而非 forward:在 forward 里 nn.Linear() 每次都新建模块,权重不会被训练
6. 注意 hook 的 order 影响:多个 hook 按注册顺序调用,写的时候要明确”我能假设其他 hook 已经跑过吗”
7. model.to(device) 是 inplace 操作:返回 self,但同时也改了 self。new = model.to(device) 与 model.to(device) 等价
8. nn.ModuleList / nn.ModuleDict 而非 Python list / dict:self.layers = [nn.Linear(10,10) for _ in range(5)] 不会自动注册(普通 list 不被 __setattr__ 检测),self.layers = nn.ModuleList([...]) 才会
9. nn.Parameter 注册时机:必须在 __init__ 期间注册。运行期动态加 Parameter(如 if cond: self.weight2 = nn.Parameter(...))会让 optimizer 拿不到(因为 optimizer 在 model 创建后就读了一次 parameters())
10. model.apply(fn) 递归调用:apply 给每个 module 调一次 fn(module)。常用于初始化:model.apply(init_weights)。它的实现就是 for m in self.modules(): fn(m) 加 self return
11. forward 里不要用 self.training 之外的 mode 切换:除了 BN / Dropout 这种内置约定,自定义 module 不要在 forward 里读其他 self.xxx 切换行为。这样会让 torch.compile 看到的图依赖运行时变量,触发 graph break
12. 模型结构包含可变 list / Tensor list:用 ParameterList / BufferList 而非普通 Python list。例如 self.scales = nn.ParameterList([nn.Parameter(torch.ones(1)) for _ in range(K)]),自动注册
9.11.1 给 Module 设计师的”通用启示”
把 nn.Module 思想抽象到任何”组件容器”系统:
第一:__setattr__ 拦截 + __getattr__ 兜底是 Python 元编程的黄金组合 —— 让用户感受不到框架在背后做的事,写代码像写普通类一样自然
第二:子节点用专门容器(dict / list)而非 self.x = child:让递归遍历、序列化、迁移操作都有统一入口
第三:hook 系统让框架可扩展:在生命周期关键点开放可注册的回调列表,第三方工具就能集成进来。RemovableHandle 让用户能撤销注册
第四:state 与 metadata 分离:state_dict 存”值”,_metadata 存”版本号 / 类型信息”。前者让数据可移植,后者让兼容性成为可能
第五:train/eval 是模型生命周期的两个不同阶段:用一个 self.training flag 切换,相关组件查这个 flag 调整行为。这种”全局上下文 + 组件就地查询”模式比”为不同模式写两份代码”清晰得多
把这五条记住,你写自己的”组件框架”(如 GUI 框架、游戏引擎的 Entity-Component 系统、CRM 的工作流模板)能少走很多弯路。
第六:让递归遍历成为一等接口:parameters() / modules() / children() / named_xxx() 这种”按容器维度遍历”接口让框架的所有变换、保存、迁移操作都有标准入口
第七:train/eval / persistence / device 这三种”全局开关”互不干扰:每种开关由不同字段表示,Module 在合适的层面响应。这种”正交概念分离”让用户能任意组合而不互相破坏
9.11.5 性能数字:nn.Module 调用开销
具体数字(H100,PyTorch v2.11,单线程 CPU):
| 操作 | 开销 |
|---|---|
一次 module(x) 调用(无 hook) | ~3-5 us |
| 一次 forward_hook 触发 | +0.5-1 us |
| 一次 backward_hook 触发 | +1-2 us |
module.named_parameters() 遍历(10 层网络) | ~10-20 us |
module.state_dict()(10 层网络) | ~50-100 us |
module.load_state_dict()(10 层网络,无形状变化) | ~200-500 us |
对一次 forward 包含几十层 module 的大模型,nn.Module 调用开销加起来约 几十微秒到几百微秒,相对于实际计算(毫秒级)不显著。但对小模型(如 MNIST classifier)反复 forward / backward 时,nn.Module 自身开销可能占总时间 10-20%。
torch.compile 把整个 module(x) 调用编进单个 Triton kernel,这部分 nn.Module 开销直接归零 —— 这是 compile 在小模型上能拿到 50%+ 加速的关键。
9.12 nn.Module 的 to/cuda/half 实现
model.cuda() 把整个模型搬到 GPU。它的实现就是用 _apply 递归遍历每个 parameter / buffer,对每个张量调 .cuda():
def cuda(self, device=None):
return self._apply(lambda t: t.cuda(device))
def _apply(self, fn, recurse=True):
if recurse:
for module in self.children():
module._apply(fn)
# apply to all parameters
for key, param in self._parameters.items():
if param is not None:
with torch.no_grad():
param_applied = fn(param)
self._parameters[key] = Parameter(param_applied, ...)
# apply to all buffers
for key, buf in self._buffers.items():
if buf is not None:
self._buffers[key] = fn(buf)
return self
_apply 是 nn.Module 里最通用的一个抽象 —— 它接受任意”对张量的变换”函数,递归整个网络。.cuda() / .cpu() / .half() / .float() / .bfloat16() / .to(device, dtype) 全部用它实现。
注意细节:参数变换后是用 Parameter() 包装回去(保留 requires_grad 等元数据),buffer 直接覆盖。这是因为 buffer 不需要 Parameter wrapper。
为什么参数变换后要重新包成 Parameter 而不是直接 inplace?因为某些变换(如 .to(device))会返回新张量(device 改变时数据需要重新分配)。新张量是普通 Tensor,需要重新包装回 Parameter 以保留 requires_grad 等元信息。如果是同 device、同 dtype 的变换,PyTorch 会做一些优化避免重新分配,但接口语义上始终是”返回新对象”。
9.12.0.5 _apply 与 autograd 的微妙交互
.cuda() 等变换内部用 with torch.no_grad(): 包住,避免变换本身被 autograd 记录。否则你 model.cuda() 的瞬间会触发 _to_copy 算子的反向图,最后整个网络复制到 GPU 的过程被记入反向图 —— 显存爆炸。
这种”对底层 housekeeping 操作显式禁用 autograd”是 PyTorch 内部很多基础设施的共同做法。optimizer.step() 内部也用 no_grad 包,避免参数更新被反向追踪。第 10 章会再看到这种模式。
9.12.1 一个有趣的 _apply 应用:模型量化
def quantize_to_int8(t):
return t.to(torch.int8) if t.is_floating_point() else t
quantized = model._apply(quantize_to_int8)
类似机制让 PyTorch 的 quantization、AMP autocast、device offloading 等高级功能都能站在 _apply 这套递归基础上写。第 20 章量化会展开。
9.12.5 _save_to_state_dict vs _load_from_state_dict 的对偶性
为什么 PyTorch 把 save 和 load 写成 _save_to_state_dict / _load_from_state_dict 这种”_私有方法”模式而不是公开 API?
因为 它们是 module 子类可以重写的扩展点。如果你的 module 有特殊持久化需求(如某个张量需要量化后存、或者把多个张量打包成一个),可以重写这两个方法:
class MyModule(nn.Module):
def _save_to_state_dict(self, destination, prefix, keep_vars):
# 自定义存储格式
destination[prefix + 'packed'] = self._pack_for_save()
def _load_from_state_dict(self, state_dict, prefix, local_metadata, ...):
# 自定义加载逻辑
if prefix + 'packed' in state_dict:
self._unpack(state_dict[prefix + 'packed'])
这种”框架提供 hook 点 + 子类重写细节”是 Template Method 设计模式的应用。PyTorch 的 quantization、稀疏张量、甚至 nn.Embedding 都用了重写 _load_from_state_dict 来处理特殊情况。
9.13 一个完整的 Module 树:HuggingFace BERT 解构
实战看一段 HuggingFace BertModel 的 model.named_modules() 输出(节选):
'' (BertModel)
'embeddings' (BertEmbeddings)
'embeddings.word_embeddings' (Embedding)
'embeddings.position_embeddings' (Embedding)
'embeddings.token_type_embeddings' (Embedding)
'embeddings.LayerNorm' (LayerNorm)
'encoder' (BertEncoder)
'encoder.layer.0' (BertLayer)
'encoder.layer.0.attention.self.query' (Linear)
'encoder.layer.0.attention.self.key' (Linear)
'encoder.layer.0.attention.self.value' (Linear)
'encoder.layer.0.attention.output.dense' (Linear)
'encoder.layer.0.intermediate.dense' (Linear)
'encoder.layer.0.output.dense' (Linear)
'encoder.layer.1' (BertLayer)
...
每条路径就是 dotted prefix,对应 state_dict 里的 key。这种”树结构 + dotted path”的设计让 BERT 这种 12 层 / 24 层的复杂模型能用扁平 dict 完整表达,加载、迁移、部署都极简单。HuggingFace、Lightning 等所有上层框架都建立在 nn.Module 这套机制之上。
观察一个细节:encoder.layer.0 这种带数字的路径来自 nn.ModuleList。它实现了 __setitem__ 把列表索引转成字符串名字('0'、'1' 等),存到 _modules 字典。所以 ModuleList 表面上是列表,底层仍然是 dict —— 索引化路径自然形成。
类似地,nn.ModuleDict 让用户用字符串 key 索引子模块,但它的所有 key 都直接进 _modules。理解 ModuleList / ModuleDict / Sequential 都只是 dict 的语法糖,能让你写自己的”结构化 module 容器”时心里有数。
nn.Sequential 是另一个重要 ModuleList 派生:它把子模块用整数索引按顺序串起来,forward 自动 for m in self: x = m(x)。常用于线性堆叠的网络:
mlp = nn.Sequential(
nn.Linear(10, 100),
nn.ReLU(),
nn.Linear(100, 10)
)
mlp[0] 是第一个 Linear、mlp[-1] 是最后一个 Linear。这是 nn.Module 给”简单堆叠模型”的便利包装。但 Sequential 不能处理需要分支 / 跳连接的复杂结构,那些场景要写完整的 forward。
9.14 一个收官的反思:nn.Module 的成功是 PyTorch 的成功
某种程度上,PyTorch 之所以击败 TF 1.x 赢得学术界,不只是动态图,还因为 nn.Module 比 tf.keras.Model 更自然。
tf.keras.Model 早期要求用户走 functional API(x = Dense(10)(x))或者 subclass + build() 方法 —— 都比 PyTorch 的 self.linear = nn.Linear(...) 多了一层心智负担。研究者要快速实验新架构时,PyTorch 的 nn.Module 写起来更接近”在草稿纸上画的网络结构图”。
这条经验给所有想做”用户友好框架”的设计师一个启示:让用户写代码的方式贴近他思考的方式。如果用户脑子里想的是”线性层 → ReLU → 线性层”,框架就该让他这样写代码。中间多一个抽象层都是认知摩擦。
PyTorch 的 nn.Module 把”魔法量”控制在 两个 dunder 方法 + 几个内部 dict 这个最小集 —— 已经能让用户体验最自然,又不至于黑盒到完全不可调试。这种 “隐藏复杂度但保留可观察性” 的工程平衡,是 nn.Module 设计上最高明的一笔。
下一章拆 optimizer:optim.SGD / optim.Adam 等怎么实现,包括最近引入的 Foreach / Fused / Capturable 三种性能模式。
评论 0
还没有评论,来说两句吧。
评论加载失败,刷新重试。