第9章 nn.Module:参数注册、Hooks、序列化

“nn.Module looks magical because it does so many things in __setattr__. Once you understand that one method, the whole system becomes obvious.”

—— PyTorch core team note

本章要点

  • nn.Module.__setattr__ 拦截赋值:检测 value 类型决定放到 _parameters / _buffers / _modules 中的哪一个。这是”self.linear = nn.Linear(...) 自动注册”的全部魔法
  • __getattr__ 是补救机制:当用户访问的属性不在 __dict__ 时,从三个 dict 里查找。linear.weight 实际是从 linear._parameters['weight'] 取出
  • state_dict() 递归收集:先调本模块的 _save_to_state_dict,再递归 _modules.items()。返回的字典 key 是 dotted path(如 'layer1.linear.weight'
  • load_state_dict 严格匹配 + 名字模糊容忍strict=True 默认要求 key 完全匹配,但提供详细的 missing/unexpected 报告供用户排查
  • 8 种 hook 协同工作:forward_pre / forward / backward / full_backward / full_backward_pre / state_dict_pre / state_dict / load_state_dict_pre。让用户能在 Module 生命周期的任何点插入自定义逻辑
  • __call__ = _wrapped_call_implmodule(x) 不等于 module.forward(x),它包裹了 hooks 调用、autograd 反向 hook 注入、错误处理

9.1 一段神奇的代码

每个 PyTorch 用户都写过这样的模型:

import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(10, 20)
        self.linear2 = nn.Linear(20, 30)
        self.bn = nn.BatchNorm1d(20)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = self.linear1(x)
        x = self.bn(x)
        x = self.linear2(x)
        return x

model = MyModel()
print(list(model.parameters()))   # 自动列出 linear1.weight, linear1.bias, linear2.weight, ..., bn.weight, bn.bias
print(model.state_dict().keys())   # 自动收集 'linear1.weight', 'bn.running_mean', ...
torch.save(model.state_dict(), 'ckpt.pt')   # 序列化

魔法到处都是:

  1. self.linear1 = nn.Linear(...) 怎么自动加入 model.parameters()
  2. bn.running_mean(不是 Parameter)怎么也进了 state_dict?
  3. state_dict() 的 key 'linear1.weight' 是怎么拼出来的?
  4. model(x)forward(x) 还是别的?

这些问题的答案都在 torch/nn/modules/module.py(3054 行)。本章拆它的设计。

9.1.1 nn.Module 的设计目标

Module 不只是”参数容器” —— 它是 PyTorch 模型的统一抽象,要解决几个问题:

  • 递归组合:模型由层组成,层又由更小层组成。需要无限递归
  • 参数自动管理:用户写 self.linear = nn.Linear(...) 自动让 optimizer 看到所有权重
  • 状态保存与加载:训练 ckpt、模型部署都需要”把所有可变状态存到磁盘”
  • 跨 device 迁移.cuda() / .cpu() / .half() 要递归整个网络
  • 训练 / 推理模式切换:BN / Dropout 等需要在 train / eval 时不同行为
  • 可扩展性:第三方框架(HuggingFace、PyTorch Lightning、DDP 等)能 hook 进 Module 生命周期

这些目标共同塑造了 Module 的内部结构。理解每个设计选择背后的目标,能让你看 nn.Module 源码不再迷失在 3054 行细节里。

9.2 __init__:建立四个核心容器

打开 torch/nn/modules/module.py:483

class Module:
    def __init__(self, *args: Any, **kwargs: Any) -> None:
        ...
        super().__setattr__("_parameters", {})
        super().__setattr__("_buffers", {})
        super().__setattr__("_non_persistent_buffers_set", set())
        super().__setattr__("_backward_pre_hooks", OrderedDict())
        super().__setattr__("_backward_hooks", OrderedDict())
        super().__setattr__("_is_full_backward_hook", None)
        super().__setattr__("_forward_hooks", OrderedDict())
        super().__setattr__("_forward_hooks_with_kwargs", OrderedDict())
        super().__setattr__("_forward_hooks_always_called", OrderedDict())
        super().__setattr__("_forward_pre_hooks", OrderedDict())
        super().__setattr__("_forward_pre_hooks_with_kwargs", OrderedDict())
        super().__setattr__("_state_dict_hooks", OrderedDict())
        super().__setattr__("_state_dict_pre_hooks", OrderedDict())
        super().__setattr__("_load_state_dict_pre_hooks", OrderedDict())
        super().__setattr__("_load_state_dict_post_hooks", OrderedDict())
        super().__setattr__("_modules", {})

注意每行都用 super().__setattr__ 而不是 self.xxx = ... —— 这是为了绕过自己重写的 __setattr__(直接给 __dict__ 赋值)。否则会触发无限递归。

四个核心容器:

字段类型用途
_parametersdict[str, Parameter]可学习参数(weight、bias 等)
_buffersdict[str, Tensor]不可学习但需保存的状态(如 BN 的 running_mean)
_non_persistent_buffers_setset[str]不存进 state_dict 的 buffer 名(如 dropout 的随机数生成器)
_modulesdict[str, Module]子模块(自动递归整个网络)

剩下的 12 个字段都是 hook 容器(OrderedDict 保证调用顺序)。Module 的所有”高级特性”都建立在这 16 个字段之上。

为什么用 OrderedDict 而不是 list?因为 hook 注册返回 RemovableHandle,handle 持有一个 id,用户调 handle.remove() 时通过 id 从 dict 删除。dict 的随机访问 O(1) 比 list 的线性查找方便。OrderedDict 保留插入顺序让多个 hook 按注册顺序调用 —— 这是用户期望的语义。

至于 _parameters / _buffers / _modules 用普通 dict 而不是 OrderedDict:v3.7+ 普通 dict 已经保证插入顺序,不需要显式 OrderedDict。但 PyTorch 在保留某些 OrderedDict 是为了与老 Python 兼容,加上代码风格惯性。

graph TB
    M[nn.Module 实例]
    M --> P["_parameters dict<br/>weight / bias / ..."]
    M --> B["_buffers dict<br/>running_mean / num_batches_tracked"]
    M --> Mods["_modules dict<br/>子模块递归"]
    M --> H["8 种 _xxx_hooks<br/>前/后/状态/反向"]

    Mods --> Sub1[子 Module 1]
    Mods --> Sub2[子 Module 2]

    Sub1 --> SubP1[子 Module 1 的 _parameters]
    Sub1 --> SubB1[子 Module 1 的 _buffers]
    Sub1 --> SubM1[子 Module 1 的 _modules]

    style P fill:#dbeafe,stroke:#3b82f6
    style B fill:#fef3c7,stroke:#f59e0b
    style Mods fill:#dcfce7,stroke:#22c55e

9.3 __setattr__:拦截赋值的核心魔法

flowchart TD
    Set["self.x = value"]
    Set --> Q1{value 是 Parameter?}
    Q1 -->|是| P1["注册到 _parameters dict<br/>从 _modules / _buffers 清理"]
    Q1 -->|否| Q2{value 是 Module?}
    Q2 -->|是| P2["注册到 _modules dict<br/>建立子模块树"]
    Q2 -->|否| Q3{value 是 Buffer?}
    Q3 -->|是| P3["注册到 _buffers dict<br/>跟随 .cuda 等迁移"]
    Q3 -->|否| Q4{name 在某个 dict 里?}
    Q4 -->|是| P4["更新对应 dict<br/>类型必须匹配"]
    Q4 -->|否| P5["走 object.__setattr__<br/>普通 Python 属性"]

    style P1 fill:#fef3c7
    style P2 fill:#dcfce7
    style P3 fill:#dbeafe
    style P5 fill:#f3e8ff

module.py:1972__setattr__ 是 nn.Module 最重要的方法,简化版:

def __setattr__(self, name: str, value):
    params = self.__dict__.get("_parameters")

    if isinstance(value, Parameter):
        # 1. value 是 Parameter → 注册到 _parameters
        remove_from(self.__dict__, self._buffers, self._modules, ...)
        self.register_parameter(name, value)

    elif params is not None and name in params:
        # 2. 已存在的 param 被赋 None → 删除
        if value is not None:
            raise TypeError("cannot assign ... as parameter ... ")
        self.register_parameter(name, value)

    else:
        modules = self.__dict__.get("_modules")
        if isinstance(value, Module):
            # 3. value 是 Module → 注册到 _modules
            remove_from(self.__dict__, self._parameters, self._buffers, ...)
            modules[name] = value

        elif modules is not None and name in modules:
            # 4. 已存在的子模块被赋值
            modules[name] = value

        else:
            buffers = self.__dict__.get("_buffers")
            if isinstance(value, Buffer) or (buffers is not None and name in buffers):
                # 5. value 是 Buffer 或已存在的 buffer 名 → 注册到 _buffers
                ...

            else:
                # 6. 普通属性 → 走默认 __setattr__
                object.__setattr__(self, name, value)

——这就是”self.linear1 = nn.Linear(...) 自动注册”的真相。Python 的 __setattr__ 协议被拦截,根据 value 类型路由到不同容器

9.3.0.5 为什么要拦截赋值:备选方案的对比

PyTorch 的__setattr__ 拦截不是唯一方案。其他框架处理”自动注册参数”的不同思路:

  • Keras(早期):要求用户在 build() 方法里显式调 self.add_weight(...)。冗长但显式
  • JAX/Flax:用 dataclass 字段声明,外部 framework 通过 reflection 收集字段。需要 flax.linen.Module 装饰器
  • TensorFlow:用 tf.Variable 在某个 tf.Module 子类上自动追踪
  • PaddlePaddle:与 PyTorch 类似,__setattr__ 拦截

PyTorch 选择 __setattr__ 拦截的好处是 用户写代码与普通 Python 完全一致 —— 不需要装饰器、不需要 build 方法、不需要 add_weight 调用。代价是 nn.Module 是个”魔法重”的类,新人学起来需要一段时间适应它的行为不是普通 Python 类。这种”直观但魔法”vs”显式但冗长”是框架设计的永恒权衡。

9.3.1 一个具体追踪:self.linear = nn.Linear(10, 20)

MyModel.__init__ 里写 self.linear = nn.Linear(10, 20) 时:

  1. Python 解释器调 MyModel.__setattr__('linear', nn.Linear(10, 20))
  2. MyModel.__setattr__ 实际是继承的 Module.__setattr__
  3. 检查 value 类型:isinstance(nn.Linear(...), Parameter) → False,isinstance(nn.Linear(...), Module) → True
  4. 走分支 3:self._modules['linear'] = nn.Linear(...)
  5. 注意 self.linear 这个名字根本没进 self.__dict__

这就是为什么后面 model.linear 还能取到值 —— 通过 __getattr__ 的兜底机制。

注意 __setattr__ 里的 remove_from(...) 调用 —— 它从其他容器里删除同名条目。这是为了避免”同一个名字同时在 _parameters 和 _modules 里”的不一致状态。如果你先 self.x = nn.Parameter(...)self.x = nn.Linear(...),Parameter 会从 _parameters 删除、Linear 进入 _modules。这种”覆盖式赋值”的语义和普通 Python 类的 self.x = ... 完全一致,用户感受不到区别。

9.3.1.5 一个反直觉的细节:覆盖式赋值如何工作

class M(nn.Module):
    def __init__(self):
        super().__init__()
        self.x = nn.Linear(10, 20)   # 进 _modules
        self.x = nn.Parameter(torch.randn(5))   # 此时 _modules 里的旧值被删, 进 _parameters

第二次赋值时 __setattr__ 检测 value 是 Parameter,先调 remove_from(self.__dict__, self._buffers, self._modules, ...) 删除其他容器里的同名旧值,再 register_parameter。这套”先清后加”语义让覆盖式赋值表现得和普通 Python 对象一致 —— 用户感受不到内部容器切换。

9.3.2 为什么 Parameter 要单独继承

torch.nn.Parametertorch.Tensor 的子类,但本身不加任何新功能

class Parameter(torch.Tensor):
    pass    # 就是个空壳子类

它存在的唯一意义是 当作 isinstance 检查的标记。让 __setattr__ 能区分”这是要训练的张量”和”这是普通中间张量”。如果你写:

self.weight = torch.randn(10, 20, requires_grad=True)   # ❌ 不会进 _parameters
self.weight = nn.Parameter(torch.randn(10, 20))         # ✅ 进 _parameters

第一种写法 __setattr__ 检测不到 Parameter 类型,直接进 __dict__不会被 model.parameters() 列出,optimizer 找不到它。这是新手常踩的坑。

9.4 __getattr__:从三个 dict 兜底

module.py:1955__getattr__ 处理”读”路径:

def __getattr__(self, name: str):
    if "_parameters" in self.__dict__:
        _parameters = self.__dict__["_parameters"]
        if name in _parameters:
            return _parameters[name]
    if "_buffers" in self.__dict__:
        _buffers = self.__dict__["_buffers"]
        if name in _buffers:
            return _buffers[name]
    if "_modules" in self.__dict__:
        modules = self.__dict__["_modules"]
        if name in modules:
            return modules[name]
    raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")

Python 的 __getattr__ 只在标准属性查找失败时才被调。所以普通属性走 __dict__ 查找正常路径,特殊属性(参数/buffer/子模块)走兜底。

linear.weight 的查找流程:

  1. Python 先查 linear.__dict__['weight'] —— 没有(因为 __setattr__ 没把它放这)
  2. type(linear).__dict__['weight'] —— 没有(不是类属性)
  3. __getattr__('weight') —— 在 _parameters['weight'] 里找到,返回

这种”读 / 写不对称”是 Python 元编程的经典套路:写靠 __setattr__ 路由到正确容器,读靠 __getattr__ 从容器兜底取出。用户全程无感。

9.4.1 性能取舍:每次属性访问都过 __getattr__ 慢吗

新手经常担心:每次 self.linear 都过 __getattr__,会不会比普通属性访问慢?

实测一下:CPython 普通属性访问约 30 ns,过 __getattr__ 约 80 ns。差距 50 ns。对每秒上百万次属性访问的代码(如递归遍历模块树)累积可观,但对单次 forward 调用(几毫秒到几百毫秒)几乎不可见。

PyTorch 的 trick 是 forward 里访问的子模块已经被局部变量缓存

def forward(self, x):
    x = self.linear1(x)   # 第一次访问 linear1, 走 __getattr__
    x = self.linear2(x)   # 第二次访问 linear2, 也走 __getattr__
    return x

每次只查一次。如果你写 for _ in range(1000): x = self.linear(x),每次循环都过 __getattr__ —— 这种代码在 PyTorch 里少见但确实有性能损失。torch.compile 在编译时会把这种重复 attribute access 特殊化到具体的子模块对象,消除 __getattr__ 开销。第 12 章 Dynamo 章会展开。

9.5 __call__:模型调用的真实入口

很多人以为 model(x) 等价于 model.forward(x)。看 module.py:1918

__call__: Callable[..., Any] = _wrapped_call_impl

_wrapped_call_impl 不直接调 forward,它包裹了一层逻辑:

def _wrapped_call_impl(self, *args, **kwargs):
    # 1. 如果有 forward_pre_hooks, 跑
    # 2. 如果有 backward_hooks, 设置 BackwardHook 包装输入
    # 3. 调 forward
    # 4. 如果有 forward_hooks, 跑
    # 5. 如果有 non_full_backward_hooks, 给输出的 grad_fn 注册 hook
    # 6. 错误处理 + always_called hooks
    return result

所以 model(x)model.forward(x)。两者的差别在于前者跑 hooks,后者不跑。这就是为什么生产代码要用 model(x) 而不是 model.forward(x) —— hooks 是 PyTorch 各种工具(profiler、Hook-based pruning、量化等)的扩展点,绕过它们会导致这些工具失效。

9.5.1 _compiled_call_impl 与 torch.compile 的协作

module.py 还有一个重要变量 _compiled_call_impl。当用户对 module 调 torch.compile(model) 时,PyTorch 把编译后的 forward 实现挂到这里,__call__ 优先调用 _compiled_call_impl 而不是普通 forward。

这种”插槽式”设计让 torch.compile 可以透明加速 module,用户代码不用改:

model = MyModel()
model = torch.compile(model)    # 替换内部 _compiled_call_impl
out = model(x)                  # 走编译后的 forward

第 15 章 torch.compile 端到端会展开这套协作。注意 __getstate__ 在序列化时会剔除 _compiled_call_impl(因为它依赖运行时编译产物,不能跨进程传递)。这种”序列化时剥离运行时态”是 Module 兼容多种使用场景的细节。

9.5.2 forward 是 Python 函数还是 Module 方法

很多人没意识到:forward 是普通 Python 方法,不是任何特殊声明。Module 的 __call__ 通过 forward_call = self.forward 拿到这个方法,调用时传递参数。这意味着:

  • 你可以在 forward 里写任何 Python 代码(if/for/while/try)
  • forward 可以接受任何参数(位置 / 关键字 / 默认值)
  • forward 可以返回任何东西(tensor、tuple、dict、list、None)

这种自由度是 PyTorch 动态图的体现。代价是:forward 必须在每次 module(x) 都被解释执行(除非用 torch.compile),无法像静态图框架那样跨调用复用编译产物。

子类可以选择重写 __call__ 而非 forward(如 PyTorch Lightning 的 LightningModule)—— 但绝大多数代码遵循 “只重写 forward” 的惯例,让 hook 等机制不被破坏。

9.5.3 reentrant 的 forward 调用

如果在 forward 里再调一次 self(x)(递归调用同一个 module),会怎样?

class RecursiveModule(nn.Module):
    def forward(self, x, depth):
        if depth == 0:
            return x
        return self(x, depth - 1)   # 递归调用自己

这是合法的,会跑两次 hooks(外层和内层各一次)。但 backward 时反向图会反映出递归结构 —— 同一个 grad_fn 实例可能在反向链上出现多次。这种”算子复用”在某些 RNN 实现里出现,PyTorch 完全支持。

sequenceDiagram
    autonumber
    participant U as user: model(x)
    participant W as _wrapped_call_impl
    participant Pre as forward_pre_hooks
    participant F as forward(x)
    participant BH as BackwardHook
    participant Post as forward_hooks

    U->>W: __call__(x)
    W->>Pre: 跑所有 pre-hook
    Pre-->>W: 可能修改 args
    W->>BH: setup_input_hook (注册反向 hook)
    W->>F: forward(args)
    F-->>W: result
    W->>BH: setup_output_hook
    W->>Post: 跑所有 post-hook
    Post-->>W: 可能修改 result
    W-->>U: result

9.6 state_dict:递归收集状态

module.py:2195state_dict 是模型保存的核心。简化逻辑:

def state_dict(self, *, destination=None, prefix="", keep_vars=False):
    if destination is None:
        destination = OrderedDict()
        destination._metadata = OrderedDict()

    local_metadata = dict(version=self._version)
    destination._metadata[prefix[:-1]] = local_metadata

    for hook in self._state_dict_pre_hooks.values():
        hook(self, prefix, keep_vars)

    self._save_to_state_dict(destination, prefix, keep_vars)

    # 递归子模块
    for name, module in self._modules.items():
        if module is not None:
            module.state_dict(
                destination=destination,
                prefix=prefix + name + ".",
                keep_vars=keep_vars,
            )

    return destination

_save_to_state_dict 把当前 module 的 _parameters 和 persistent _buffers 写到 destination:

def _save_to_state_dict(self, destination, prefix, keep_vars):
    for name, param in self._parameters.items():
        if param is not None:
            destination[prefix + name] = param if keep_vars else param.detach()
    for name, buf in self._buffers.items():
        if buf is not None and name not in self._non_persistent_buffers_set:
            destination[prefix + name] = buf if keep_vars else buf.detach()

注意 递归 + prefix 是 dotted path 的来源:

  • 调用 model.state_dict(),prefix = ""
  • model.linear1.state_dict(prefix="linear1.")
  • 在 linear1 里 _save_to_state_dict_parameters['weight'] 存到 destination['linear1.weight']
  • 类似产生 'linear1.bias''linear2.weight''bn.running_mean'

这种”前缀 + 递归”的设计让任意深度的 Module 树扁平化为一个 dict[str, Tensor],序列化非常友好。

值得注意一个工程细节:state_dict 返回的是 OrderedDict + _metadata 字段_metadata 是个隐藏字段(用 _metadata 而不是 metadata,避免冲突),存每个模块的”版本号”。_load_from_state_dict 在加载时会检查 version,调用对应的迁移逻辑 —— 让 PyTorch 能在某个 module 类升级 API 时仍然加载旧 ckpt。这套版本号机制大多数用户感受不到,但是 PyTorch 跨版本 ckpt 兼容的关键。

9.5.4 forward 与 no_grad 的细微交互

如果你在 forward 里写 with torch.no_grad(): 包住一段代码,会怎样?

def forward(self, x):
    with torch.no_grad():
        prepared = self.preprocess(x)   # 这段不参与反向
    return self.main(prepared)

注意 prepared 没有 grad_fn —— 它从 no_grad 块出来。self.main(prepared) 反向时反向链就在 prepared 这里断了:main 的反向能算出 grad 给 prepared,但 prepared.grad_fn 是 None,反向终止,不会传给 preprocess 的参数。

这是 PyTorch 给”前处理 / 特征提取不参与训练”的标准模式。Vision Transformer 的”frozen ViT backbone + 训练 head”等冻结预训练模型场景就靠这个。比手动设置 param.requires_grad = False 更直接、也更难误用。

9.6.0.5 一个常见误解:state_dict 是不是模型本身

新手以为 torch.save(model.state_dict()) 等价于”保存了整个模型”。:state_dict 只保存权重和 buffer 值,不保存模型类定义。要重建模型必须先 model = MyModel(...) 用 Python 构造,再 model.load_state_dict(...) 灌入权重。

如果想”完整保存模型对象”,要 torch.save(model)(不传 state_dict)—— 这会用 pickle 序列化整个 model 对象。但这种方式有个大坑:pickle 依赖类定义。如果你的 MyModel 类定义改了(哪怕只是改了 method 实现),加载 pickle 时会报错或行为不一致。

所以生产代码强烈建议用 state_dict 而非整对象 pickle:让模型架构(Python 代码)与权重(state_dict)完全解耦,各自独立演进。HuggingFace、PyTorch Lightning 等所有正经框架都遵循这个原则。

9.6.1 keep_vars=True 的妙用

默认 keep_vars=False:state_dict 里的张量都被 .detach() —— 与原模型断开 autograd 关系。这是为了避免序列化时把整个反向图也存进去。

但训练用的”checkpoint averaging”或者 “Polyak averaging” 等技术需要 state_dict 仍然连着 autograd 图:

# averaged model 的 weights 仍然是 differentiable
avg_dict = {k: v for k, v in model.state_dict(keep_vars=True).items()}

这是为研究场景留的逃生口。普通用户用默认 detach 就够。

9.6.2 _non_persistent_buffers_set 的设计

不是所有 buffer 都该进 state_dict。比如 BatchNorm 的 num_batches_tracked 是必要的(要存),但某些用户自己加的”调试用 running stats”可能不该存。Module.register_buffer(name, tensor, persistent=False) 让用户标记某 buffer 不进 state_dict —— 名字加入 _non_persistent_buffers_set_save_to_state_dict 时跳过。

这种”buffer 是否持久化”的细分在 v1.x 之后才出现,是社区反馈”我有些 buffer 不想存”后加的功能。

9.6.3 自定义 state_dict 输出格式

某些场景需要修改 state_dict 的 key(如要转成与 HuggingFace 兼容的命名)。可以注册 _state_dict_hook

def rename_keys(module, state_dict, prefix, local_metadata):
    # 把 module.weight 改成 module.W
    if prefix + 'weight' in state_dict:
        state_dict[prefix + 'W'] = state_dict.pop(prefix + 'weight')

model._register_state_dict_hook(rename_keys)

Hugging Face Transformers 库的 from_pretrained / save_pretrained 内部就有大量这类 hook,把 PyTorch 原始命名转换成 HuggingFace 标准命名(如 bert.encoder.layer.0.attention.self.query ↔ HF 风格)。这是 ecosystem 互操作性的工程基础。

9.7 load_state_dict:参数加载与名字匹配

module.py:2531load_state_dict 反向操作。它做:

  1. _load_state_dict_pre_hooks
  2. 递归遍历模块树,每层 module 的 _load_from_state_dict 把 state_dict 里对应 prefix 的 key 拷到自己的 _parameters / _buffers
  3. _load_state_dict_post_hooks
  4. 收集 missing_keys / unexpected_keys / error_msgs,根据 strict 参数决定 raise 还是 warn

最常见的”加载报错”场景:训练时模型结构变了,旧 ckpt 的 key 与新模型对不上。strict=False 让加载继续但记录差异,方便迁移学习场景:

# 加载预训练 ResNet, 替换最后的 fc 层
model = MyResNet(num_classes=10)
state_dict = torch.load('resnet50.pth')
missing, unexpected = model.load_state_dict(state_dict, strict=False)
# 通常 unexpected = ['fc.weight', 'fc.bias'] (旧 fc,被新 num_classes 改了)

第 19 章序列化章会展开 torch.save / torch.load 与 state_dict 的协作,包括 safetensors 格式的优势。

9.7.0.5 missing_keys 与 unexpected_keys 的诊断价值

load_state_dict 返回的 missing_keys / unexpected_keys 是诊断”模型 vs ckpt 不匹配”的金矿:

  • missing:当前模型有但 ckpt 没有的 key —— 通常是模型新加了层,ckpt 是老版本
  • unexpected:ckpt 有但当前模型没有的 key —— 通常是模型删了层,或 ckpt 多保存了不必要的 buffer

调试技巧:

missing, unexpected = model.load_state_dict(state_dict, strict=False)
print('Missing keys:', missing[:5])
print('Unexpected keys:', unexpected[:5])
print(f'Match rate: {1 - (len(missing) + len(unexpected)) / len(state_dict):.1%}')

匹配率 > 95% 通常说明 ckpt 大体兼容,少量差异可能是 BN 的 num_batches_tracked 或 dropout 的 random state 这种次要 buffer。匹配率 < 80% 通常意味着模型架构有重大变化,加载后行为可能不正常。

9.7.1 加载流程的细节

load_state_dict 实际工作流:

flowchart TB
    Start[load_state_dict 入口]
    Start --> Pre[跑 _load_state_dict_pre_hooks]
    Pre --> Load[递归 _load:<br/>对每个 module 调 _load_from_state_dict]

    Load --> Each["_load_from_state_dict 做:<br/>1. 跑 _load_state_dict_pre_hooks (per-module)<br/>2. 把 state_dict 里 prefix+name 的 tensor 拷到 _parameters/_buffers<br/>3. 处理 missing/unexpected keys<br/>4. 跑 _load_state_dict_post_hooks (per-module)"]

    Each --> Post[跑全局 _load_state_dict_post_hooks]
    Post --> Check{strict=True?}
    Check -->|是, 有 missing/unexpected| Err[raise RuntimeError]
    Check -->|否| Return[返回 missing_keys, unexpected_keys]

    style Each fill:#fef3c7,stroke:#f59e0b

注意 数据拷贝是 inplace:state_dict 里的 tensor 用 param.copy_(loaded_tensor) 拷到现有的 Parameter,不是创建新对象。这样原模型的 Parameter 引用保持有效(optimizer 的 param_groups 不会失效),但权重值已经更新。这是为什么 load 后不需要重新 build optimizer。

这种 “in-place loading” 设计有个隐形约束:state_dict 里的 tensor 必须形状与现有 Parameter 完全一致。形状不一致会报错。如果你想换形状(如改 num_classes),要先重新构造 model 或者用 strict=False 跳过那一项。

9.8 8 种 hook:模块生命周期的扩展点

nn.Module 提供 8 种 hook,按调用时机分类:

Hook注册方法触发时机
forward_pre_hookregister_forward_pre_hookforward 调用前,可改 args
forward_hookregister_forward_hookforward 调用后,可改 result
full_backward_pre_hookregister_full_backward_pre_hook反向开始前
full_backward_hookregister_full_backward_hook反向结束后,能拿 grad_input / grad_output
backward_hook (deprecated)register_backward_hook早期版本,与 full_backward 不兼容
state_dict_pre_hook_register_state_dict_hookstate_dict 调用前
state_dict_hook_register_state_dict_hookstate_dict 调用后,可改返回值
load_state_dict_pre/post_hook_register_load_state_dict_hookload_state_dict 前后

每种 hook 都是为特定扩展场景设计的:

  • profiler / 性能分析:用 forward_pre/post hook 记时间戳
  • 混合精度 cast:用 forward_pre hook 把 fp32 输入转 fp16
  • 梯度统计 / 调试:用 full_backward_hook 检查 grad_input / grad_output
  • DDP gradient bucketing:用 backward hook 检测某 grad 算完后立即触发 AllReduce(第 17 章)
  • 量化 observer:forward_hook 收集激活分布,PT2E 量化的核心
  • 检查点平均:state_dict_hook 在 save 时拦截

每条 hook 都是 PyTorch 与生态对接的关键点。理解 hook 机制能让你优雅扩展任何 nn.Module 行为,不需要改源码。

9.8.0 hook 与 dispatcher mode 的关系

第 5 章 §5.7 我们看到 TorchDispatchMode 也是一种”拦截每个算子”的扩展点。Module hook 与它的区别:

维度Module hookTorchDispatchMode
拦截粒度一整个 Module 的 forward每一个 ATen 算子
注册方式module.register_forward_hook(fn)with MyMode():
看到的输入Module 的输入参数ATen 算子的张量参数
适合场景”在某层之后做 X""对所有算子做 X”

两者不是互斥的 —— 你可以同时用。Module hook 是粗粒度(按层),DispatchMode 是细粒度(按算子)。生产代码里大部分扩展用 Module hook(更易理解),少数极致 hack 用 DispatchMode(如假执行、量化 observer)。

9.8.2 RemovableHandle 设计

每个 register_xxx_hook 都返回 RemovableHandle

handle = model.register_forward_hook(my_hook)
# 后续...
handle.remove()    # 移除

RemovableHandle 是个简陋的 dataclass:持有 hook id 和注册它的字典引用。remove() 就是从字典 pop 出 id。这种”返回 handle 让用户可以撤销”是 Python 标准库 weakref.WeakValueDictionary 等也用的模式。

为什么不直接给用户 hook 函数引用让他们传回来?因为同一个函数可能被注册多次(在不同 hook 集上),handle 的 id 让”第 N 次注册”可识别可撤销。这种”id-based 注册系统”在 RxJava、JS 事件监听里都有体现。

9.8.5 一个高级技巧:用 hook 实现 model surgery

“model surgery” 是指在不改源码的前提下修改训练好的模型行为。例如:把 ResNet 的某层 ReLU 换成 GELU,但只改一层、保持其他层原样。

def replace_relu_with_gelu(module):
    for name, child in module.named_children():
        if isinstance(child, nn.ReLU):
            setattr(module, name, nn.GELU())
        else:
            replace_relu_with_gelu(child)

replace_relu_with_gelu(model)

这段代码递归遍历所有 module,遇到 ReLU 就替换成 GELU。setattr(module, name, ...) 触发 __setattr__,自动从 _modules 删除旧的 ReLU、加入新的 GELU。整套 surgery 操作零侵入、不需要改原模型代码

这种能力让 PyTorch 在研究场景里非常灵活 —— 可以快速尝试”如果某层换成 X 会怎样”。HuggingFace 的 PEFT (LoRA / Prefix-tuning) 库就大量用 model surgery 把”普通 Linear 替换成 LoRALinear”。

9.8.1 一个真实例子:grad clipping with hook

def grad_clip_hook(module, grad_input, grad_output):
    return tuple(g.clamp(-1.0, 1.0) if g is not None else g for g in grad_input)

model.register_full_backward_hook(grad_clip_hook)

这种实现比 torch.nn.utils.clip_grad_norm_ 更精细 —— 在每个 module 反向时立即裁,不等所有梯度算完。但需要小心 inplace clamp 与 autograd 的交互,所以生产代码通常仍用标准 API。

9.8.3 hook 执行的隐形顺序

每种 hook 之间的调用顺序是有确定 spec 的,按 _wrapped_call_impl 实现的顺序:

  1. global forward_pre_hook(用 nn.modules.module.register_module_forward_pre_hook 注册,对所有 module 生效)
  2. self.forward_pre_hook(per-module)
  3. backward_pre_hook 设置(如果有,会插入 BackwardHook 包装输入)
  4. self.forward(args)
  5. global forward_hook
  6. self.forward_hook
  7. backward_hook 设置(包装输出)
  8. non_full_backward_hook 注册到 grad_fn

理解这套顺序很重要。比如你写一个 forward_hook 想看到”经过其他所有 hook 修改之后的最终 result”,要确保你的 hook 注册时机最晚(dict 是按插入顺序),或者用 prepend=True 参数。这种顺序敏感性是 Module hook 与 dispatch mode 的另一个区别 —— DispatchMode 用栈结构、最近的 mode 先拦截,逻辑更清晰。

9.8.4 用 hook 实现 GPT-style cache:一个高级例子

class KVCache:
    def __init__(self):
        self.cache = {}

    def hook(self, module, input, output):
        # 假设 module 是 attention 层
        layer_id = id(module)
        self.cache[layer_id] = output[1]   # K, V

cache = KVCache()
for layer in model.transformer.layers:
    layer.attention.register_forward_hook(cache.hook)

这是 transformer 推理时手动实现 KV cache 的 toy 写法。生产代码的 KV cache(如 vLLM)当然更精巧(PagedAttention),但思路一致:用 hook 拦截每层 attention 输出,存下 K/V,下一次 forward 时复用。Module hook 让”对模型的非侵入式扩展”非常直接。

9.9 train() / eval():模式切换

def train(self, mode: bool = True):
    self.training = mode
    for module in self.children():
        module.train(mode)
    return self

这个方法把 self.training 设为 True/False 并递归子模块。BatchNorm、Dropout 等模块在 forward 时检查 self.training 决定行为:

class Dropout(Module):
    def forward(self, input):
        if self.training:
            return F.dropout(input, self.p, ...)
        else:
            return input    # eval 模式下 dropout 是 identity

新手最常见的 bug 就是 inference 时忘了 model.eval() —— BN 仍然用 batch 统计、Dropout 仍然丢神经元、模型行为非确定。生产 inference 服务必须 model.eval()

注意 eval()no_grad()正交的概念

  • eval():切到推理模式(影响 BN/Dropout 行为)
  • no_grad():关掉 autograd(影响梯度计算)

生产 inference 通常两者都要:

model.eval()
with torch.no_grad():
    out = model(x)

或者更高效的:

model.eval()
with torch.inference_mode():
    out = model(x)

9.9.1 model.children() vs model.modules()

两个递归方法容易混:

  • children() 只返回直接子模块(不递归到子模块的子模块)
  • modules() 递归返回所有模块(包括 self、子子模块……)

类似还有 named_children()named_modules() 带名字版本。生产代码里:

# 给所有 Conv2d 的 weight 做 init
for m in model.modules():
    if isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight)

modules() 返回深度优先遍历整个网络。这是 PyTorch 给”按类型批量处理参数”留的最方便接口。

children()modules() 看似冗余,但场景不同:

  • 想给”每个直接子层”应用变换(如最外层 wrap 一下):用 children()
  • 想搜整个网络里所有某类型层(如所有 BN):用 modules()
  • 想知道某层在网络里的命名位置:用 named_modules()

apply(fn) 内部就是 for m in self.modules(): fn(m) 加 self return。这种”递归遍历 + 用户提供闭包”模式让 PyTorch 与训练流程框架(如 Lightning)能轻松实现”对所有参数做 X”的统一接口。

9.9.5 train/eval 在分布式与 inference 服务中的微妙性

model.train() / model.eval() 看似简单,分布式训练里有些坑:

1. DDP 包装后的 train/eval:DDP 把 model 包成 DistributedDataParallel(model),调 ddp_model.train()递归设置 inner module 的 training。但有些第三方库的 wrapper 不递归,需要显式 ddp_model.module.train()

2. 模型分阶段切换:典型场景是先训练后用一段 inference 子图。常见错误:

model.eval()
out = model.encoder(x)   # eval 模式
loss = some_loss(out)
loss.backward()           # 仍然会反向, 但 BN 用的是 running stats 而非 batch stats
                          # 训练效果会很差

正确做法:明确分开训练阶段和推理阶段,不要混合。

3. 推理服务的并发安全model.eval() 设置了 self.training=False,多线程并发推理时不会互相影响(因为每个推理只读 training)。但如果你某条路径里又调了 model.train(),并发的其他推理会被污染。生产服务里不要在请求处理路径切换 train/eval

9.10 跨书关联

  • 《vLLM 内核探秘》第 7 章 模型加载:vLLM 加载 HuggingFace 模型时直接构建 PyTorch nn.Module 树,再用自定义 weight loader 喂权重 —— 完全依赖本章的 state_dict 机制
  • 《Tokio 异步运行时》第 X 章 任务上下文:Tokio Task 的 spawn / hook 注册机制与 nn.Module 的 hook 思想类似,都是”在统一抽象上预留扩展点”
  • 《Serde 元编程》派生宏:Serde 的 #[derive(Serialize)] 在编译期生成的逻辑,与 nn.Module 在运行期通过 __setattr__ 拦截做的事是同一目的:自动收集类的字段
  • 《MCP 协议剖析》Server 注册:MCP server 用 decorator 注册 method handler 与 nn.Module hook 注册思想极相似 —— 都是”对类的字段做声明式处理”

9.10.5 一个对照:torch.fx vs nn.Module

最后值得提一下 torch.fx。它是 PyTorch 用来”把 nn.Module 转成符号化 IR”的工具:

import torch.fx

model = MyModel()
gm = torch.fx.symbolic_trace(model)
# gm 是一个 torch.fx.GraphModule, 它的 forward 是一段被记录的 IR
print(gm.graph)

fx.symbolic_trace__torch_function__ 协议拦截每个算子调用,把它记录成 IR 节点。这套 IR 是 torch.compile 的早期基础(v1.x 时代),现在被 Dynamo 部分取代但仍在量化、模型变换等场景广泛使用。

理解 fx 与 nn.Module 的关系:nn.Module 是面向人的 API(layers + forward),fx Graph 是面向编译器的 IR。两套表示互相可转换:fx.symbolic_trace 把 Module 转成 Graph;fx.GraphModule 又是一种 Module(继承自 nn.Module)。这种”原始 + 编译”两套表示并存是 PyTorch 框架的经典哲学,也是第 12-15 章 torch.compile 章节的根本设计。

注意 fx 现在已经被 Dynamo 取代主导地位(Dynamo 更鲁棒、能处理更多 Python 控制流),但 fx 在某些特定场景仍然广泛使用:

  • 量化 (quantization):PT2E 量化的 prepare / convert 阶段大量用 fx
  • 模型变换 / pruning:剪枝、重参数化等需要明确的 Graph IR 才好做
  • 简单的 trace 场景:用户想看”我的 Module 实际跑了哪些算子”时,symbolic_trace 比 Dynamo 简单得多

第 14 章 Inductor 章会展开 fx 与 Dynamo 之间的接力。

9.11 几条工程建议

实战 nn.Module 相关的最佳实践:

1. 永远 super().__init__()__init__ 第一行:忘了会让 _parameters 等字典没初始化,后续赋值报错”cannot assign parameters before Module.init()”

2. 学习参数用 nn.Parameter:直接 self.weight = torch.randn(...) 不会被自动注册为参数

3. 不参与训练但要保存的状态用 register_buffer:如 BN 的 running stats、自定义的位置编码

4. 临时中间状态不要存为 attribute:会被 state_dict 收进来,让 ckpt 变大。用局部变量。但注意:纯 Python 对象(list、dict、自定义类)不会被 state_dict 收 —— 只有 Tensor / Parameter / Buffer / Module 才被自动收集

5. 子模块要写到 __init__ 而非 forward:在 forward 里 nn.Linear() 每次都新建模块,权重不会被训练

6. 注意 hook 的 order 影响:多个 hook 按注册顺序调用,写的时候要明确”我能假设其他 hook 已经跑过吗”

7. model.to(device) 是 inplace 操作:返回 self,但同时也改了 self。new = model.to(device)model.to(device) 等价

8. nn.ModuleList / nn.ModuleDict 而非 Python list / dictself.layers = [nn.Linear(10,10) for _ in range(5)] 不会自动注册(普通 list 不被 __setattr__ 检测),self.layers = nn.ModuleList([...]) 才会

9. nn.Parameter 注册时机:必须在 __init__ 期间注册。运行期动态加 Parameter(如 if cond: self.weight2 = nn.Parameter(...))会让 optimizer 拿不到(因为 optimizer 在 model 创建后就读了一次 parameters())

10. model.apply(fn) 递归调用apply 给每个 module 调一次 fn(module)。常用于初始化:model.apply(init_weights)。它的实现就是 for m in self.modules(): fn(m) 加 self return

11. forward 里不要用 self.training 之外的 mode 切换:除了 BN / Dropout 这种内置约定,自定义 module 不要在 forward 里读其他 self.xxx 切换行为。这样会让 torch.compile 看到的图依赖运行时变量,触发 graph break

12. 模型结构包含可变 list / Tensor list:用 ParameterList / BufferList 而非普通 Python list。例如 self.scales = nn.ParameterList([nn.Parameter(torch.ones(1)) for _ in range(K)]),自动注册

9.11.1 给 Module 设计师的”通用启示”

把 nn.Module 思想抽象到任何”组件容器”系统:

第一__setattr__ 拦截 + __getattr__ 兜底是 Python 元编程的黄金组合 —— 让用户感受不到框架在背后做的事,写代码像写普通类一样自然

第二子节点用专门容器(dict / list)而非 self.x = child:让递归遍历、序列化、迁移操作都有统一入口

第三hook 系统让框架可扩展:在生命周期关键点开放可注册的回调列表,第三方工具就能集成进来。RemovableHandle 让用户能撤销注册

第四state 与 metadata 分离:state_dict 存”值”,_metadata 存”版本号 / 类型信息”。前者让数据可移植,后者让兼容性成为可能

第五train/eval 是模型生命周期的两个不同阶段:用一个 self.training flag 切换,相关组件查这个 flag 调整行为。这种”全局上下文 + 组件就地查询”模式比”为不同模式写两份代码”清晰得多

把这五条记住,你写自己的”组件框架”(如 GUI 框架、游戏引擎的 Entity-Component 系统、CRM 的工作流模板)能少走很多弯路。

第六让递归遍历成为一等接口:parameters() / modules() / children() / named_xxx() 这种”按容器维度遍历”接口让框架的所有变换、保存、迁移操作都有标准入口

第七train/eval / persistence / device 这三种”全局开关”互不干扰:每种开关由不同字段表示,Module 在合适的层面响应。这种”正交概念分离”让用户能任意组合而不互相破坏

9.11.5 性能数字:nn.Module 调用开销

具体数字(H100,PyTorch v2.11,单线程 CPU):

操作开销
一次 module(x) 调用(无 hook)~3-5 us
一次 forward_hook 触发+0.5-1 us
一次 backward_hook 触发+1-2 us
module.named_parameters() 遍历(10 层网络)~10-20 us
module.state_dict()(10 层网络)~50-100 us
module.load_state_dict()(10 层网络,无形状变化)~200-500 us

对一次 forward 包含几十层 module 的大模型,nn.Module 调用开销加起来约 几十微秒到几百微秒,相对于实际计算(毫秒级)不显著。但对小模型(如 MNIST classifier)反复 forward / backward 时,nn.Module 自身开销可能占总时间 10-20%。

torch.compile 把整个 module(x) 调用编进单个 Triton kernel,这部分 nn.Module 开销直接归零 —— 这是 compile 在小模型上能拿到 50%+ 加速的关键。

9.12 nn.Module 的 to/cuda/half 实现

model.cuda() 把整个模型搬到 GPU。它的实现就是用 _apply 递归遍历每个 parameter / buffer,对每个张量调 .cuda()

def cuda(self, device=None):
    return self._apply(lambda t: t.cuda(device))

def _apply(self, fn, recurse=True):
    if recurse:
        for module in self.children():
            module._apply(fn)
    # apply to all parameters
    for key, param in self._parameters.items():
        if param is not None:
            with torch.no_grad():
                param_applied = fn(param)
            self._parameters[key] = Parameter(param_applied, ...)
    # apply to all buffers
    for key, buf in self._buffers.items():
        if buf is not None:
            self._buffers[key] = fn(buf)
    return self

_apply 是 nn.Module 里最通用的一个抽象 —— 它接受任意”对张量的变换”函数,递归整个网络。.cuda() / .cpu() / .half() / .float() / .bfloat16() / .to(device, dtype) 全部用它实现。

注意细节:参数变换后是用 Parameter() 包装回去(保留 requires_grad 等元数据),buffer 直接覆盖。这是因为 buffer 不需要 Parameter wrapper。

为什么参数变换后要重新包成 Parameter 而不是直接 inplace?因为某些变换(如 .to(device))会返回新张量(device 改变时数据需要重新分配)。新张量是普通 Tensor,需要重新包装回 Parameter 以保留 requires_grad 等元信息。如果是同 device、同 dtype 的变换,PyTorch 会做一些优化避免重新分配,但接口语义上始终是”返回新对象”。

9.12.0.5 _apply 与 autograd 的微妙交互

.cuda() 等变换内部用 with torch.no_grad(): 包住,避免变换本身被 autograd 记录。否则你 model.cuda() 的瞬间会触发 _to_copy 算子的反向图,最后整个网络复制到 GPU 的过程被记入反向图 —— 显存爆炸。

这种”对底层 housekeeping 操作显式禁用 autograd”是 PyTorch 内部很多基础设施的共同做法。optimizer.step() 内部也用 no_grad 包,避免参数更新被反向追踪。第 10 章会再看到这种模式。

9.12.1 一个有趣的 _apply 应用:模型量化

def quantize_to_int8(t):
    return t.to(torch.int8) if t.is_floating_point() else t

quantized = model._apply(quantize_to_int8)

类似机制让 PyTorch 的 quantization、AMP autocast、device offloading 等高级功能都能站在 _apply 这套递归基础上写。第 20 章量化会展开。

9.12.5 _save_to_state_dict vs _load_from_state_dict 的对偶性

为什么 PyTorch 把 save 和 load 写成 _save_to_state_dict / _load_from_state_dict 这种”_私有方法”模式而不是公开 API?

因为 它们是 module 子类可以重写的扩展点。如果你的 module 有特殊持久化需求(如某个张量需要量化后存、或者把多个张量打包成一个),可以重写这两个方法:

class MyModule(nn.Module):
    def _save_to_state_dict(self, destination, prefix, keep_vars):
        # 自定义存储格式
        destination[prefix + 'packed'] = self._pack_for_save()

    def _load_from_state_dict(self, state_dict, prefix, local_metadata, ...):
        # 自定义加载逻辑
        if prefix + 'packed' in state_dict:
            self._unpack(state_dict[prefix + 'packed'])

这种”框架提供 hook 点 + 子类重写细节”是 Template Method 设计模式的应用。PyTorch 的 quantization、稀疏张量、甚至 nn.Embedding 都用了重写 _load_from_state_dict 来处理特殊情况。

9.13 一个完整的 Module 树:HuggingFace BERT 解构

实战看一段 HuggingFace BertModel 的 model.named_modules() 输出(节选):

'' (BertModel)
'embeddings' (BertEmbeddings)
'embeddings.word_embeddings' (Embedding)
'embeddings.position_embeddings' (Embedding)
'embeddings.token_type_embeddings' (Embedding)
'embeddings.LayerNorm' (LayerNorm)
'encoder' (BertEncoder)
'encoder.layer.0' (BertLayer)
'encoder.layer.0.attention.self.query' (Linear)
'encoder.layer.0.attention.self.key' (Linear)
'encoder.layer.0.attention.self.value' (Linear)
'encoder.layer.0.attention.output.dense' (Linear)
'encoder.layer.0.intermediate.dense' (Linear)
'encoder.layer.0.output.dense' (Linear)
'encoder.layer.1' (BertLayer)
...

每条路径就是 dotted prefix,对应 state_dict 里的 key。这种”树结构 + dotted path”的设计让 BERT 这种 12 层 / 24 层的复杂模型能用扁平 dict 完整表达,加载、迁移、部署都极简单。HuggingFace、Lightning 等所有上层框架都建立在 nn.Module 这套机制之上。

观察一个细节:encoder.layer.0 这种带数字的路径来自 nn.ModuleList。它实现了 __setitem__ 把列表索引转成字符串名字('0''1' 等),存到 _modules 字典。所以 ModuleList 表面上是列表,底层仍然是 dict —— 索引化路径自然形成。

类似地,nn.ModuleDict 让用户用字符串 key 索引子模块,但它的所有 key 都直接进 _modules理解 ModuleList / ModuleDict / Sequential 都只是 dict 的语法糖,能让你写自己的”结构化 module 容器”时心里有数。

nn.Sequential 是另一个重要 ModuleList 派生:它把子模块用整数索引按顺序串起来,forward 自动 for m in self: x = m(x)。常用于线性堆叠的网络:

mlp = nn.Sequential(
    nn.Linear(10, 100),
    nn.ReLU(),
    nn.Linear(100, 10)
)

mlp[0] 是第一个 Linear、mlp[-1] 是最后一个 Linear。这是 nn.Module 给”简单堆叠模型”的便利包装。但 Sequential 不能处理需要分支 / 跳连接的复杂结构,那些场景要写完整的 forward

9.14 一个收官的反思:nn.Module 的成功是 PyTorch 的成功

某种程度上,PyTorch 之所以击败 TF 1.x 赢得学术界,不只是动态图,还因为 nn.Module 比 tf.keras.Model 更自然

tf.keras.Model 早期要求用户走 functional API(x = Dense(10)(x))或者 subclass + build() 方法 —— 都比 PyTorch 的 self.linear = nn.Linear(...) 多了一层心智负担。研究者要快速实验新架构时,PyTorch 的 nn.Module 写起来更接近”在草稿纸上画的网络结构图”。

这条经验给所有想做”用户友好框架”的设计师一个启示:让用户写代码的方式贴近他思考的方式。如果用户脑子里想的是”线性层 → ReLU → 线性层”,框架就该让他这样写代码。中间多一个抽象层都是认知摩擦。

PyTorch 的 nn.Module 把”魔法量”控制在 两个 dunder 方法 + 几个内部 dict 这个最小集 —— 已经能让用户体验最自然,又不至于黑盒到完全不可调试。这种 “隐藏复杂度但保留可观察性” 的工程平衡,是 nn.Module 设计上最高明的一笔。

下一章拆 optimizer:optim.SGD / optim.Adam 等怎么实现,包括最近引入的 Foreach / Fused / Capturable 三种性能模式。

评论 0