第13章 AOTAutograd:函数化与正反向图划分
“AOTAutograd is the bridge between Dynamo (capture) and Inductor (codegen). It takes a forward graph, traces backward through it, makes everything functional, and partitions into forward/backward subgraphs.”
—— PyTorch dev docs
本章要点
- AOTAutograd 在 Dynamo 之后、Inductor 之前:把 forward-only graph 扩展成 (forward + backward) joint graph
- Functionalization 把 inplace 操作改写成纯函数式:
x.add_(y)→x = x.add(y),让下游 IR 没有 mutation 副作用 - min-cut partition 用图论里的最小割算法决定”哪些中间张量保留给反向用”:在显存与重算之间找最优平衡
- 三个产物:forward graph + backward graph + saved tensor list
- AOTAutograd 也会被 Inductor 用作 backend:直接
aot_eagerbackend 不上 Inductor,等同于”trace 反向但不编译”,用于调试 - 它是 PyTorch 编译器栈最复杂的一段:~6000 行 Python,处理无数 corner case(subclass、view、inplace、AMP、checkpoint)
13.1 Dynamo 留下的烂摊子
第 12 章我们看到 Dynamo 输出 forward FX Graph + Guards + example_inputs。但 Inductor 想要的更多:
- Joint forward-backward graph:训练时 backward 也是热路径,必须一起编译
- 纯函数式 IR:Inductor 内部图变换不希望处理 inplace 副作用
- 明确的 saved tensor 列表:哪些中间值要保留给 backward
这三件事都不在 Dynamo 的职责内。AOTAutograd(Ahead Of Time Autograd)补上这一段。源码主要在 torch/_functorch/aot_autograd.py(v2.11 实测 1878 行)和 torch/_functorch/_aot_autograd/ 目录(实测 ~16500 行)。整个 torch/_functorch/ namespace 共 ~32000 行,是 PyTorch v2.x 编译栈的核心组件。
13.2 整体流程
flowchart TB
Dy[Dynamo 输出<br/>forward FX Graph]
Dy --> Trace[1 反向 trace<br/>用 functorch.grad / vjp 在 fake tensor 上跑]
Trace --> Joint[joint graph<br/>forward + backward 节点混在一起]
Joint --> Func[2 functionalization<br/>inplace → 纯函数]
Func --> Part[3 min-cut partition<br/>切成 fw_module + bw_module]
Part --> FW[forward 子图]
Part --> BW[backward 子图]
Part --> Save[saved tensors 列表]
FW --> Ind[Inductor 编译 forward]
BW --> Ind2[Inductor 编译 backward]
style Joint fill:#fef3c7,stroke:#f59e0b
style Part fill:#dbeafe,stroke:#3b82f6,stroke-width:2px
四步:trace → functionalize → partition → 各自送 Inductor。
13.3 反向 trace:用 fake tensor 跑出 backward
Dynamo 给的是 forward graph(一组 ATen 算子调用)。AOTAutograd 怎么得到 backward?答案是 再 trace 一遍:
def joint_fn(primals, tangents):
out = forward_graph(*primals)
grads = torch.autograd.grad(out, primals, tangents)
return out, grads
primals 是输入张量,tangents 是输出梯度。autograd.grad 触发反向 —— 但所有张量都是 FakeTensor(第 5 章 §5.7 提过 FakeTensor 是只算 shape/dtype 的”假张量”)。整个反向不真做计算,只生成节点。
trace 完拿到一张 joint graph:包含 forward 节点 + backward 节点 + 所有中间值。这张图就是 后续两步的输入。
13.4 Functionalization:消除 inplace
joint graph 里可能有 x.add_(y)、x.copy_(y) 等 inplace 操作。这种操作让图分析变难(节点 X 的输出可能被节点 Y 偷偷修改)。Functionalization 把它们重写成纯函数式:
inplace 版: x.add_(y) # x 被原地修改
functional: x_new = x.add(y) # x 不变, 返回新张量
实现机制:用 FunctionalTensorMode(第 5 章 §5.7 讲的 TorchDispatchMode 之一)拦截每个 inplace 算子,把它替换成对应的 out-of-place 版本 + 一个 copy_back(如果用户原本期待原地修改)。这套改写对用户透明,但下游 IR 里所有节点都变成纯函数。
为什么这件事重要?因为 Inductor 的算子融合 / 重排算法假设节点没有副作用。有 inplace 时,“a + b 节点”的输入可能在它执行前被另一个节点 inplace 改了 —— 重排顺序就崩。Functionalization 一刀斩断 mutation,让下游图变换可以放心做。
代价是显存上升(每个 inplace 变成 out-of-place 多一份张量)。但 Inductor 后面有 buffer reuse pass 能把这些”逻辑上独立但实际可复用”的张量重新合并 inplace —— 等同于先抹掉 inplace 让分析简单、再恢复 inplace 让运行高效。这种”先简化再优化”是编译器经典做法。
13.5 min-cut partition:在显存与重算之间找最优
joint graph 里 forward 与 backward 节点混在一起。要切成两段子图,关键问题是:哪些 forward 中间张量要保留给 backward 用?
直观选项有两个极端:
- 全保留:所有 forward 中间值都保留 → 显存爆炸
- 全重算:backward 时从 input 重新跑一遍 forward → 计算翻倍
实际最优在两者之间:保留”算起来贵的”中间值,丢弃”重算便宜的”。这就是经典的最小割问题。
torch/_functorch/partitioners.py:min_cut_rematerialization_partition(v2.11 中位于该文件 line 3449 的核心函数)建模为:
- 每个 forward 节点是图中的点,边连着它的依赖
- 给”保留某节点的输出”赋值
节点输出张量大小(保留代价 = 显存) - 给”重算某节点”赋值
节点计算 cost(重算代价 = 计算量) - 用网络流算法(max-flow / min-cut)找出”切割集”,让总代价最小
切割集就是要保留的张量集合。算法保证 forward 之后只把这些张量传给 backward,其他全部丢弃;backward 时从这些保留的张量出发,必要时重算。
这套算法的优雅之处:它自动决定 activation_checkpoint 应该 checkpoint 什么。用户不需要手动标”这层重算”,min-cut 自己算出来。
实测在 Llama 类 transformer 上,min-cut 自动选择保留 attention 输出 + LayerNorm 输出,丢弃 QKV 投影中间值(QKV 是 GEMM,重算便宜;attention softmax 算起来贵)。这与人手设计的 activation_checkpoint 策略高度一致。
13.6 输出三件套
partition 完拿到:
fw_module:forward 子图 —— 接收 primals,返回 outputs + saved tensorsbw_module:backward 子图 —— 接收 saved tensors + tangents,返回 gradientssaved_tensor_indices:保留张量索引 —— 告诉 runtime 哪些 fw 输出要喂回 bw
Runtime 流程(forward 调用时):
fw_outputs = fw_module(*primals)
saved = [fw_outputs[i] for i in saved_tensor_indices]
real_outputs = [fw_outputs[i] for i in real_output_indices]
# 用户拿到 real_outputs
# saved 暂存等 backward 用
# backward 调用时
gradients = bw_module(*saved, *tangents)
这套 runtime 由 _aot_autograd/runtime_wrappers.py(3034 行)实现。它包了一层 autograd.Function,让编译产物对接到 PyTorch 的 backward Engine(第 8 章)—— 从用户角度看跟普通 module 一样,loss.backward() 自动触发 bw_module。
13.6.5 runtime_wrappers 的层叠
_aot_autograd/runtime_wrappers.py(3034 行)有十几个 Wrapper 类,每个负责一类问题。它们像洋葱一样层层包裹编译产物:
| Wrapper | 行号 | 职责 |
|---|---|---|
RuntimeWrapper | 165 | 最外层,分发到 forward / backward |
FunctionalizedRngRuntimeWrapper | 731 | 处理 RNG state(让随机算子也能 functionalize) |
FakifiedOutWrapper | 813 | output tensor 与 fake tensor 的对齐 |
AOTDispatchSubclassWrapper | 906 | tensor subclass(DTensor / FunctionalTensor)的 unwrap / rewrap |
EffectTokensWrapper | 993 | side-effect 算子(如 print / NCCL)的 token 串联 |
AOTDedupeWrapper | 1106 | 去重:同一张量作为多个 input 时合并 |
AOTSyntheticBaseWrapper | 1356 | view 合成 base:多个 view 共享 storage 时统一处理 |
AOTDispatchAutograd | 2303 | 把 fw + bw 包成 autograd.Function 接入 Engine |
DebugAssertWrapper | 2940 | debug 模式下的运行时断言 |
每一层解决一个具体 corner case。运行时调用 compiled_fn(*args) 时,从外到内穿过所有 wrapper:
flowchart LR
User[用户调用 compiled_fn]
User --> R1[RuntimeWrapper]
R1 --> R2[AOTDispatchAutograd<br/>包成 autograd.Function]
R2 --> R3[AOTDedupeWrapper<br/>去重]
R3 --> R4[AOTSyntheticBaseWrapper<br/>view 合并]
R4 --> R5[AOTDispatchSubclassWrapper<br/>unwrap subclass]
R5 --> R6[EffectTokensWrapper<br/>串 side effect]
R6 --> Real[真正的 fw_module / bw_module]
style R2 fill:#fef3c7,stroke:#f59e0b,stroke-width:2px
这种”每个 wrapper 一类问题”的设计是 AOTAutograd 6000+ 行代码却仍可读的关键。读源码时按层定位 —— 看到具体错误信息搜对应 Wrapper 类名能精准找到出问题的层。
13.6.6 AOTConfig:编译策略的中央配置
schemas.py:1066 的 AOTConfig 是 AOTAutograd 的配置 hub:
@dataclass
class AOTConfig:
fw_compiler: Callable # forward backend (通常是 Inductor)
bw_compiler: Callable # backward backend
partition_fn: Callable # 默认 min_cut_rematerialization_partition
decompositions: dict # ATen op 拆解规则 (大 op 拆成小 op)
num_params_buffers: int
aot_id: int
keep_inference_input_mutations: bool
dynamic_shapes: bool
# ... 几十个字段
几个特别值得注意的字段:
decompositions:把”难编译的大 op”拆成”易编译的小 op 组合”。比如aten::native_layer_norm默认拆成 mean + var + sub + div + mul + add 等。每个 backend 有自己的 decomposition table(Inductor 的、aot_eager 的不同)。decomp 让 AOTAutograd 给 backend 一致的”低级 IR”keep_inference_input_mutations:inference 路径下 input mutation 不需要 functionalize(反正不反向)。这个 flag 让 inference 编译跳过 functionalize 开销dynamic_shapes:shape 是符号还是具体值。影响整个 trace + partition + codegen
AOTConfig 像 OpenSSH 的 ssh_config:所有 knob 集中在一个对象,调编译策略时改这里。
13.6.7 min-cut 的 banned / recomputable 列表
partitioners.py:must_recompute 与相关函数控制哪些节点必须 recompute、哪些必须 save。具体规则(节选):
Banned (必须 save,不能 recompute):
aten.rand/aten.randn等随机算子(recompute 会得到不同值)aten.scatter_/aten.index_put_等 inplace 操作- 标记为
CheckpointPolicy.MUST_SAVE的(用户 activation_checkpoint 显式标记的) - 涉及 mutable state 的(如 BN 的 running_mean)
Force recompute (必须 recompute,不能 save):
aten.view/aten.expand/aten.permute等 view ops(save view 等于 save base,浪费)- 用户标记
CheckpointPolicy.MUST_RECOMPUTE的(手动标记某区段必 recompute) - 标记为
prims.convert_element_type之类的 trivial cast(重算几乎免费)
Negotiable (min-cut 算法决定):
- 大部分 GEMM、attention、norm 等”计算贵但输出大”的算子。min-cut 综合输入大小、输出大小、FLOPs 估算总代价决定 save 还是 recompute
这套 banned / recomputable / negotiable 三态分类让 min-cut 算法只在合理空间搜索,不会把 view 当 negotiable 浪费时间、也不会试图 recompute rand 破坏正确性。
源码里 BANNED_OPS_FOR_CHECKPOINT 集合 + must_recompute / must_be_saved 的位标记是这套三态机制的实现入口。
13.6.8 decomposition table:把大算子拆成小算子
AOTConfig.decompositions 是把 ATen “大算子”自动拆成若干”小算子组合”的规则表。源码主仓在 torch/_decomp/decompositions.py(5490 行),用 @register_decomposition(aten.xxx) 装饰器批量注册。
为什么要拆?因为 Inductor 想要看到的 IR 比 ATen 更”基础”。比如 aten.native_layer_norm 是一个复合算子(含 mean / var / sub / div / mul / add),如果直接送给 Inductor,它得给每个复合算子写一份 lowering 规则,工作量爆炸。decomposition 把 layer_norm 拆成几个简单 op 后,Inductor 只需要 lower 这几个简单 op,整个 layer_norm 自动用上 fusion / autotune 机会。
具体例子(精简版):
@register_decomposition(aten.native_layer_norm)
def native_layer_norm(x, normalized_shape, weight, bias, eps):
mean = torch.mean(x, dim=..., keepdim=True)
var = torch.var(x, dim=..., keepdim=True, unbiased=False)
rstd = torch.rsqrt(var + eps)
out = (x - mean) * rstd
if weight is not None:
out = out * weight
if bias is not None:
out = out + bias
return out, mean, rstd
——一行 aten.native_layer_norm 调用被拆成 8-10 个基础算子。Inductor 的 fusion 算法把这 10 个算子合并成 1 个 Triton kernel,最终的 GPU 行为与 PyTorch 内置的 fused layer_norm kernel 几乎等价、有时更快(因为 Inductor 能与周边算子继续 fuse)。
torch/_decomp/__init__.py:291 的 core_aten_decompositions() 返回 PyTorch 维护的”标准 decomp table” —— 大约 200+ 算子的拆解规则。AOTAutograd 默认用这套表。第三方后端可以只用部分 decomp 或加自家规则:
from torch._decomp import get_decompositions
# 我家硬件原生支持 layer_norm, 不要拆
my_decomps = get_decompositions([aten.relu, aten.silu, ...]) # 不含 layer_norm
decomp table 这套机制是 PyTorch 高层 API 与编译器后端解耦的关键:用户写 torch.layer_norm,AOTAutograd 拆成 primitives,每个后端选自家友好的拆法。这与第 5 章 §5.2.2 的 CompositeImplicitAutograd 思想一致 —— 但前者是编译期拆、后者是运行期拆。
理解了 decomp table,你就理解了为什么 Inductor 能在不写”layer_norm 专用 codegen”的前提下生成高质量的 layer_norm kernel —— 拆 + fusion 自动覆盖。这条思想是编译器栈”少写专用代码、多用通用算法”的工程范本。
13.6.9 三种 dispatch 入口:autograd / export / inference
AOTAutograd 不是单一入口,它有三种主要 trace 模式:
aot_dispatch_autograd:训练路径。trace forward + backward joint graph,partition 后给 Inductor。最常用aot_dispatch_inference:推理路径。只 trace forward(不需要反向),跳过 partition。第 15 章 §15.6.7 的 AOTI 走这条aot_dispatch_export:torch.export路径。trace 出严格静态 graph,不允许 graph break、不依赖 dynamic shape
三者共享 95% 代码,差异在最后阶段:
| 阶段 | autograd | inference | export |
|---|---|---|---|
| trace joint | ✓ | ✗ (只 forward) | ✗ (只 forward) |
| functionalize | ✓ | ✓ | ✓ (更严格) |
| min-cut partition | ✓ | ✗ (无 backward 可分) | ✗ |
| 接 autograd.Function | ✓ | ✗ | ✗ |
| dynamic shape | 支持 | 支持 | 严格静态 |
理解三种入口让你看 PyTorch 源码时不困惑:相同 trace 流程会被三个 dispatch 函数复用,每个有自己的 entry。第 12 章 §12.7 提过 OutputGraph.compile_subgraph 的 backend 调用就在 autograd / inference 间选择。
13.6.10 FunctionalTensor 与 FunctionalTensorMode
§13.4 提了 functionalization 用 FunctionalTensorMode 拦截 inplace。具体实现:
torch/_subclasses/functional_tensor.py:59 的 FunctionalTensor 是 torch.Tensor 子类,每个实例包一个底层 tensor + 一个”version 编号”。inplace op(如 x.add_(y))被 FunctionalTensorMode(:335)拦截:
# 简化版拦截逻辑
class FunctionalTensorMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args, kwargs):
if is_inplace(func):
# 把 inplace 改成 out-of-place
new_value = ops.add(x, y) # 不 inplace
x._update(new_value) # 更新 wrapper 持有的底层
return x
...
x._update(new_value) 让 FunctionalTensor 看起来像被 inplace 修改(x.data_ptr() 仍能取,但底层张量已经换了)。外部代码看到的语义与原始 inplace 一致,但内部已经是纯函数式。
这套机制的工程精妙之处:用户原本写 x.add_(y) 期望”x 被原地改”。functionalize 后底层 ATen 看到的是 x = x.add(y),但 x 这个 wrapper 仍指向更新后的值 —— 用户层面没感知。
13.6.11 copy_back 机制:functionalize 后处理 mutation
functionalize 把 inplace 全改 out-of-place,但有些场景用户确实需要 mutation 生效(如训练时 BN 的 running_mean.copy_(new_mean)、optimizer 的 param.data.copy_(...))。
AOTAutograd 的解法:在 graph 末尾插入 copy_ 节点把 functionalize 后的新值写回原位置:
原始: param.data.add_(grad, alpha=-lr)
functionalize 后:
new_param = param.add(grad, alpha=-lr) # out-of-place
... (更多计算)
param.copy_(new_param) # 末尾的 copy_back
copy_back 让”用户原本期望的 inplace”在 graph 末尾生效。Inductor 看到这种 copy_ 时不会再 functionalize(已经在 graph 末尾、没下游使用)—— 直接生成内存拷贝指令。
AOTConfig.keep_inference_input_mutations(§13.6.6)控制是否保留这些 copy_back。inference 路径下用户不能依赖 mutation 生效,所以这个 flag 默认 False、跳过 copy_back 节省一次 copy。训练路径必须保留。
13.6.12 min-cut 算法的网络流实现
§13.5 讲了 min-cut 的”图论建模”。具体怎么解?partitioners.py 用 NetworkX 库的 max_flow 算法。
建图:
- 源 source 节点:joint graph 的 forward 输入
- 汇 sink 节点:backward 输入(即”saved tensors 区域”)
- 中间节点:每个 forward op 的输出
- 边 weight:=保留这个 tensor 的 byte 数(rematerialize 时不存就要重算)
最小割集:所有”被切的边”对应的 tensor 必须保留给 backward。NetworkX 的 minimum_cut 用 Edmonds-Karp 算法(多项式时间复杂度)。
实战上 min-cut 算几百节点的图(典型 transformer block)只要几十毫秒。整个 70B model 的 partition 总耗时几秒,与 trace + functionalize 时间相比可忽略。
如果你想看 min-cut 的具体决策,开 TORCH_LOGS=joint_graph,partitioner 会打印每个 tensor 的”保留 vs 重算”决策与 cost 估算。这是调试 activation_checkpoint 自动决策的金钥匙。
13.6.13 AOTAutograd × dynamic shape
第 12 章 §12.5 讲了 Dynamo 的 guards 处理 dynamic shape。AOTAutograd 这一层支持是用 SymInt 在 fake tensor trace 中传递:
# Dynamo 看到 batch dim 是 SymInt s0
input.shape = (s0, 768)
# AOTAutograd trace 时 fake input 也是 (s0, 768)
fake_input = FakeTensor(shape=(s0, 768), ...)
# trace 出来的 fx graph 里所有 shape 表达式都是 SymInt
output = forward_graph(fake_input)
# output.shape = (s0, 768) # 仍是符号
# min-cut partition 在符号 shape 上做 cost 估算
cost = SymInt_to_int_estimate(output.shape) * dtype_size
代价:cost 估算是”近似”(SymInt 不是具体值),partition 决策可能不是真实最优。但实战上这种近似足够好 —— LLM 训练的 dynamic 范围通常是 batch / seq_len,每个具体值的最优 partition 差异不大。
AOTConfig.dynamic_shapes(§13.6.6)开启后整条链路用 SymInt;关闭后用具体值(每个具体 shape 重新 trace + partition)。生产推荐开启 dynamic_shapes —— shape 范围有限时缓存几个 graph 配置就够。
13.6.14 forward 输出的打包:saved tensors 与 user outputs
trace 完 forward 后,输出包含两类张量:用户实际要的 outputs + backward 需要的 saved tensors。AOTAutograd 把它们一起返回:
# fw_module 的实际输出
fw_outputs = fw_module(*primals)
# = (user_output_0, user_output_1, ..., saved_tensor_0, saved_tensor_1, ...)
AOTConfig 里的 num_user_outputs / num_saved 标记前几个是用户、后几个是 saved。runtime_wrappers (§13.6.5) 解包:
real_outputs = fw_outputs[:num_user_outputs] # 给用户
saved = fw_outputs[num_user_outputs:] # 给 backward
把 saved 与 user output 同一次 forward kernel 计算让二者共享中间状态、避免重复 kernel launch。Inductor 在 codegen 时也按这套语义生成 wrapper。
13.6.15 AOTAutograd × autocast
第 20 章讲过 autocast。AOTAutograd 怎么处理用户的 with autocast(...): out = model(x)?
机制:autocast 上下文在 trace 时已经被 Dynamo 捕获(作为 thread-local state),AOTAutograd 拿到的 fx graph 里每个 op 已经是 autocast 后的 dtype。functionalize 与 partition 都在 dtype-correct 的图上做,不需要二次处理。
但有个微妙点:autocast 下的 mm 算子是 fp16,但保存给 backward 的 mat2 仍要 fp32(精度需求)。AOTAutograd 在保存 SavedVariable 时显式 cast 回 fp32 再保存。这套”forward 用 fp16、backward 中间值 fp32”的精度控制是 mixed precision 训练正确的关键。
13.6.15.5 AOTAutograd 与 functorch transform 的协作
第 7 章 §7.15 提过 functorch 的 vmap / grad / jvp 函数变换。AOTAutograd 与这些变换可以叠加:
import torch
from torch.func import vmap, grad
@torch.compile
def per_sample_grad(x_batch, y_batch):
return vmap(grad(loss_fn))(x_batch, y_batch)
这段代码里:
vmap(grad(...))是双重函数变换 —— 每个样本独立算 grad@torch.compile编译整段- AOTAutograd 看到 vmap 与 grad 是 HigherOrderOperator(§13.7.9.5),递归 trace 它们的内部
- 最终编译产物里 vmap 被展开成 batched op,grad 被展开成 backward 链
实测 per-sample gradient(用于 differential privacy 训练)通过这条路径能拿到接近朴素循环 100x 的加速。AOTAutograd 把”复杂的函数变换”变成”普通可编译 op”,让用户能用 JAX 风格 API + PyTorch eager 体验 + Inductor 性能。
13.6.16 性能数字:AOTAutograd 在编译时间里的占比
完整 torch.compile 链路时间分解(70B Llama,第一次 forward):
| 阶段 | 时长 | 占比 |
| Dynamo trace | 2s | 10% |
| AOTAutograd trace + functionalize | 3s | 15% |
| AOTAutograd partition | 1s | 5% |
| Inductor lowering | 2s | 10% |
| Inductor scheduling | 2s | 10% |
| Triton compile (并发, async) | 10s | 50% |
| 总编译时间 | 20s | 100% |
AOTAutograd 占 30% 编译时间。这是它在编译栈里的”位置成本” —— 不是最重的(Triton 编译最贵),但显著。FxGraphCache(§14.6.12)让重启时这部分时间归零,是生产服务必开。
13.6.17 一个完整 trace 例子:F.linear(x, w, b).relu()
把整章串起来看一段简单代码的 AOTAutograd trace:
@torch.compile
def f(x, w, b):
return F.linear(x, w, b).relu()
AOTAutograd 收到 fx graph 后:
- trace joint(forward + backward):
primals = [x, w, b] (x 不需要 grad, w/b 需要)
tangents = [grad_output] # 反向输入
def joint(primals, tangents):
linear_out = primals[0] @ primals[1].T + primals[2] # forward
relu_out = relu(linear_out)
grad_relu = tangents[0] * (relu_out > 0) # relu backward
grad_linear = grad_relu # 梯度透明传递
grad_x = grad_linear @ primals[1] # x 的梯度
grad_w = grad_linear.T @ primals[0] # w 的梯度
grad_b = grad_linear.sum(0) # b 的梯度
return relu_out, grad_x, grad_w, grad_b
-
functionalize:原始 graph 没 inplace,跳过
-
partition (min-cut):哪些中间值保留给 backward?
linear_out:保留(relu_backward 要用,重算贵)relu_out:保留(既是用户输出也是 saved)- 决策:保留
linear_out+relu_out
-
拆出 fw / bw:
fw_module: primals → (relu_out, [linear_out, relu_out for save])
bw_module: (saved_linear_out, saved_relu_out, tangents) → (grad_x, grad_w, grad_b)
- 送 Inductor:fw_module 编成一个 Triton kernel(linear + relu fuse),bw_module 编成另一个
整套流程几百毫秒完成。用户写 3 行 Python,AOTAutograd 自动展开成正反向 trace + 分片 + 编译,是 PyTorch 训练加速最复杂的中间环节。
13.6.17.5 trace 阶段的 fake_mode 详解
§13.3 提了”用 fake tensor 跑 autograd.grad”。具体怎么实现?
FakeTensorMode 是个 TorchDispatchMode(第 5 章 §5.7)。进入 mode 后所有算子调用被拦截:
with FakeTensorMode() as fake_mode:
fake_x = fake_mode.from_tensor(real_x) # 把真实 tensor 转成 fake
fake_y = fake_x + 1 # 算子被拦截, 只算 shape, 不真做加
print(fake_y.shape) # OK, 但 fake_y.data 不存在
AOTAutograd trace 时全程在 FakeTensorMode 里。fake_y 看起来像 tensor 但只有 shape / dtype / device 元信息 —— 没分配显存、没真做计算。这让 trace 70B model 在小显存机器上也能跑(trace 不需要装下完整模型)。
autograd.grad(fake_loss, fake_inputs) 在 fake tensor 上跑,触发反向链 trace。反向规则的”中间值”也是 fake,最终 fx graph 上每个节点是一个 fake op。
整套 fake trace 让”先 trace 出 graph、再用真实 tensor 编译运行”成为可能。这是 AOTAutograd 能处理超大模型的根本。
13.6.18 flat_args 与 FlatArgsAdapter:嵌套结构展平
用户的 forward 输入可能是嵌套结构(dict / list / tuple):
def model(x, kwargs={'attn_mask': mask, 'cache': (k, v)}):
...
AOTAutograd / Inductor 内部处理纯 tensor 列表 —— 嵌套结构必须先 flatten。schemas.py 的 FlatArgsAdapter 做这件事:
flat_args = [x, mask, k, v] # 扁平 tensor 列表
spec = TreeSpec(...) # 描述如何重建嵌套
# trace 时 forward 接 flat_args, 内部用 spec unflatten 还原嵌套
def fw_module(*flat_args):
x, mask, k, v = flat_args
kwargs = {'attn_mask': mask, 'cache': (k, v)}
return model(x, **kwargs)
pytree(PyTorch 的嵌套结构工具,torch/utils/_pytree.py)提供 flatten / unflatten 实现。这套机制让”任意嵌套 input”都能进编译路径,不需要用户改 forward 签名。
13.6.19 synthetic_base:处理 view 别名
AOTSyntheticBaseWrapper(§13.6.5 提过)处理一个微妙场景:多个 input 张量是同一个 base 的不同 view:
def f(x, y):
# 用户传入 x = base[0:5], y = base[5:10] (共享 storage)
...
f(base[:5], base[5:]) # 两个 input 共享 storage
如果 trace 时把 x / y 当独立 tensor,functionalize 后两者的 mutation 不会互相影响 —— 但真实运行时它们共享 storage、修改 x 会影响 y。Inductor 编译产物在这种场景下行为错误。
AOTSyntheticBaseWrapper 检测这种共享 base 场景,把多个 view input 合并成”传 base + 多个 slice 偏移”。trace 时 base 是 leaf input,view 在 graph 内部用 slice op 重新构造 —— functionalize 与 mutation 分析在 base 层面正确进行。
这套机制让”用户传 view 给 compiled function” 在边角情况下仍正确。但额外开销让性能略降,所以 AOTConfig 有 flag 控制是否启用此优化(默认开)。
13.6.20 AOTDedupeWrapper:相同张量去重
另一个 corner case:用户把同一个张量作为多个 input 传:
def f(x, y):
return x + y
result = f(t, t) # x 和 y 是同一对象
trace 时 graph 里会同时引用 t 两次。AOTDedupeWrapper(runtime_wrappers.py:1106)检测这种情况,trace 时把 t 只列一次输入:
# trace 后的简化 graph
def f_traced(t): # 只一个 input
return t + t # 同一 input 用两次
去重让生成的 graph 更简洁、Inductor 编译产物对相同张量的处理一致。这个 wrapper 在 RNN 等”weight 在多 step 复用”场景下尤其有用。
13.6.21 RNG 处理:让 functionalize 与重算一致
§13.6.7 的 banned ops 列表里有 aten.rand 等随机算子。但这只解决了”min-cut 不重算 rand” —— 真正的 RNG 处理更复杂。
问题:activation_checkpoint 在反向时重 forward 取 activation。如果 forward 包含 dropout,第二次重 forward 时 RNG state 不同,得到的 dropout mask 也不同 —— 梯度计算错乱。
AOTAutograd 的解法:RNG state 也作为 saved tensor。第一次 forward 时记录 dropout 调用前的 RNG state,反向重 forward 时先恢复 RNG state,保证两次 forward 的随机数一致。
fx_passes/replace_random.py(§14.6.5 的 Pass)把 aten.rand 替换成 aten.rand_with_seed(seed),让 seed 显式作为输入。这样 functionalize / partition 都能在符号 RNG 上进行。
理解这条机制让你看到 torch.compile 下的 dropout 与 eager 模式行为完全一致 —— 不会因为 checkpoint 重算引入随机数不一致。
13.6.22 AOTAutograd 历史演进:从 functorch 到主仓
AOTAutograd 不是从零写的,是 functorch 项目(2022 年发布)的核心组件之一。functorch 借鉴 JAX 的函数变换思想(vmap / grad / jvp),把 PyTorch 的 autograd 改造成”函数式” —— 让”对函数求导”成为编译期可处理的对象。
时间线:
- v1.13 (2022 末):functorch 作为独立包发布,含 AOTAutograd 雏形
- v2.0 (2023):functorch 整合进主仓
torch._functorch,AOTAutograd 成为 torch.compile 的核心组件 - v2.4 (2024):AOTAutograd 与 DTensor / FSDP-2 / export 深度集成
- v2.6+:Compiled Autograd —— 反向也能被 Dynamo trace + Inductor 编译(不只是被 partition 后送 Inductor)
理解这条演进让你看 torch._functorch 模块的命名时不困惑:functorch 这个老名字保留下来作为 namespace,但实际功能远超原始 functorch(vmap / grad)—— 它是整个 AOT 编译框架的代号。
13.6.23 AOTAutograd × Compiled Autograd
v2.4+ 引入的新机制:把反向 Engine(第 8 章)也 trace 进编译路径。普通 AOTAutograd 把 backward 拆成 bw_module 编译,但调用 bw_module 仍由 Engine 调度。Compiled Autograd 让 Engine 调度本身被 Dynamo trace、Inductor 编译。
import torch._dynamo
torch._dynamo.config.compiled_autograd = True
# 之后 loss.backward() 的整个调度也被编译
收益:消除 Engine 调度开销(第 8 章 §8.12.5 的 ~2ms / step)。对小模型反向(如 LSTM)能再加速 30-50%。
实现机制:Compiled Autograd 用 torch._functorch.compiled_autograd 拦截 Engine 的 task 调度,把每个 backward node 的 apply 调用 trace 成 fx graph。
这是 PyTorch 编译器栈的”最后一块拼图” —— 之前 forward / backward / optimizer step 都能编译,现在 Engine 调度也加入。理解了 §13 + §8 + §14,你就理解了为什么 Compiled Autograd 是自然演进的下一步。
13.6.24 AOTAutograd 错误诊断速查表
| 症状 | 可能原因 | 诊断 |
|---|---|---|
| trace 时报”some tensor was modified inplace but …“ | functionalize 失败 | 自定义算子缺 register_fake |
| backward 数值与 eager 不一致 | RNG 处理出错 | 检查 dropout / random op 是否被正确 functionalize |
| 编译后 OOM 但 eager 不会 | min-cut 决策不优 | 调小 unit 粒度或显式 activation_checkpoint |
| 看到 “input is a view of …” 警告 | synthetic_base 触发 | 改用 .clone() 让 input 独立 |
| AOTAutograd trace 时间长 | model 太大 + min-cut 复杂 | 用 fx_graph_cache 缓存(§14.6.12) |
把这套表配合 TORCH_LOGS=aot,aot_graphs,partitioner 用,能快速定位 AOTAutograd 阶段的问题。
13.6.25 AOTAutograd 中的”trace 的 trace”
AOTAutograd 的 trace 阶段实际上调用了 PyTorch 的 dispatcher trace 机制(第 5 章 §5.7 TorchDispatchMode)—— 用 ProxyMode 把每次算子调用记录成 fx Node。这与 Dynamo trace 是两个不同层面:
- Dynamo trace:在 Python 字节码层,把 Python 函数翻译成 fx Graph
- AOTAutograd trace:在 dispatcher 层,把 fx Graph 上每个 ATen op 通过 fake tensor + ProxyMode 重新 trace 成更底层的 fx Graph(含反向)
所以 AOTAutograd 是 “trace 的 trace” —— Dynamo 的输出是 AOTAutograd 的输入,AOTAutograd 再往下 trace 一层。这种”多层 trace 串联”是 PyTorch 编译器栈的核心设计。
13.7 几个 corner case
AOTAutograd 复杂度的来源是无数 corner case:
- Tensor subclass:FakeTensor、TwoTensor、自定义 subclass 的 trace 路径
- View ops:
x.view(...)在 functionalize 时要重新计算 view 关系 - AMP autocast:joint graph 要保留 amp 上下文,让 backward 用同样的 dtype 策略
- Checkpoint interaction:用户已经在用
torch.utils.checkpoint时,AOTAutograd 要识别并尊重 checkpoint 边界 - Inplace foreach:optimizer 的
_foreach_*inplace 集合操作要特殊 functionalize - Side-effect ops:rng / print / collective ops(NCCL)不能被随便 reorder
每个 corner case 在源码里都有专门的 if 分支。这是 AOTAutograd 6000+ 行的来源 —— 绝大多数代码不在描述编译算法,而在处理边界情况。
13.7.5 effect tokens:side effect 算子的顺序保留
某些算子有 side effect:print / NCCL collective / 某些 stateful op(如 RNG state 的 reads)。functionalize 不能简单消除它们,否则程序行为变化。
EffectTokensWrapper(§13.6.5)的解法:把 side effect 表达成 token 依赖链:
op_with_effect_1 → token_1
↓
op_with_effect_2 → token_2
↓
op_with_effect_3 → token_3
每个 side effect op 接收 input token、产出 output token。后续 op 必须等前面的 token —— 形成显式依赖链。这套机制让 functionalize 能”假装这些 op 是纯函数”(只是依赖 token),但实际执行顺序与原始一致。
实战影响:collective ops(如 all_reduce)在 AOTAutograd 路径上通过 effect tokens 保持顺序。FSDP-2 的 collective + compute 重排(§14.6.16 reorder_communication_preserving_pin)就建立在这套机制上 —— 重排不能跨越 token 边界。
13.7.6 fx graph 上的 CSE 与简化
trace 完成后 fx graph 可能含大量冗余(同一个表达式重复算)。AOTAutograd 在送给 Inductor 前跑几个简化 pass:
- CSE (Common Subexpression Elimination):相同表达式只算一次(与第 14 章 §14.6.16 的 IndexExprCSE 不同 —— 这是 fx 层的、那是 Inductor IR 层的)
- dead code elimination:没用到的中间值删掉
- constant folding:编译期能算的常量提前算
- algebraic simplification:
x * 1 → x、x + 0 → x等
这些 pass 减小 graph 大小,让后续 Inductor 工作量降低。整套简化在 torch._functorch._aot_autograd.utils 里实现。
CSE 在 AOTAutograd 这层做的好处是:functionalize 后会引入大量重复(如多个 view 都涉及 base + offset 计算),CSE 立即消除让送给 Inductor 的图清爽。
13.7.7 AOTAutograd 与 Hooks 的协作
第 7 章 §7.5.3 / 第 9 章 §9.8 提过 hooks(saved_tensors_hooks / forward_hook / backward_hook 等)。AOTAutograd 编译路径下这些 hook 怎么处理?
- saved_tensors_hooks (activation_checkpoint):trace 时被识别 + 直接转成 functionalize 路径,hook 自身不出现在 graph 里
- forward_hook / backward_hook:trace 时 Dynamo 把 hook 调用 inline 进 graph(被当成普通函数调用 trace)—— 编译产物里 hook 自动展开
register_post_accumulate_grad_hook:FSDP / DDP 用,AOTAutograd 不直接支持 —— 这些 hook 在 graph 之外触发,与编译路径协作通过 collective op 与 functional collectives(第 16 章 §16.7.9)
理解 hook 在编译路径下的处理让你看到”用户写的 hook 在 compile 后是否还生效”。多数 hook 仍生效(trace 时被 inline);少数与训练循环协作的 hook(如 DDP)通过专门机制保留。
13.7.8 一段 LLM Attention 的 AOTAutograd trace 简要
把 §13.6.17 的简单例子升级到 LLM attention:
@torch.compile
def attention(q, k, v, mask):
# q, k, v: [B, H, S, D]
scores = q @ k.transpose(-2, -1) / sqrt(D) # [B, H, S, S]
scores = scores + mask
attn = F.softmax(scores, dim=-1)
out = attn @ v # [B, H, S, D]
return out
AOTAutograd 处理:
- trace joint:捕获 forward + backward。backward 中 attention 反向涉及 softmax 反向(complex jacobian)
- functionalize:scores 的 inplace add 被改成 out-of-place
- partition (min-cut):决定保留
attn(softmax 输出)给 backward 用 —— 这是 LLM 训练 activation 的主要内存压力来源 - Inductor:识别 SDPA pattern + 替换为
aten._scaled_dot_product_flash_attention(如果 shape 满足条件)
最终编译产物里 attention 整段被 FlashAttention kernel 替代,反向也用 FlashAttention 的反向 kernel。用户写朴素 attention 实现就拿到 SOTA 性能 —— 这是 AOTAutograd + Inductor 配合的典型工程价值。
13.7.9 AOTAutograd 与第三方 backend 的协作
§12.7 提过 aot_eager / cudagraphs / inductor 三种 backend。AOTAutograd 是这些 backend 的”上游” —— trace 完后把 graph 给 backend:
# 用户调 torch.compile(model, backend='inductor')
# Dynamo 输出 forward graph → AOTAutograd trace + partition →
# 把 fw_module 与 bw_module 分别送给 Inductor compile_fx_inner
第三方 backend 接入流程:
@torch._dynamo.register_backend
def my_backend(gm, example_inputs):
# gm 是 GraphModule (AOTAutograd 已经 trace + partition 后的 fw_module 或 bw_module)
return my_compiler(gm) # 用户的 compile 函数
# 注册后用
model = torch.compile(model, backend='my_backend')
这套接入路径让国内厂商能给 PyTorch 加自家 backend,不需要修改主仓代码。AOTAutograd 给所有 backend 提供”标准化的可编译 IR”(functionalize + partition 后的 fx graph),是 PyTorch 编译器生态的工程基石。
13.7.9.5 AOTAutograd × Higher Order Operator (HOP)
第 14 章 §14.6.14 提过 HigherOrderOperator —— 算子的”参数”是另一段计算图(如 cond / while_loop / flex_attention)。AOTAutograd 处理 HOP 时有特殊逻辑:
- 不直接 functionalize HOP 内部:HOP 的”子图”作为黑盒处理,trace 时保留为单个节点
- 递归 trace 子图:每个 HOP 的 subgraph 单独走一次 AOTAutograd,产出独立的 fw / bw 子图
- token 串连主图与子图:通过 effect tokens 让 HOP 调用与主图节点的执行顺序正确
HOP 让 PyTorch 能 trace “动态控制流”代码(如 if cond: branch_a else branch_b),同时让每个分支被独立编译。这是与 JAX 的 lax.cond / lax.while_loop 思想一致 —— 控制流作为 first-class IR 节点。
实战:FlexAttention(自定义 attention pattern)就是用 HOP 实现的。用户写 score_mod 函数描述自定义 attention 行为,HOP 把它 trace 成子图、AOTAutograd 给它产出反向、Inductor 编成 Triton kernel。整套机制让自定义 attention 享受与内置 SDPA 同等的编译优化。
13.7.10 AOTAutograd 与 export 的特殊处理
torch.export 路径走 aot_dispatch_export(§13.6.9)。它与 autograd / inference 路径的关键差异:
- 不允许 graph break:export 要求”严格静态 graph”,遇到 unsupported op 直接报错,不能 fallback eager
- 不允许 dynamic shape(除非用户显式
dynamic_shapes={...}标记) - 不 trace backward:export 的目标是部署,不是训练
- 保留更多语义信息:export 输出的 ExportedProgram 含 module hierarchy、原始 Python 调用栈,比纯 fx graph 信息丰富
ExportedProgram 是 v2.x 部署的核心 IR。AOTI(第 15 章 §15.6.7)从 ExportedProgram 出发编译。HuggingFace 推理服务、移动端部署都基于 export。理解 AOTAutograd 在 export 路径上的角色让你能跟踪”训练时的 model 怎么变成部署时的 .so”完整链路。
13.7.11 AOTAutograd 的并发与多线程
AOTAutograd trace 本身是单线程(CPython GIL 限制),但它产出的 graph 编译阶段(Inductor)是多进程的(§14.6.7 AsyncCompile)。整体编译时间分布:
- AOTAutograd trace + functionalize:单线程,几百 ms 到几秒
- AOTAutograd partition:单线程,几百 ms
- Inductor lowering + scheduling:单线程,几秒
- Triton 编译:多进程并发,几秒到几十秒
整体编译时间瓶颈在 Triton(§13.6.16 表格)。AOTAutograd 的单线程性质不是问题 —— trace 阶段比 Triton compile 快得多。
未来方向:AOTAutograd trace 可能也并行化(用 thread pool 同时 trace 不同 subgraph),但实现复杂度高、收益相对小。当前设计够用。
13.7.12 AOTAutograd 在 PyTorch 演进路线上的位置
AOTAutograd 在 PyTorch 整体演进中的位置:
graph LR
Eager[v1.x eager]
JIT[v1.x TorchScript]
Functorch[v1.13 functorch<br/>函数变换]
Compile[v2.0 torch.compile<br/>= Dynamo + AOTAutograd + Inductor]
Compiled[v2.6+ Compiled Autograd]
Future[v3.x ?<br/>更激进编译 / DTensor 全自动并行]
Eager --> JIT
Eager --> Functorch
Functorch --> Compile
Compile --> Compiled
Compiled --> Future
style Compile fill:#fef3c7,stroke:#f59e0b,stroke-width:2px
AOTAutograd 是 v2.0 编译栈的核心 —— 没有它就没有 torch.compile 的工程实现。理解 AOTAutograd 让你能解读 PyTorch 团队公开的 RFC、跟上未来演进路线。
13.8 几条工程经验
实战 AOTAutograd 相关:
1. 调试时用 aot_eager backend:torch.compile(fn, backend='aot_eager') 跳过 Inductor,只跑 AOTAutograd。如果这步 OK 但 Inductor 阶段错,就能快速定位
2. TORCH_LOGS=aot,aot_graphs 看到 trace 出来的 joint graph + 切分结果:知道哪些张量被保留、哪些被重算
3. 有的算子 functionalize 失败:通常是用户写的自定义算子缺 register_fake。第 22 章自定义算子会展开
4. recompute_max_memory 控制重算上限:通过 inductor.config 调整 partition 偏好(更多重算 vs 更多显存)
5. 与 torch.utils.checkpoint 的协作:AOTAutograd 优先尊重用户标记的 checkpoint 区段,min-cut 只在这些区段内部决策
13.8.5 AOTAutograd 在用户代码里的”可见性”
普通用户写 @torch.compile 时不知道 AOTAutograd 的存在 —— 它对用户透明。但有几个 API 让用户能间接感知 / 控制:
torch._dynamo.export(model, ...)内部走aot_dispatch_exporttorch.export.export(model, ...)是_dynamo.export的高级封装torch._functorch.aot_function(fn, fw_compiler, bw_compiler)直接暴露 AOTAutograd 接口(不经 Dynamo)—— 用于研究 / 调试AOTConfig可以通过torch._inductor.config间接控制(如decomposition_freezing)
这些 API 让”想直接玩 AOTAutograd”的高级用户有入口。生产代码用 @torch.compile 即可,不需要直接接触 AOTAutograd —— 这本身就是它的设计目标:对常规用户透明、对高级用户可控。
13.9 跨书关联
- 《Rust 编译器之路》编译器中端:AOTAutograd 的 functionalize → partition → optimize 与 LLVM 的 SSA → mem2reg → CFG simplification 是同一思想 —— 把 mutation 化简、再做图变换
- 《Serde 元编程》函数式派生:functionalization 让算子变纯函数,与 Serde 的”派生宏生成纯函数式 serde”思路相通
- 第 7 章 §7.5.3 saved_tensors_hooks:AOTAutograd 的 saved tensors 与 eager autograd 的 SavedVariable 是相同概念在两个执行模式下的对应
13.9.5 AOTAutograd 的”3 段式”工作流总览
把整章串起来,AOTAutograd 内部是清晰的”3 段式”流程:
flowchart TB
Input[Dynamo 输出: forward FX Graph + Guards]
Input --> S1[段 1: trace<br/>用 fake tensor + autograd.grad<br/>产出 joint forward+backward graph]
S1 --> F1[FunctionalTensorMode 拦截 inplace]
F1 --> S2[段 2: rewrite<br/>functionalize + decomp + dedup<br/>产出纯函数式 IR]
S2 --> S3[段 3: partition<br/>min-cut 切 fw / bw<br/>用 SymInt 估算 cost]
S3 --> O[输出: fw_module + bw_module + saved_indices]
style S1 fill:#fef3c7
style S2 fill:#dbeafe
style S3 fill:#dcfce7
每段对应一组源码文件:
- 段 1:
_functorch/aot_autograd.py的create_aot_state+_functorch/_aot_autograd/dispatch_and_compile_graph.py的 trace 函数 - 段 2:
_subclasses/functional_tensor.py+_decomp/decompositions.py+_aot_autograd/runtime_wrappers.py的 dedupe / synthetic_base - 段 3:
_functorch/partitioners.py的 min-cut 算法
这种”段-段-段”工作流让 AOTAutograd 的代码组织清晰、错误能精确定位到段。
13.9.6 设计层面的 AOTAutograd vs JAX
AOTAutograd 的思想直接受 JAX 影响。对比:
| 维度 | JAX | AOTAutograd |
|---|---|---|
| trace 时机 | jit 时 | torch.compile 第一次调用时 |
| autograd 风格 | 函数变换 (grad(f) 返回新函数) | 命令式 (loss.backward()) |
| inplace 处理 | 不允许 inplace | functionalize 后允许 |
| 输入 shape | trace 时绑定(dynamic 用 abstract shape) | SymInt 支持 dynamic |
| 编译后端 | XLA | Inductor |
JAX 的”函数式”哲学让 AOTAutograd 设计成立 —— 没有 functionalize 这套机制,PyTorch 的 mutation-heavy 代码根本编不动。AOTAutograd 在保留 PyTorch 命令式 API 的前提下,用 functionalize 做”内部函数式”,是连接 eager / compile 两种范式的关键工程胶水。
理解这条 JAX → AOTAutograd 演进让你能预判:未来 PyTorch 若引入更激进的”函数式优化”(如自动 vmap / grad 组合),AOTAutograd 是天然落地点。
13.9.7 AOTAutograd 在 LLM 训练里的实际收益
具体数字(H100,70B Llama 训练):
| 配置 | 单 step 时间 |
| pure eager (无 compile) | 5000 ms |
| compile (Dynamo only) | 4500 ms | ← 减少 dispatcher 开销
| + AOTAutograd partition | 3500 ms | ← min-cut activation_checkpoint
| + Inductor fusion | 2800 ms | ← 算子融合
| + CUDA Graph (reduce-overhead) | 2500 ms | ← 消除 launch overhead
AOTAutograd 单独贡献约 20% 加速,主要来自 min-cut 自动 activation_checkpoint 优化。这部分省下的不是 compute(重算反而多)—— 是显存峰值,让用户能开更大 batch_size、训练吞吐间接提升。
理解 AOTAutograd 的工程价值不是”它直接让训练快多少”,是 它让”自动决定 activation_checkpoint 粒度”成为可能。手动 checkpoint 调到最优是几小时工作 + 容易错;AOTAutograd 自动决策几乎无需调参。
13.9.8 AOTAutograd × profiler
第 21 章讲过 PyTorch profiler。AOTAutograd 编译产物在 profiler 里有特殊处理:
- 编译产物的 fw / bw 在 trace 里显示为单个事件(
CompiledFunction.forward) - 内部 Triton kernel 显示为
triton_per_fused_xxx(第 14 章 §14.9.5.5) - AOTAutograd 自身的 trace 阶段(编译时)不在训练 trace 里(编译完成后才训练)
如果你的训练慢,profiler 看到 AOTAutograd 相关事件占比大 —— 那是编译时间没缓存(FxGraphCache 没命中)。开启 TORCHINDUCTOR_CACHE_DIR=... 持久化 cache 通常能解决。
13.9.9 AOTAutograd 不能做的事
明确 AOTAutograd 的边界:
- 不处理纯 Python 控制流:那是 Dynamo 的职责,AOTAutograd 拿到的已经是确定的 fx graph
- 不直接生成 kernel:那是 Inductor 的职责,AOTAutograd 输出的是 fx graph,不是 Triton 代码
- 不优化单算子内部:算子级优化(如 fp16 GEMM 是用 cuBLAS 还是 Triton template)由 Inductor 决定
- 不处理跨进程通信:那是 ProcessGroup 与 functional collectives 的职责(第 16 章),AOTAutograd 把 collective 当作普通算子(带 effect token)
明确边界让你不会”在 AOTAutograd 找其他模块的功能”。当编译某段代码出错时,先判断是哪个模块的问题:Dynamo(trace 失败 / graph break)、AOTAutograd(functionalize / partition 错)、还是 Inductor(codegen 失败)。
13.9.10 AOTAutograd × FSDP-2 的协作
第 18 章 §18.6.17 提过 FSDP-2 与 torch.compile 兼容好。具体到 AOTAutograd 这层:
- DTensor 是 first-class:AOTAutograd 看到 DTensor 时知道它是 placement-aware tensor,trace 时保留 placement 信息
- functional collectives 在 trace 中保留:FSDP-2 的 AllGather / ReduceScatter 用 functional collectives(第 16 章 §16.7.9),AOTAutograd 把它们当作普通算子加进 graph
- Effect tokens 维持顺序:collective 通过 token 链保持执行顺序,partition 不会重排
这套机制让 FSDP-2 训练 graph 能完整经过 AOTAutograd → Inductor 编译。FSDP-1 的 FlatParameter 在 AOTAutograd 路径上引发各种特殊处理,是它与 compile 兼容性差的根本。FSDP-2 用 DTensor 让 AOTAutograd 不需要”识别 FSDP” —— 普通 trace 流程就能正确处理分片张量。
13.9.11 整章信息密度小结
读完 Ch 13 你应该能:
- 理解:AOTAutograd 在编译栈的位置(连接 Dynamo 与 Inductor)
- 追踪:joint trace → functionalize → partition 三段式工作流
- 决策:何时用 dynamic shape、何时手动 activation_checkpoint vs 让 min-cut 自动决定
- 诊断:
TORCH_LOGS=aot_graphs,partitioner看每段产物 - 扩展:知道第三方 backend 怎么接进 AOTAutograd
AOTAutograd 是编译栈最复杂的中间环节,也是 PyTorch v2.0 演进路线上最大的工程投入之一。它把”动态图 PyTorch”与”静态优化编译”的鸿沟填上 —— 这是 PyTorch 能在 LLM 时代继续领先的关键。
13.10 设计启示
AOTAutograd 的核心思想可迁移:
第一:用图论算法(min-cut)做”显存 vs 计算”权衡:人手调 activation_checkpoint 是不可扩展的,让算法决定。这是编译器自动化的胜利
第二:先 functionalize 再优化:mutation 让分析变难,编译期把 mutation 抹掉、做完优化再恢复。这套”先简化、再优化、最后还原”是编译器中端经典模式
第三:joint graph 比独立 forward + backward 更优:在 joint 上做切分能找到全局最优,而不是 forward 与 backward 各自局部最优
第四:runtime wrapper 让编译产物兼容 eager:不需要破坏现有 autograd Engine,让 fw/bw 子图通过 autograd.Function 接入
下一章拆 Inductor —— AOTAutograd 把切好的子图送给 Inductor,Inductor 生成 Triton kernel。这是整个编译栈的”代码生成”环节。
评论 0
还没有评论,来说两句吧。
评论加载失败,刷新重试。