第7章 Autograd 设计原理与反向图构建
“Autograd is what makes PyTorch feel magical. Take any Python computation, append
.backward(), and gradients appear. The ‘magic’ is a careful collaboration between the dispatcher, a hidden graph builder, and a multithreaded engine.”—— Soumith Chintala, “PyTorch internals” PyCon 2019
本章要点
- 反向图不是显式建的:每次前向算子被调用时,autograd 中间层(dispatcher 上的 Autograd key)偷偷创建一个
Node,挂到输出张量的grad_fn_上 - Node + Edge 双向链表构成 DAG:每个 Node 持有
next_edges_指向 上游 Node。loss.backward()从 loss 的 grad_fn 出发反向遍历整张图 AutogradMeta是张量的”反向身份证”:藏在 TensorImpl 里,存grad_/grad_fn_/grad_accumulator_/fw_grad_。inference 张量的 autograd_meta_ 是 nullptr,autograd 路径完全不走SavedVariable让前向中间值”穿越到反向”:带 version 检查(防 inplace 污染)、支持 hooks(用于显存优化)、weak_ptr(防循环引用)AccumulateGrad是叶子张量的梯度”沉淀池”:用户参数张量都对应一个 AccumulateGrad 节点,反向时所有梯度汇入这里autograd.Function是用户级自定义 autograd 接口:写forward+backward两个静态方法,和 PyTorch 内置算子的 autograd 路径完全等价
7.1 一行代码引发的迷思
每个 PyTorch 用户都写过这样的代码:
import torch
a = torch.randn(2, 3, requires_grad=True)
b = torch.randn(3, 2, requires_grad=True)
c = a @ b
loss = c.sum()
loss.backward()
print(a.grad) # 神奇地有了!
print(b.grad) # 也有了!
loss.backward() 这一行做了什么?严格地讲:
- 它怎么知道反向应该从
loss开始? - 它怎么知道
loss是从c.sum()来的?c是从a @ b来的? - 它怎么算梯度?matmul 的反向公式从哪里来的?
- 它怎么知道 把
a的梯度填到a.grad?
答案是 PyTorch 在前向时偷偷建了一张反向图。这张图记录了”loss 是怎么从 a、b 算出来的”,反向时沿图反走,应用每个算子的反向公式。本章把这张图的构建过程彻底拆开。
graph BT
subgraph Forward["前向 (a @ b → c.sum() → loss)"]
F1[a 张量<br/>requires_grad=True]
F2[b 张量<br/>requires_grad=True]
F3[c = a @ b]
F4[loss = c.sum]
F1 --> F3
F2 --> F3
F3 --> F4
end
subgraph Backward["反向图 (在前向时偷偷建好)"]
B0[GraphRoot<br/>backward 起点]
B1[SumBackward0]
B2[MmBackward0]
B3[AccumulateGrad<br/>对应 a]
B4[AccumulateGrad<br/>对应 b]
B0 --> B1
B1 --> B2
B2 --> B3
B2 --> B4
end
F4 -.grad_fn.-> B1
F3 -.grad_fn.-> B2
F1 -.grad_acc.-> B3
F2 -.grad_acc.-> B4
style B0 fill:#fef3c7,stroke:#f59e0b
style B3 fill:#dcfce7,stroke:#22c55e
style B4 fill:#dcfce7,stroke:#22c55e
7.2 AutogradMeta:张量身上的”反向身份证”
回到第 2 章 §2.4:TensorImpl 有一个 autograd_meta_ 字段:
// c10/core/TensorImpl.h
std::unique_ptr<c10::AutogradMetaInterface> autograd_meta_ = nullptr;
它是 unique_ptr 默认 nullptr —— 不需要梯度的张量根本不分配 autograd_meta。需要时它指向 torch::autograd::AutogradMeta:
// torch/csrc/autograd/variable.h
struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface {
std::string name_;
Variable grad_; // 累积的梯度
std::shared_ptr<Node> grad_fn_; // 产生这个张量的 Node (中间张量)
std::weak_ptr<Node> grad_accumulator_; // 累积梯度的 Node (叶子张量)
mutable std::shared_ptr<ForwardGrad> fw_grad_; // forward-mode 切线
std::vector<std::unique_ptr<FunctionPreHook>> hooks_;
...
};
——这就是张量的”反向身份证”。理解几个字段:
grad_fn_:如果张量是某个算子的输出(interior),这个字段指向产生它的Node。a @ b产生的c,c.grad_fn就是MmBackward0grad_accumulator_:如果张量是用户创建的叶子(leaf),这个字段指向一个AccumulateGrad节点 —— 反向时所有汇入的梯度都加到grad_上grad_:累积的梯度。loss.backward()跑完后用户读a.grad就是读这个字段fw_grad_:forward-mode 切线(jvp 用)
grad_fn_ 和 grad_accumulator_ 是互斥的:一个张量要么是某算子的输出(有 grad_fn)、要么是叶子(有 grad_accumulator),不会同时有。
为什么用 weak_ptr 持有 grad_accumulator?因为 leaf 张量本身可能很多(如所有模型参数),而 AccumulateGrad 节点的存活由反向图的引用关系控制。如果叶子张量强引用 AccumulateGrad,AccumulateGrad 又通过 next_edges 上溯连到反向图 —— 形成长寿循环。用 weak_ptr 让 AccumulateGrad 在反向图被释放时自动消失,下次反向再 lazy 创建。这种”按需复活”是 PyTorch 内存管理上又一个细节优化。
7.2.1.5 lazy materialize_autograd_meta
AutogradMeta 不是张量构造时就分配,而是 lazy 分配。打开 torch/csrc/autograd/variable.h:
TORCH_API AutogradMeta* materialize_autograd_meta(const at::TensorBase& self);
只有第一次需要写 autograd 信息(如设置 grad_fn)时才分配。这种 lazy 模式让 inference 路径上的张量完全不付 autograd 元数据的内存代价 —— 一个 inference 张量(不需要梯度)的 TensorImpl 比 training 张量小约 100 字节。在大模型推理里这能省下数百 MB 显存。
7.2.1 leaf vs interior:autograd 视角下的两类张量
a = torch.randn(3, requires_grad=True) # leaf — 用户创建
b = torch.randn(3, requires_grad=True) # leaf
c = a + b # interior — 算子结果
d = c * 2 # interior
print(a.is_leaf) # True
print(c.is_leaf) # False
print(a.grad_fn) # None
print(c.grad_fn) # <AddBackward0 object>
leaf 张量有梯度累积器(grad_accumulator)但没有 grad_fn。interior 张量反过来。这种区分是反向传播终止条件的体现 —— 反向沿图走到 AccumulateGrad 节点就把梯度填到张量的 grad_ 字段,停下来。
7.3 Node 与 Edge:反向图的骨架
打开 torch/csrc/autograd/function.h:113:
struct TORCH_API Node : std::enable_shared_from_this<Node> {
public:
virtual variable_list apply(variable_list&& inputs) = 0;
// 指向"上游" Node 的边
edge_list next_edges_;
uint64_t sequence_nr_; // 创建顺序,反向时按 priority 排序
uint32_t topological_nr_;
...
};
每个 Node 是一个反向函数对象,apply 方法接收”输出梯度”返回”输入梯度”。next_edges_ 是反向图的核心:
// torch/csrc/autograd/edge.h:14
struct Edge {
std::shared_ptr<Node> function; // 指向上游 Node
uint32_t input_nr; // 是上游 Node 的第几个输入
};
Edge = (Node 指针, input 索引)。一个 Node 的 next_edges 是它”上游”Nodes 的列表,每条边告诉反向引擎:当我把梯度往上传时,应该传给哪个 Node 的第几个输入。
graph BT
subgraph CN["MmBackward0 (c = a @ b 的反向 Node)"]
N["MmBackward0<br/>──────────<br/>apply(grad) {<br/> return [<br/> grad @ b.T, ← input 0<br/> a.T @ grad ← input 1<br/> ]<br/>}<br/>──────────<br/>next_edges_:<br/> [(AccumulateGrad_a, 0),<br/> (AccumulateGrad_b, 0)]"]
end
AGA[AccumulateGrad_a]
AGB[AccumulateGrad_b]
N -- "next_edges[0] (input_nr=0)" --> AGA
N -- "next_edges[1] (input_nr=0)" --> AGB
style N fill:#fef3c7,stroke:#f59e0b,stroke-width:2px
注意 input_nr 字段的作用:当多个 Node 的输出共同流到同一个上游 Node 的多个输入时,每条 Edge 必须知道”我对应的是哪个输入位”。否则梯度会错乱。
7.3.1 反向图是 DAG,不是树
考虑 c = a + a:
graph BT
AB[AddBackward0]
AGA[AccumulateGrad_a]
AB -->|input 0| AGA
AB -->|input 1| AGA
AddBackward0 的两条 next_edges 都指向同一个 AccumulateGrad_a —— 因为两个输入都是 a。反向时,AddBackward0 计算出 grad_input0 和 grad_input1,两者会被加起来累积到 a.grad。这种”多边汇入同一节点”是 DAG 的标准结构。
function.h:78-81 的注释明确说了这件事:
When two or more
Edges (from different sources) point at the same input to aNode, the values produced along all of these edges are implicitly summed prior to being forwarded to the targetNode.
反向引擎自动处理这种隐式求和,用户代码完全感受不到。
这一行隐式求和语义是反向传播链式法则在多路径场景下的精确实现 —— 数学上 dL/dx = ∑ (dL/dy_i) (dy_i/dx),引擎用”汇入同一节点的多条边自动相加”实现这条公式。理解这条规则,你就能预测任意 DAG 的反向行为,不需要逐个手动求导。
7.3.2 sequence_nr 与 topological_nr
每个 Node 还有两个看似冗余的字段:
sequence_nr_:thread-local 单调递增的创建顺序号。反向时按它做优先级排序(让较新创建的 Node 优先执行,启发式地把”靠近 loss”的反向先做完)topological_nr_:拓扑序号,从叶子开始计数。它让反向引擎能在 O(1) 判断”两个节点中哪个更接近 root”,用于实现next_edges的提前剪枝
这两个号在大型反向图(如 Llama-70B 的反向)里影响调度效率几个百分点。它们是 PyTorch 反向引擎在保持”DAG 拓扑遍历”语义同时榨性能的工程细节。
7.3.3 反向图的”惰性构建”特性
值得强调一个反直觉的事实:前向时建立的反向图,在 loss.backward() 调用前都不真正消耗计算。它就是一些 shared_ptr<Node> 互相指向的对象图。每个 Node 持有几个张量引用(SavedVariable 里)和几条 Edge,但不做任何数学计算。
这意味着即便你前向跑了几百层 transformer 创建了几千个反向 Node,在 loss.backward() 调用前的 CPU 开销基本可以忽略。真正的 backward 计算延迟到 backward 时才发生。这种”前向定义 + 反向执行”的拆分让 PyTorch 能精确报告反向时的内存峰值与时间分布,是 profiler 工具的工作前提。
7.4 反向图怎么在前向时被偷偷建
sequenceDiagram
participant U as 用户代码
participant V as VariableType wrapper
participant N as 新 Node MmBackward0
participant K as ATen kernel mm
participant T as 输出 Tensor
U->>V: c = a @ b 触发 mm
V->>V: 1. collect_next_edges: 收集 a/b 的 grad_fn
V->>N: 2. new MmBackward0
V->>N: 3. set_next_edges 连到 a.grad_fn b.grad_fn
V->>N: 4. SavedVariable 保存 a/b 给反向用
V->>K: 5. redispatch 到真实 mm
K-->>V: 输出 tensor
V->>T: 6. set_history: c.grad_fn = MmBackward0
V-->>U: c
Note over T,N: 反向图已经悄悄建好<br/>c.backward 时沿 grad_fn 反走
到了核心问题:前向算子调用怎么变出反向图?答案在第 5 章 §5.5 提过的 VariableType 包装。每个可微算子在 tools/autograd/derivatives.yaml 声明反向规则,由 tools/autograd/gen_variable_type.py 自动生成包装函数。
简化的 VariableType::mm 长这样:
// 简化版生成代码
Tensor mm(const Tensor& self, const Tensor& mat2) {
// 1. 收集 next_edges (上游 Node)
auto& self_meta = impl::get_autograd_meta(self);
auto& mat2_meta = impl::get_autograd_meta(mat2);
edge_list next_edges = collect_next_edges(self, mat2);
// 2. 创建反向 Node
auto grad_fn = std::shared_ptr<MmBackward0>(
new MmBackward0(),
deleteNode);
grad_fn->set_next_edges(std::move(next_edges));
// 3. 保存反向需要的张量
grad_fn->self_ = SavedVariable(self, false);
grad_fn->mat2_ = SavedVariable(mat2, false);
// 4. redispatch 到下一层 (实际计算)
auto result = at::redispatch::mm(after_autograd_keyset, self, mat2);
// 5. 把 grad_fn 挂到输出
set_history(result, grad_fn);
return result;
}
每一步都要看清楚:
7.4.1 collect_next_edges:连接上游
collect_next_edges 遍历输入张量,每个 requires_grad=true 的张量贡献一条 Edge:
- 如果输入是 interior(有
grad_fn_):Edge 指向那个 grad_fn - 如果输入是 leaf(有
grad_accumulator_):Edge 指向 AccumulateGrad
这就是反向图边连接的全部魔法。前向时每个算子各自决定自己的 next_edges,整张反向图就在不知不觉中拼好了。
为什么这种设计如此简洁?因为反向图的边**只取决于”哪个张量是这次 op 的输入”**这一信息,而这是每个算子在调用时一定知道的。不需要全局状态、不需要事后分析、不需要 trace pass —— 每个算子各自管自己的 next_edges,整张图自然涌现。这是声明式 vs 命令式哲学的胜利:定义”每条边怎么决定”,整张图自然存在。
7.4.2 set_history:把 grad_fn 钉在输出上
set_history(result, grad_fn) 等价于:
materialize_autograd_meta(result)->grad_fn_ = grad_fn;
result.unsafeGetTensorImpl()->set_requires_grad(true);
这一步让 result 也变成 autograd 体系的一员 —— 它有 grad_fn、有 requires_grad,下一个算子调用时能从 result 推出 next_edges。整个反向图就这样”链式自传染”。
7.4.3 “如果输入都不 requires_grad,跳过整个反向构建”
VariableType::mm 真实代码(生成在 torch/csrc/autograd/generated/VariableType_4.cpp)的第一步其实是检查:
auto _any_requires_grad = compute_requires_grad(self, mat2);
std::shared_ptr<MmBackward0> grad_fn;
if (_any_requires_grad) {
grad_fn = std::shared_ptr<MmBackward0>(new MmBackward0(), deleteNode);
// ... save tensors, set next_edges ...
}
auto result = at::redispatch::mm(...);
if (grad_fn) {
set_history(result, grad_fn);
}
return result;
如果所有输入都不 requires_grad,跳过 grad_fn 创建。这种短路让”参数张量混着 inference 张量”的常见 inference 路径不为 autograd 付任何额外开销。理解这条路径,你能精确推理:当模型 weight 关闭 grad、只跑 inference 时,autograd 的 dispatcher 中间层走的是 fallthrough + 这条短路 —— 性能与”完全不带 autograd”几乎一致。
7.4.4 反向图随调用栈生长,不随时间增长
很多新手会担心”长跑训练的反向图会不会越来越大”。答案是不会 —— 反向图的大小由当前 forward 的复杂度决定,每次 loss.backward() 跑完后整张反向图被释放(因为 loss.grad_fn 引用消失)。下一个 batch 的 forward 重新建一张全新的反向图。
例外是 retain_graph=True:用户显式说”我还要再 backward 一次”,反向图保留下来。这种用法在 second-order derivatives、GAN 训练等场景出现,但绝大多数训练不需要。第 8 章 Engine 章会讲 retain_graph 怎么影响调度。
7.4.5 跟踪一个真实的 forward → backward graph
让我们用一段最简单的代码完整跟踪反向图构建:
import torch
a = torch.tensor([1., 2.], requires_grad=True) # leaf
b = torch.tensor([3., 4.], requires_grad=True) # leaf
c = a + b # → AddBackward0
d = c * 2 # → MulBackward0
e = d.sum() # → SumBackward0
e.backward()
每行做的事:
a = torch.tensor(..., requires_grad=True):构造 leaf 张量。AutogradMeta 被分配,grad_fn_=None,grad_accumulator_是 weak_ptr → 待 lazy 创建的AccumulateGrad_ab = ...:同理,对应AccumulateGrad_bc = a + b:dispatcher 命中VariableType::add→compute_requires_grad(a, b)返回 true- 创建
AddBackward0Node collect_next_edges(a, b)返回[Edge(AccumulateGrad_a, 0), Edge(AccumulateGrad_b, 0)]grad_fn->set_next_edges(next_edges)- redispatch 调真 add,得到 c
set_history(c, AddBackward0)—— c.grad_fn = AddBackward0
d = c * 2:类似地 →MulBackward0,next_edges =[Edge(AddBackward0, 0)](c 是 interior,grad_fn 就是上游 Node)。注意常数2不参与反向,所以只有一条 edgee = d.sum():SumBackward0,next_edges =[Edge(MulBackward0, 0)]e.backward():从 e.grad_fn (SumBackward0) 出发反向遍历,最终梯度填到 a.grad / b.grad
整张反向图:
SumBackward0 → MulBackward0 → AddBackward0 → AccumulateGrad_a
↓
AccumulateGrad_b
每个节点的 next_edges 完全在前向时被正确填好。用户写的就是 4 行 Python 代码,PyTorch 在背后偷偷拼出了完整的反向 DAG。
7.5 SavedVariable:前向中间值穿越到反向
很多反向计算需要前向时的中间值。mm 的反向需要 mat2 才能算 grad_self = grad @ mat2.T。但 mat2 是普通 Python 变量,怎么保证它在反向时还活着?
答案是 SavedVariable(torch/csrc/autograd/saved_variable.h:22 类定义,v2.11 实测):
class TORCH_API SavedVariable {
private:
at::Tensor data_; // 保存的张量数据
std::shared_ptr<ForwardGrad> fw_grad_; // forward AD 切线
std::weak_ptr<Node> weak_grad_fn_; // 防循环引用 (inplace view 场景)
uint32_t saved_version_ = 0; // 保存时的 version_counter
uint32_t output_nr_ = 0;
bool was_default_constructed_ = true;
bool is_inplace_on_view_ = false;
bool saved_original_ = false;
bool is_leaf_ = false;
bool is_output_ = false;
std::unique_ptr<SavedVariableHooks> hooks_;
...
};
它的关键责任:
7.5.1 防 inplace 污染:saved_version_
第 3 章 §3.10.5 提过的”version_counter”在这里登场。SavedVariable 在保存时记下 saved_version_:
SavedVariable::SavedVariable(const Variable& v, bool is_output, ...) {
data_ = v;
saved_version_ = v.unsafeGetTensorImpl()->version_counter().current_version();
...
}
反向 unpack 时检查:
Variable SavedVariable::unpack(...) const {
if (data_.unsafeGetTensorImpl()->version_counter().current_version() != saved_version_) {
throw std::runtime_error(
"one of the variables needed for gradient computation has been "
"modified by an inplace operation");
}
return data_;
}
——这就是那个所有 PyTorch 训练老手都见过的经典报错的来源!它防止用户对 mat2 做 inplace 修改后反向算出错误梯度。
7.5.2 防循环引用:weak_grad_fn_
考虑 inplace view 场景:
x = torch.randn(3, requires_grad=True)
y = x[0] # view
y *= 2 # inplace modify x[0]
y.sum().backward()
y *= 2 触发的反向 Node 需要保存 y。但 y 又是 MulBackward0 节点的输入 —— 强引用会形成循环:MulBackward0 → SavedVariable(y) → AutogradMeta(y).grad_fn_ → MulBackward0。
SavedVariable 用 weak_grad_fn_ 打破这个循环:当保存的张量与 grad_fn 形成潜在循环时,存 weak_ptr 而不是 shared_ptr。反向 unpack 时 weak_grad_fn_.lock() 拿回 grad_fn —— 如果中间没人持有,循环引用自动解除。
7.5.3 saved variable hooks:用计算换显存
PyTorch 还提供 SavedVariable::register_hooks 接口(saved_variable.h:49,v2.11 实测)—— 用户可以为每个保存的张量挂自定义 pack / unpack 钩子:
# 把保存的中间值放到 CPU 而非 GPU,节省显存
def pack_to_cpu(t):
return t.cpu()
def unpack_from_cpu(t):
return t.cuda()
with torch.autograd.graph.saved_tensors_hooks(pack_to_cpu, unpack_from_cpu):
out = model(x)
loss = out.sum()
loss.backward()
这是大模型训练显存优化的最重要工具之一:把中间值从 GPU 换到 CPU 甚至磁盘,反向时再换回。activation checkpointing(梯度检查点)的底层就是基于这套 hook 实现的。第 18 章 FSDP / 第 20 章 量化训练会再用到。
具体的钩子语义:pack_to_cpu 在 SavedVariable 构造时被调用,用户决定怎么”压缩”待保存的张量;unpack_from_cpu 在 backward 用到这个张量时被调用,把”压缩”格式还原。PyTorch 不强制 pack 一定是 device-to-host 拷贝 —— 你完全可以做更激进的事,比如把张量量化成 int8 存下来再反向时反量化(牺牲精度换显存),或者写到 NVMe 磁盘(牺牲速度换更大显存)。HuggingFace Accelerate 库里有完整的 cpu_offload 实现,背后就是这套 hook。
7.5.3.5 一个真实优化案例:activation_checkpoint
torch.utils.checkpoint.checkpoint(fn, *args) 的实现就是基于 SavedVariable hooks 的”伪 forward + 真 forward 重算”:
class CheckpointFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, fn, *args):
ctx.fn = fn
ctx.save_for_backward(*args)
with torch.no_grad(): # 前向不建反向图
return fn(*args)
@staticmethod
def backward(ctx, *grad_outputs):
args = ctx.saved_tensors
with torch.enable_grad():
outputs = ctx.fn(*args) # 反向时重新前向,这次建图
return torch.autograd.grad(outputs, args, grad_outputs)
实际 PyTorch 的 checkpoint 实现更复杂(处理 RNG state、AMP 上下文等),但核心就是这套思想。前向时不保存中间值,反向时重算一遍换得显存峰值大幅下降。在 Llama-70B 训练上,开 checkpoint 能让显存峰值降 30-50%,代价是 backward 时间增加 30%(额外的重算)。
7.5.4 SavedVariable 的内存账
每一个 SavedVariable 持有 at::Tensor data_(一个引用计数指针,约 8 字节),但它强引用底层 storage。这意味着只要反向图存活,所有 SavedVariable 引用的中间值都不能被释放。
举个例子:一个 70 层 transformer,每层保存约 5 个中间张量(attention QKV 输出、MLP 中间、norm 输入等),每个张量几十 MB —— 反向图存活期间持有的总显存可能达到 30-50 GB。这是为什么大模型训练显存峰值出现在 forward 完成、backward 即将开始那一刻 —— 所有 SavedVariable 都已经在那里等着。
activation_checkpoint 的核心思想就是:在前向时把某些区段的 SavedVariable 释放掉,反向到那个区段时重新前向算一遍取回中间值。用计算换显存,对大模型几乎是必备技术。第 18 章 FSDP 章会展开。
7.6 AccumulateGrad:叶子张量的梯度”沉淀池”
leaf 张量的 grad_accumulator 是一个特殊 Node 类:AccumulateGrad(torch/csrc/autograd/functions/accumulate_grad.h)。它的 apply 实现就是把传入的梯度加到张量的 grad_ 字段:
// 简化的 AccumulateGrad::apply
variable_list AccumulateGrad::apply(variable_list&& grads) {
auto new_grad = grads[0];
auto* meta = impl::get_autograd_meta(variable);
if (!meta->grad_.defined()) {
meta->grad_ = new_grad; // 第一次设置
} else {
meta->grad_ += new_grad; // 累积
}
return {}; // 没有输出 (sink)
}
注意 AccumulateGrad 没有 next_edges_(反向遍历到这里就停了)也没有输出。它是反向图的终点。
7.6.1 为什么训练前要 optimizer.zero_grad()
理解了 AccumulateGrad,你就明白为什么 PyTorch 训练要在每个 batch 开始时调 zero_grad():
for batch in dataloader:
optimizer.zero_grad() # ← 关键
loss = model(batch).sum()
loss.backward()
optimizer.step()
不调 zero_grad 的话,AccumulateGrad 会把当前 batch 的梯度加到上一个 batch 的梯度上,而不是替换。这是 PyTorch 选择”梯度累积式”语义的设计 —— 让”梯度累积”(gradient accumulation,把多个小 batch 的梯度合并成大 batch)这种用法天然支持。
7.6.0.5 leaf 张量的 grad_accumulator 是 lazy 创建的
回到 §7.2:leaf 张量的 grad_accumulator_ 字段是 weak_ptr<Node>。它什么时候才被实际创建?
答案是第一次有算子需要它的时候。当 c = a + b 触发 VariableType::add,需要为 a 创建 next_edge。代码大致:
// 第一次需要时 lazy 创建
auto a_acc = impl::try_get_grad_accumulator(a);
if (!a_acc) {
a_acc = std::make_shared<AccumulateGrad>(a);
impl::set_grad_accumulator(a, a_acc);
}
edge_for_a = Edge(a_acc, 0);
第二次再用 a 做某算子时,try_get_grad_accumulator(a).lock() 拿回上次创建的 accumulator,复用。
这种 lazy 创建让”创建了 leaf 张量但从没用过”的场景不付 AccumulateGrad 的开销 —— 比如 inference 模式下定义了 weight 张量但根本不反向,AccumulateGrad 永远不创建。
7.6.1 AccumulateGrad 的”叶子梯度仅在 leaf 累积”语义
新手常困惑:为什么 c.grad 是 None?c 是 interior 张量,不应该也有梯度吗?
答案是:默认情况下 PyTorch 只在 leaf 张量上累积梯度。这是有意为之的优化 —— 训练里你只关心模型参数(leaf)的梯度,中间张量的梯度只是”传递的中转值”,存下来浪费显存。
如果你确实想看 c 的梯度(如调试),调 c.retain_grad():
c = a + b
c.retain_grad() # 告诉 PyTorch 反向时也保留 c.grad
loss = (c * 2).sum()
loss.backward()
print(c.grad) # 现在有值了
retain_grad 的实现是在 c 身上挂一个 hook,反向流到 c 时把梯度同时存到 c.grad 字段。代价是显存上升(多保留一个张量)。生产代码不该用,调试时很方便。
7.6.1.5 多次 backward 与 retain_graph
默认情况下 loss.backward() 跑完后反向图被释放(每个 SavedVariable 调 reset_data,shared_ptr 解引用)。如果你想再跑一次反向(如 second-order derivatives),要传 retain_graph=True:
loss.backward(retain_graph=True) # 保留反向图
# 现在还能再跑
torch.autograd.grad(loss, model.parameters()) # 用 functional 接口再算一次
retain_graph=True 的代价是反向图保留,所有 SavedVariable 仍然持有中间值 —— 显存峰值翻倍。所以只在必要时用,且尽快释放。
第 8 章 Engine 章会展开 GraphTask 怎么处理 retain_graph 标志。
7.6.0 optimizer.zero_grad(set_to_none=True) 与 set_to_none=False
老版本 PyTorch 的 zero_grad() 默认行为是 grad_.zero_() —— 一次 inplace 写零。但这要遍历每个参数张量、调一次 zero kernel,开销不小。新版本(1.7+)默认 set_to_none=True:直接把 grad_ 设为 None,省掉一次 zero kernel。
set_to_none=True 改变了一个细微语义:之前的”梯度永远是 0 张量”变成”第一次累积前梯度是 None”。这影响一些手工读 param.grad 的代码 —— 它们要先判断 param.grad is None。绝大多数 PyTorch 代码不受影响,但写 grad clip 或自定义 optimizer 时要留意。
性能差距:在 70B 模型训练里,set_to_none=True 比 set_to_none=False 每 step 节省约 5-10ms。看似小,长跑训练加起来是几十小时的差距。
但代价就是要手动 zero_grad。新手忘了 zero_grad 会发现损失曲线非常奇怪 —— 这是 PyTorch 训练里最经典的 bug 之一。
7.7 requires_grad / no_grad / inference_mode 的精确语义
第 5 章 §5.6.3 我们对比过这三种 mode,autograd 视角下它们的差异更清楚:
| 特性 | 普通张量 | no_grad 上下文 | inference_mode 上下文 |
|---|---|---|---|
autograd_meta_ | nullptr 或 AutogradMeta | 同左 | 永远 nullptr |
requires_grad | 用户控制 | 临时 false | 张量结构上不支持 |
| dispatcher 路径 | 走 AutogradXxx kernel | 走 fallthrough | 完全不带 Autograd key |
| 创建反向图 | 是 | 否 | 否 |
| dispatcher 开销 | 高 (含 redispatch) | 中 (走 fallthrough) | 低 (无 autograd key) |
inference_mode 比 no_grad 快 5-10% 在小算子上。代价是 inference 张量”不可逆”地不能再参与反向 —— 即使 with 块外赋 requires_grad 也不行。
实战建议:生产推理服务用 inference_mode(性能优势真实),调试 / 评估用 no_grad(更宽容、可以临时再开 grad)。这种”两档 mode”是 PyTorch 给不同场景留的灵活性。
7.7.0.5 三种 mode 的实际选择建议
实战里遇到三种 mode 的选择困难,按场景给一份决策表:
| 场景 | 推荐 |
|---|---|
| 训练循环里只是 forward 看 loss 值 | no_grad |
| 训练循环里临时跑一段不参与反向的子图 | no_grad |
| 调试 / 单元测试 | no_grad |
| 生产 inference server 的 hot loop | inference_mode |
| Dataloader / 预处理 pipeline | inference_mode(数据张量本来就不反向) |
| 只读访问参数(如导出权重、统计参数大小) | inference_mode |
| 在 inference 里偶尔需要反向(如对抗样本生成) | no_grad 配合临时 enable_grad |
特别提醒:Dataloader 里如果你在 worker 进程里创建 tensor,用 inference_mode 创建可以避免后续 H2D 时附带的 autograd 元数据开销。这是大多数 dataloader 没注意但能挤出几个百分点 throughput 的优化。
7.7.1.5 一个不常被提到的细节:grad_mode 与 dispatcher 的协作
with torch.no_grad(): 的实现是修改 c10::AutogradState 的 thread-local flag,dispatcher 在 getDispatchKeySetUnboxed 时把 Autograd* 加入 excluded 集合(第 5 章 §5.3.1)。它不直接改张量元数据 —— 用户在 with 块内访问张量,仍能看到 requires_grad=True。这种”不改张量、改 dispatcher”的设计让 grad_mode 成为线程级开关,不影响张量的全局状态。
inference_mode 的实现完全不同:它改张量构造时的 key_set_,让构造出来的张量不带 Autograd key。这个区别对调试很重要 —— no_grad 退出后原本的 autograd 行为恢复;inference_mode 创建的张量永久”无法反向”。理解这条区别能避免一些反直觉的 bug。
7.7.1 requires_grad 在 dispatcher 上的体现
每次构造张量时 PyTorch 检查输入张量的 requires_grad:如果任一输入 requires_grad,输出张量的 requires_grad 自动 true,且 key_set_ 加入 AutogradXxx 位。这就是 requires_grad 传染性的实现 —— 沿计算图自动传播,用户不需要每次手动设。
7.8 autograd.Function:用户级自定义 autograd
如果你想给 PyTorch 加一个新的可微操作 —— 比如自定义 attention,PyTorch 提供 torch.autograd.Function(Python 接口)和 torch::autograd::Function(C++ 接口):
class MyExp(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return x.exp()
@staticmethod
def backward(ctx, grad_output):
(x,) = ctx.saved_tensors
return grad_output * x.exp()
# 使用
x = torch.randn(3, requires_grad=True)
y = MyExp.apply(x)
y.sum().backward()
这套接口的底层实现就是把用户写的 forward 和 backward 包装成一个 Node 类,注入到 PyTorch 的反向图体系。ctx.save_for_backward 内部就是创建 SavedVariable;return grad_output * x.exp() 就是 Node::apply 的返回值。
任何用 autograd.Function 写的算子和 PyTorch 内置算子在反向图层面是平等公民。这套接口让用户能在不改 PyTorch 主仓的前提下扩展 autograd。第 22 章自定义算子会再讲。
7.8.1 autograd.Function 与 torch.library.custom_op 的区别
PyTorch 现在有两套自定义算子接口:
| 接口 | 适用场景 |
|---|---|
torch.autograd.Function | 纯 Python 实现,需要自定义反向规则 |
torch.library.custom_op (v2.4+) | 与 dispatcher / torch.compile 完全对接的”主流”接口 |
autograd.Function 是历史悠久的接口,简单直接,但它对 torch.compile 的支持较弱 —— Inductor 看到 MyExp.apply(x) 时往往会 graph break,回到 eager 路径。
custom_op 是新接口,要求用户多注册几样东西(FakeTensor 实现、autograd 实现等),但保证 torch.compile 能正确编译。所以新代码推荐 custom_op,旧代码继续用 autograd.Function 也无妨。
第 22 章自定义算子会完整对比这两套接口的优劣与适用场景。
7.8.2 ctx.save_for_backward 的版本检查
ctx.save_for_backward(x) 在 C++ 层就是创建 SavedVariable(x, ...)。所以 §7.5.1 讨论的版本检查在用户级 autograd.Function 里同样生效 —— 如果你保存了 x,然后在 forward 完成后又对 x 做 inplace 修改,反向时同样会触发”one of the variables needed for gradient computation has been modified”。
很多新手在写 autograd.Function 时遇到这个错,原因是 forward 里做了类似 x.add_(1) 的事 —— 而 PyTorch 默认期望 forward 是纯函数式的。把 inplace 改成 out-of-place 通常就能解决。
7.9 几条容易被忽略的细节
实战里几个易混淆点:
1. tensor.detach() 与 tensor.data 的区别:detach() 返回一个新张量,与原张量共享 storage 但 requires_grad=False 且 grad_fn=None,version_counter 是新的(不会污染原张量的反向)。tensor.data 是危险接口,直接绕过版本检查,错用容易引发反向错乱
2. torch.set_grad_enabled(False):等价于 no_grad() 但函数式风格,可以在条件分支里更灵活地控制
3. tensor.requires_grad_(True) 的传染性:对 leaf 张量调可以,对 interior 张量(有 grad_fn)调会报错,因为这破坏反向图一致性
4. with torch.no_grad(): 内部创建的张量:这些张量 grad_fn 全是 None。即便 with 块外面再设 requires_grad=True 也无济于事 —— 反向图被永久砍掉了
5. tensor.retain_grad():默认情况下 interior 张量没有 grad_ 字段(反向时梯度只流向 leaf)。如果你想看某个中间张量的梯度,要在前向时调 retain_grad()。它的实现是给那个 interior 张量挂一个特殊 hook,反向流到这一步时把梯度同时存到 grad_ 字段里。代价是显存上升
6. autograd 与多 device 张量:如果一个 Node 的输出是 CPU 张量但 next_edges 指向的上游 Node 输出是 CUDA 张量,反向时 PyTorch 会自动 H2D 拷贝。这很方便,但容易隐藏性能问题 —— 用 profiler 找到这种”隐式 H2D”是优化训练吞吐的常见动作
7. with torch.enable_grad(): 嵌套在 no_grad() 里:这是合法的,让你在 inference 大块代码里临时开启 autograd。第 9 章 nn.Module 章会演示一些场景(如 BatchNorm 的 running stats)
7.9.5 backward 时的钩子 (hook) 系统
autograd 还支持在反向流程的特定时刻插入用户回调:
| Hook 类型 | 注册方式 | 触发时机 |
|---|---|---|
| Tensor backward hook | tensor.register_hook(fn) | 反向流到这个张量时(grad 已经算出但还没传给 next_edges) |
| Module hook | module.register_full_backward_hook(fn) | 反向经过整个 nn.Module 时 |
| Pre-grad hook | param.register_post_accumulate_grad_hook(fn) | leaf 张量的 grad 累积完成后 |
最常见的用法是梯度裁剪 (gradient clipping):
# 经典用法:裁剪超过阈值的梯度
for p in model.parameters():
if p.requires_grad:
p.register_hook(lambda grad: torch.clamp(grad, -1.0, 1.0))
但生产代码通常不这样用 —— 因为 torch.nn.utils.clip_grad_norm_ 提供了更标准的接口。Hook 主要在调试时有用:在某个层的反向时打印梯度统计、检查 NaN 是否在某层出现等。
这套 hook 机制的实现是给 Node 挂一个 vector<unique_ptr<FunctionPostHook>>,反向流到那个 Node 时按顺序调用所有 hook。源码在 torch/csrc/autograd/function.h 与 function_hook.h。
7.10 跨书关联
- 《Tokio 异步运行时》第 X 章 work-stealing 调度器:autograd Engine 的多线程模型与 Tokio 调度器的 work-stealing 思想极其相似。第 8 章会做详细对照
- 《Rust 编译器之路》第 X 章 编译期求导:Rust 生态有
enzyme等 LLVM 编译期 AD 库。它们与 PyTorch 的运行期 AD 形成对照 —— 各有取舍 - 《MCP 协议剖析》第 X 章 上下文传递:MCP 在分布式调用里传递 context 与 PyTorch 在 dispatcher 里通过
c10::DispatchKeySet+ thread-local 传递 autograd 状态有相通思想
7.11 一个练习:手画一段代码的反向图
import torch
a = torch.randn(3, requires_grad=True)
b = torch.randn(3, requires_grad=True)
c = a * b
d = c.sum()
e = a.exp()
f = e.sum()
loss = d + f
试着画出反向图:每个张量是哪种类型(leaf / interior)、每个 grad_fn 是什么 Node 类、每条 next_edge 指向哪里、a 的梯度怎么从 loss 流回来。
提示:a 出现在两条路径上(a*b 和 a.exp()),所以 AccumulateGrad_a 会有两条 in-edge。loss.backward() 时这两路梯度会自动求和到 a.grad。
画完之后用 torch.autograd.grad 或者 loss.backward() 验证:
loss.backward()
print(a.grad)
# 应该是 b + a.exp() (因为 d=a*b, f=a.exp().sum, loss=d+f → dL/da = b + e^a)
7.12 一些 autograd 关键细节的总结
把 autograd 的精髓压成几条:
- autograd 是 dispatcher 的中间层:通过
Autograd*DispatchKey 注入到正常算子调用链,redispatch 后让真正的 backend kernel 接手 - 反向图在前向时建好:每个可微算子调用都创建一个
Node,next_edges 指向上游 Node,让整张图自然链起来 - Node 是 base class,每个算子有自己的
XxxBackward子类:apply方法实现具体反向数学 - leaf 张量有 grad_accumulator (AccumulateGrad),interior 张量有 grad_fn,两者互斥
- SavedVariable 用 version_counter 防 inplace 污染,用 weak_ptr 防循环引用
requires_grad沿计算图自传染,no_grad/inference_mode提供两档”关闭 autograd”的语义
如果你能在脑子里画清”a + b 触发的 dispatcher → VariableType::add → 创建 AddBackward0 → next_edges 指向 a/b 的 grad_fn → redispatch 调真 add → set_history → 返回 c”这一连串过程,本章就内化了。
7.12.5 Compiled Autograd:把反向图也编译
PyTorch 2.4+ 引入了 Compiled Autograd(在 torch/_dynamo/compiled_autograd.py),让反向 Engine 也能被 Inductor 编译。简单原理:
- 用户照常 forward +
loss.backward() - PyTorch 把”按 next_edges 反向遍历 + 调每个 Node.apply”这套调度逻辑自身捕捉成 FX graph
- 把这个 graph 喂给 Inductor 编译成 Triton kernel
- 后续 backward 直接调编译后的 binary,跳过 Engine 的多线程 work-stealing 开销
效果:在小算子密集的反向(典型如 LSTM 反向、MoE 模型反向)上能让 backward 提速 30-50%。代价是首次编译几秒,后续命中缓存。
启用方式(v2.4+):
import torch._dynamo
torch._dynamo.config.compiled_autograd = True
它仍是实验性功能,第 13 章 AOTAutograd 章会展开。理解 compiled autograd 的工作原理需要先吃透本章(autograd 怎么建反向图)+ 第 8 章(engine 怎么执行)+ 第 12 章(Dynamo 怎么 trace)—— 它是这几个系统协同的产物。
7.13 几条工程经验
实战里 autograd 相关 issue 大致分三类,附常见诊断方法:
类型 A:反向报错 “one of the variables … has been modified by an inplace operation”
诊断:用 torch.autograd.set_detect_anomaly(True) 让 PyTorch 在前向时记录每个 Node 的创建栈,反向报错时打印精确位置。修法:把可疑的 inplace 操作(*=、+=、add_、relu_)改成 out-of-place。
类型 B:训练 loss 不下降 / NaN 不收敛
诊断:可能是忘了 zero_grad 让梯度累积。或者某个 Node 的反向规则有 bug 导致梯度方向错。用 torch.autograd.gradcheck 可以数值验证某个算子的反向是否正确。
类型 C:显存峰值过高
诊断:通常是 SavedVariable 持有太多中间值。用 torch.cuda.memory._record_memory_history() snapshot 看(第 4 章 §4.11)。修法:activation_checkpoint、混合精度(第 20 章)、CPU offload(saved_tensors_hooks)。
7.14 跨书关联补充
- 《vLLM 内核探秘》第 4 章 PagedAttention:vLLM 的推理路径完全不构建反向图(pure inference)。理解 PyTorch autograd 的”非侵入式”设计后再看 vLLM 跳过 autograd 的方式,就能理解 PyTorch 在训练 / 推理两侧的统一性
- 《Serde 元编程》派生宏:
tools/autograd/gen_variable_type.py自动从derivatives.yaml生成VariableType::add等包装函数,与 Serde 派生宏(自动生成 serialize/deserialize 代码)思想完全一致
7.14.1 几条 autograd 设计的”通用启示”
如果你设计自己的 AD 系统(深度学习外的领域:金融衍生品定价、物理模拟),本章思想能直接迁移:
第一:反向图作为 DAG,每个节点写自己的 apply —— 不要试图把整张图集中在一个数据结构里管理。每个 Node 自治、用 Edge 连接,是最 modular 的设计
第二:分离声明(forward)与执行(backward) —— forward 时只建 IR,不真做反向计算。这让 forward 的代码路径专注于生成图,backward 路径专注于调度执行
第三:用 weak_ptr 打破循环引用 —— 任何”图节点 + 张量元数据互引”的设计都要小心循环引用,weak_ptr 是经典武器
第四:version_counter 防 inplace 污染 —— 任何”延迟计算”系统都要有”我保存的值是不是还有效”的校验机制,version 号是最简方案
第五:lazy materialize_autograd_meta —— 元数据按需分配。99% 的对象不需要那块元数据,给所有对象都分配是巨大浪费
把这五条记下来,写自己的 AD 系统能少走很多弯路。
第六:version_counter / weak_ptr / lazy 创建是性能预算的”三件套” —— 任何”延迟构造 + 多对象共享 + 安全检查”的系统都会用到这三件套,PyTorch 是它们配合极致的范例
7.15 一个反思:为什么 PyTorch 选了”运行时建图”
最后回到一个高维度问题:为什么 PyTorch 选择前向时偷偷建反向图这条路,而不是 JAX 那种”先 trace 再编译”或者 TF 1.x 那种”显式定义图”?
PyTorch 设计者的回答(参考 NeurIPS 2019 论文 PyTorch: An Imperative Style, High-Performance Deep Learning Library):
- define-by-run 让控制流自由:用 Python 的 if/for 写动态网络,反向图自然跟着控制流走,不需要
tf.cond/lax.scan这类图原语 - 构图开销可承受:每个算子建一次 Node + 几条 Edge 的开销在毫秒级,相对于数学计算可忽略
- 代码即调试器:反向图等同于”前向跑过的代码路径”,调试时打开 stack trace 立刻看到结构
代价是反向图每次都重建(不能像静态图那样跨 step 复用编译)。torch.compile 在 v2.0 之后用”AOTAutograd”机制把反向图也编译进 Inductor,弥补了这个性能缺口(第 13 章会展开)。这种”动态图 + 可选编译”是 PyTorch 哲学的核心,也是它赢得学术界 → 工业界飞轮的根本。
值得对比的是:JAX 的 grad 是函数变换 g = jax.grad(f),把整个 f 编译成一个新的反向函数。它要求 f 是纯函数(不能有 inplace、不能有副作用),换来的是反向函数的极致编译优化。PyTorch 的 loss.backward() 是引擎驱动的图遍历,允许 inplace 与各种动态行为,代价是反向时不能像 JAX 那样跨调用复用编译。
两条路在哲学上是镜像的:JAX 把”什么是反向”拆成数学问题让编译器解;PyTorch 把”反向怎么调度”拆成数据结构让引擎跑。前者依赖严格函数式假设,后者依赖运行时元信息。理解这两种 AD 设计的本质差异,你能在选型时做出有依据的判断 —— 学术研究 / 动态网络偏 PyTorch,函数式建模 / 静态优化偏 JAX。
下一章拆 Engine:这张反向图怎么被多线程引擎执行 —— 看 PyTorch 的 work-stealing 调度器。
评论 0
还没有评论,来说两句吧。
评论加载失败,刷新重试。