第3章 c10 核心抽象:Device、DType、Layout、intrusive_ptr

intrusive_ptr<T> is an alternative to shared_ptr<T> that has better performance because it does the refcounting intrusively (i.e. in a member of the object itself).”

—— c10/util/intrusive_ptr.h:117-120 类文档第一句

本章要点

  • Device 只有 2 字节 —— (DeviceType: int8, DeviceIndex: int8),但用一个简单的 hash 就能和 21 种硬件后端配对
  • ScalarType 是 16-bit 索引,背后是 TypeMeta 类型擦除架构 —— 既支持 POD 数值类型,也保留了对老式 caffe2 对象张量的兼容
  • Layout 是 PyTorch 张量”形状语义”的开关Strided(普通)/ Sparse / SparseCsr / Mkldnn / Jagged 八种值
  • DispatchKey 是个 64-bit bitmap:底部 ~16 位是 BackendComponent(CPU/CUDA/HIP/XLA/MPS/…),上部位是 Functionality 维度(Dense/Quantized/Sparse/Autograd/…),二维笛卡尔积压成一维
  • intrusive_ptr 是 PyTorch 替代 std::shared_ptr 的自研智能指针:refcount 嵌入对象本身、单原子 64-bit 同时管 strong/weak/PyObject 标志、专门为 GIL 与跨语言生命周期设计
  • 这一章的所有抽象都不是”为了好看”,而是为了 每秒上百万次张量创建 / 调度 / 析构 这个数量级的性能预算服务

3.1 c10 是什么:PyTorch 的”原子层”

打开 PyTorch 仓库,你会看到三个并列的目录:

pytorch/
├── c10/        ← 最底层,几乎不依赖外部
├── aten/       ← ATen 算子库,依赖 c10
└── torch/      ← Python 端 + 用户接口,依赖 aten + c10

c10 是 PyTorch 的”原子层”:它定义了那些”如果改了,半个 PyTorch 都要重新编”的核心抽象 —— Tensor 的指针类型、张量元数据、设备类型、dtype、内存分配器接口、引用计数。

c10 这个名字本身是 “Caffe 2 + ATen” 的混合缩写(C2 + Aten = c10),来自 2018 年 Caffe2 与 ATen 合并的历史。今天它已经成长为一个独立的”基础设施层”,目标是 不依赖任何深度学习概念(autograd / 算子等都不在 c10 里),只关心”张量需要什么样的容器、生命周期管理、跨平台兼容性”。

这一章我们把 c10 里最关键的五块抽象逐一拆开。第 2 章我们看了 StorageTensorImpl;这一章下沉到它们底下的”零件供应商”。

graph TB
    subgraph TI[TensorImpl 字段]
        S1[Storage]
        S2[DispatchKeySet]
        S3[ScalarType]
        S4[Device]
        S5[intrusive_ptr_target]
    end

    subgraph C10[c10 核心抽象]
        D1[Device + DeviceType]
        D2[ScalarType + TypeMeta]
        D3[Layout]
        D4[DispatchKey + DispatchKeySet]
        D5[intrusive_ptr / intrusive_ptr_target]
    end

    S1 --> D5
    S2 --> D4
    S3 --> D2
    S4 --> D1
    S5 --> D5

    style C10 fill:#dbeafe,stroke:#3b82f6,stroke-width:2px
    style TI fill:#fef3c7,stroke:#f59e0b,stroke-width:2px

3.2 Device:两个字节的硬件地理学

打开 c10/core/Device.h:31

struct C10_API Device final {
  using Type = DeviceType;
  ...
 private:
  DeviceType type_;          // 1 byte
  DeviceIndex index_ = -1;   // 1 byte (using DeviceIndex = int8_t)
};

——Device 总大小只有 2 字节。但就是这 2 字节,要表达 PyTorch 支持的所有硬件后端:CPU、NVIDIA CUDA、AMD HIP、Intel XPU、Apple MPS、Google TPU/XLA、Graphcore IPU、华为 HPU、NEC SX-Aurora VE、Meta MTIA、Microsoft MAIA、Lazy Tensor、Meta Tensor、Vulkan、Metal —— 还有保留位 PrivateUse1/2/3 给厂商扩展。

完整列表在 torch/headeronly/core/DeviceType.h:35

enum class DeviceType : int8_t {
  CPU = 0,
  CUDA = 1,
  MKLDNN = 2,    // 已弃用,保留为兼容值
  OPENGL = 3,    // 同上
  OPENCL = 4,
  IDEEP = 5,
  HIP = 6,       // AMD ROCm
  FPGA = 7,
  MAIA = 8,      // ONNX Runtime / Microsoft 自家
  XLA = 9,       // Google TPU
  Vulkan = 10,
  Metal = 11,    // Apple Metal
  XPU = 12,      // Intel
  MPS = 13,      // Apple Metal Performance Shaders
  Meta = 14,     // 仅元数据张量,无实际数据
  HPU = 15,      // Habana Labs (华为 HPU 也借这个键)
  VE = 16,       // NEC SX-Aurora
  Lazy = 17,     // Lazy Tensor (用于 PyTorch/XLA 的 tracer)
  IPU = 18,      // Graphcore
  MTIA = 19,     // Meta 自家 inference/training accelerator
  PrivateUse1 = 20,
  COMPILE_TIME_MAX_DEVICE_TYPES = 21,
};

注意几个细节:

  • Meta 是一个特殊设备:上面的张量”没有真正的数据”,只有 sizes/strides/dtype。它用于”假执行” —— 第 12 章 TorchDynamo 大量用它做 graph capture
  • Lazy 也是一个特殊设备:它把每次操作记录成一个 IR 节点而不是真做,用于 XLA 后端
  • PrivateUse1/2/3 是给厂商扩展的”占位符”:当国内厂商(华为昇腾、寒武纪、壁仞)想接进 PyTorch 时,先用 PrivateUse1 实现,跑稳后社区再讨论是否升级到独立 device type

3.2.1 为什么 Device 这么紧凑

DeviceTypeint8_tDeviceIndex 也是 int8_t。整个 Device 2 字节。这种紧凑不是装饰,而是因为 Device 嵌入到 TensorImpl 里(见第 2 章),每多一字节都会乘上每秒上百万的张量创建。

std::hash<Device> 也设计得很有意思(c10/core/Device.h:194-215):把 type 和 index 拼成一个 32-bit 整数再 hash。注释里特别警告”小心 sign extension” —— 因为 DeviceIndex = int8_t 有可能是 -1(“当前设备”哨兵值),如果不小心把它直接强转 uint32_t,符号扩展会让上半部分变成全 1,hash 退化。这种”细节决定正确性”的注释在 c10 源码里随处可见,是 PyTorch 工程文化的体现。

3.2.2 DeviceIndex = -1 的”当前设备”语义

Device(CUDA)Device(CUDA, 0) 是不同的:

  • Device(CUDA)index_ 是 -1,表示当前设备cudaGetDevice 的返回值)
  • Device(CUDA, 0) 显式指定 GPU 0

这种”两态”设计让用户可以写 tensor.to('cuda') 而不强制指定卡号 —— PyTorch 在执行时再读当前 stream 绑定的设备。但它也带来一个真实坑:如果你在多 GPU 训练里没有显式指定 device index,张量可能会跑在不同的卡上,最终某个 op 拿到 device 不一致的输入而抛错。生产代码里推荐总是显式 Device('cuda', rank),把 -1 留给临时实验。

源码上 validate() 函数(c10/core/Device.h:172-185)有一行特别的注释:“Removing these checks in release builds noticeably improves performance in micro-benchmarks.” —— 它在 release 编译里直接删掉了 device index 合法性检查,只留 DEBUG_ONLY 断言。原因还是那条 “每秒上百万次张量创建”的预算 —— 即使是一次 index_ >= -1 的比较,乘上百万也成了可观开销。这种”在 release 里相信调用者”的工程取舍,是 c10 性能哲学的鲜明体现。

3.3 ScalarTypeTypeMeta:dtype 的双层表达

PyTorch 的 dtype 有两层 C++ 表示:

  • ScalarTypec10/core/ScalarType.h):一个简单的枚举,列出所有合法 dtype
  • TypeMetac10/util/typeid.h):一个 16-bit 索引 + 类型元信息(size、构造、析构)

简单 dtype 用 ScalarType 就够了:

// c10/core/ScalarType.h 中通过宏批量定义的 enum
enum class ScalarType : int8_t {
  Byte,        // uint8_t
  Char,        // int8_t
  Short,       // int16_t
  Int,         // int32_t
  Long,        // int64_t
  Half,        // at::Half (fp16)
  Float,       // float (fp32)
  Double,      // double (fp64)
  ComplexHalf,
  ComplexFloat,
  ComplexDouble,
  Bool,
  QInt8,       // 量化 int8
  QUInt8,
  QInt32,
  BFloat16,
  ...
  NumOptions,
};

TensorImpl::data_type_ 字段存的不是 ScalarType,而是 TypeMeta。为什么?

3.3.1 历史包袱:从”对象张量”到”数值张量”

PyTorch 早期(1.x 之前)继承自 Caffe2,允许张量存任意类型的对象(不仅是数值),包括 caffe2::Blob 这种带构造函数和析构函数的对象。这要求张量元数据保留”如何构造、如何析构”这一类类型擦除信息。

TypeMeta 就是这个抽象。它内部其实只是一个 16-bit ID(指向全局类型表),但通过这个 ID 可以查到:

  • itemsize() —— 单个元素多大字节
  • placementNew() —— 怎么”原地构造”一个新对象
  • placementDelete() —— 怎么”原地析构”
  • copy() —— 怎么复制
  • name() —— 类型的字符串名

今天的 PyTorch 已经把支持的 dtype 收敛到一组 POD 数值类型(fp32/fp16/bf16/int* 等),不再支持任意对象张量。但 TypeMeta 这个类型擦除接口还保留着,作为历史遗产。

源码里的兼容办法:TypeMeta 的内部表示就是一个 uint16_t index_,所有 POD dtype 在初始化时被注册到全局表的前几十个槽位。需要 ScalarType 时,TypeMeta 直接转成 ScalarType(一个简单的查表)。

// 简化版的 TypeMeta → ScalarType 转换
ScalarType TypeMeta::toScalarType() const {
    return ScalarType(index_);
}

为什么不直接用 ScalarType 而要包一层?因为:

  • 历史 ABI 兼容:早期源码大量使用 TypeMeta,全部改成 ScalarType 会破坏太多代码
  • 将来扩展空间:如果 PyTorch 重新需要”用户自定义 dtype”(如新硬件支持一种独有的 fp4 类型),TypeMeta 比硬编码的枚举好扩展

理解这种”双层表达”,你看源码里 data_type_scalar_type() 来回转的代码就不会迷惑。

3.3.2 几个看似冗余但有意义的 dtype

PyTorch 支持的 dtype 列表里有几个看起来”冗余”的:

  • Half(fp16)vs BFloat16:精度都是 16 位但指数位不同。Half 用 5 位指数 + 10 位尾数(IEEE 754 半精度),BFloat16 用 8 位指数 + 7 位尾数(与 fp32 同指数,更适合训练时的梯度)。NVIDIA Tensor Core 早期主推 Half,Google TPU 主推 BFloat16,今天 H100 / GB200 都原生支持两种
  • ComplexHalf / ComplexFloat / ComplexDouble:复数类型,用于信号处理与量子模拟。每个元素是两个浮点数,PyTorch 通过 strides + 巧妙的内存布局让复数张量能直接调用底层数学库
  • QInt8 / QUInt8 / QInt32:量化整数类型,分别表示对称量化的 int8、非对称量化的 uint8 和 int32 累加器。第 20 章会详细讲量化

每多一个 dtype,dispatcher 就要在每个算子的 dispatch 表里多一行(虽然代码生成出来)。这是为什么 PyTorch 团队对引入新 dtype 极为谨慎 —— 即便硬件厂商呼吁”加 fp4 / fp8 的某种变体”,社区也通常先要求把对应硬件落地到 PrivateUse1 上、跑过几个版本再考虑提升为正式 dtype。

3.4 Layout:除了 strides 还有谁

第 2 章我们说了 PyTorch 张量是 “strides + offset” 模型。但这只是默认情况 —— Layout::Strided。PyTorch 实际支持多种张量布局,全部列在 c10/core/Layout.h

Layout含义内部表示
Strided第 2 章讨论的标准 strides 模型sizes + strides + offset
Sparse稀疏 COO 张量indices + values 两个张量
SparseCsr稀疏 CSR (Compressed Sparse Row)crow_indices + col_indices + values
SparseCsc稀疏 CSC (列优先版本)ccol_indices + row_indices + values
SparseBsr / SparseBsc块稀疏 (Block Sparse)类似 CSR/CSC,但 values 是块矩阵
MkldnnIntel MKL-DNN 私有格式不透明字节缓冲 + MKL 元信息
Jagged嵌套张量,每行长度可变用于 RNN 变长输入、graph batching

每个 Layout 对应一个或多个 TensorImpl 子类(见第 2 章 §2.9.7)。Layout 字段本身只是个标签,决定后续算子怎么解释这块”数据”。

3.4.1 为什么 Sparse 需要单独的 Layout

考虑一个稀疏矩阵:100 万行 × 100 万列,但只有 1000 个非零元素。如果用普通 strided 张量,需要 4 TB 显存;用 COO 表示,只需要 2000 个 indices + 1000 个 values = 12 KB。

PyTorch 用 Layout 让 同一个张量类型 (torch.Tensor) 可以表达稠密和稀疏数据。算子根据 Layout 走不同实现:add(strided_tensor, sparse_tensor) 会被 dispatcher 路由到 add_sparse_dense kernel,而不是普通的 add_kernel

第 1 章 §1.3 我们提过 dispatcher 的多分派 —— 现在你看到 dispatcher 实际上是根据 (Device, Layout, DType, Functionality) 四个维度的笛卡尔积选择 kernel 的。Layout 是其中一个轴。

3.4.2 Jagged Layout:一个值得停下来看的新成员

Jagged 是相对较新的 Layout(v2.1+)。它的核心场景是 变长序列张量,典型的就是 NLP 里”一个 batch 包含 8 个句子,每句长度不同”。

传统做法是 padding:把所有句子补到最长长度,padding 位置写 0。代价是 padding 的部分占内存、占算力。

Jagged 张量直接把变长这件事 first-class:

sizes = [batch=4, var=*, dim=128]
       ↑              ↑
   有限维度        变长维度(每行长度不同)

底层是一个 values 张量(所有有效元素拼一起)+ 一个 offsets 数组(每行起止位置)。算子在 Jagged 上可以只对有效元素计算,跳过 padding,吞吐能提升 1.5-2x。

PyTorch 的 NestedTensor 在 v2.x 里大量基于 Jagged Layout 实现。第 9 章 nn.Module 章会有专门小节讲它。这是 Layout 维度作为”数据组织语义开关”的最新例子。

3.5 DispatchKeyDispatchKeySet:64 位的多维身份证

到第 5 章我们会详细拆 dispatcher 的查找逻辑。这一章只需要知道一件事:DispatchKeySet 是一个 64-bit bitmap,每一位代表一个 DispatchKey

打开 c10/core/DispatchKey.h:136

enum class DispatchKey : uint16_t {
  Undefined = 0,
  CatchAll = Undefined,

  // 可分派功能键 (Functionality keys, 高位)
  Dense,
  Quantized,
  Sparse,
  SparseCsr,
  NestedTensor,
  AutogradFunctionality,
  ...

  // 内存格式 / Mode 键
  AutogradOther,
  Tracer,
  AutocastCPU,
  AutocastCUDA,
  Functionalize,
  Python,
  ...

  // 后端键 (Backend keys, 低位)
  CPU,
  CUDA,
  HIP,
  XLA,
  MPS,
  ...
};

DispatchKey 总数大概 100+,但 DispatchKeySet 是 64-bit —— 装不下所有 key。怎么办?

PyTorch 用了一个聪明的设计:两段拼接

3.5.1 BackendComponent + Functionality 的笛卡尔积压一维

graph LR
    Naive["朴素方案<br/>每 key 一位<br/>需要 100+ bits"] -.装不下.-> Fail[超出 64 位]

    Smart["PyTorch 方案<br/>16 BackendComponent × 48 Functionality<br/>= 768 个语义"] --> Fit[64 bits 装下所有]

    style Fail fill:#fee2e2
    style Fit fill:#dcfce7

DispatchKeySet 的 64 位被这样切分:

位 63 .. 16          位 15 .. 0
[ Functionality bits ] [ BackendComponent bits ]
   ~48 个功能键          ~16 个后端键
  • BackendComponent(低 16 位):CPU、CUDA、HIP、XLA、MPS、IPU、XPU、HPU、VE、Lazy、MTIA、MAIA、PrivateUse1/2/3、Meta
  • Functionality(高 48 位):Dense、Quantized、Sparse、SparseCsr、NestedTensor、Autograd 等正交维度

当一个张量是”CUDA 上的稠密 fp32 张量、需要梯度、当前在 autocast 上下文里”,它的 DispatchKeySet 会同时点亮:

  • BackendComponent: CUDA
  • Functionality: Dense
  • Functionality: AutogradCUDA(实际上是 AutogradFunctionality + CUDA 的合成)
  • Functionality: AutocastCUDA

每次进 dispatcher,PyTorch 计算最高优先级的 functionality 与最高优先级的 backend,用乘法表查到对应的具体 kernel —— 这是为什么 100+ keys 能塞进 64 bits 的秘密。

3.5.2 这个设计的代价与收益

代价:dispatcher 的查找逻辑比”单一 enum 查表”要复杂,每次 lookup 多几次位操作。

收益

  1. 新增一个 backend(如某国产芯片)只要扩 BackendComponent 一项,不用改 functionality 维度
  2. 新增一个 functionality(如 vmap 这种 mode)只要扩 functionality 维度,自动支持所有 backend
  3. DispatchKeySet 的位操作(OR、AND、highest_bit)天然 O(1),吞吐高

第 5 章会讲 dispatcher 的具体 lookup 算法。这一章你只需要把 DispatchKeySet 当作”张量身份证”理解就好。

3.5.3 优先级与 highestPriorityTypeId

DispatchKeySet 提供一个核心方法 highestPriorityTypeId(),它的实现就是 找最高位的 1。在硬件上 x86 / ARM 都有原生 BSR / CLZ 指令做这件事,复杂度严格 O(1)。

PyTorch 编排 DispatchKey 的”位序”的方式特别讲究:

  • 越高位 = 越高优先级 —— autograd-related keys、Mode keys 在最高位,因为它们必须先于普通 backend 命中
  • 正交维度合理隔开 —— functionality keys 与 backend keys 在不同段位,避免互相干扰
  • 新增 key 优先在低位 —— 让既有代码的优先级关系不变

打开 c10/core/DispatchKey.h 顶部那段长达数百行的注释,你会看到 Edward Yang 把每一个 key 的位置都解释了一遍。这份注释是”PyTorch 设计的活档案” —— 第 5 章我们会逐段拆。

实际值得记住的优先级直觉:

高 → 低
PythonDispatcher (最高,给 Mode 系统)
FuncTorchDynamicLayer (vmap/grad 等函数变换)
AutogradXxx (反向图记录)
AutocastXxx (混合精度)
BackendSpecific (CUDA/CPU/...)
Undefined (最低)

这就是为什么 with torch.no_grad(): 关闭 autograd 后,algorithm 还能继续跑 —— autograd key 被压低,下一层 backend key 立刻命中。

3.6 intrusive_ptr:PyTorch 自研的智能指针

终于到本章的重头戏。PyTorch 不用 std::shared_ptr,而是自己造了一套 c10::intrusive_ptr —— 这是写过 C++ 的人最容易困惑的设计决定之一。让我们彻底拆开。

3.6.1 std::shared_ptr 的”控制块”设计

回顾一下 std::shared_ptr<T> 的内存结构:

graph TB
    SP[shared_ptr&lt;T&gt;<br/>16 字节: 数据指针 + 控制块指针]
    SP --> CB[控制块<br/>──────────<br/>strong refcount<br/>weak refcount<br/>deleter<br/>allocator]
    SP --> OBJ[T 对象]
    CB -. 析构时调用 deleter .-> OBJ

    style CB fill:#fee2e2,stroke:#ef4444
    style OBJ fill:#dbeafe,stroke:#3b82f6

shared_ptr<T> 把引用计数放在一个独立的”控制块”对象里,控制块和 T 对象是两块分离的堆内存。这种设计的问题:

  1. 额外一次堆分配:除非用 make_shared 把对象和控制块合并,否则每次 new T + shared_ptr<T>(p) 是两次 malloc
  2. 额外一次指针跳转:每次访问 refcount 要先解引用 shared_ptr 拿到控制块指针,再访问其中的 atomic
  3. Cache 不友好:T 对象和它的 refcount 在内存里不相邻,每次 incref / decref 都打散一条 cache line

对一个 普通 C++ 项目,这些代价可以忽略。但 PyTorch 张量每秒被创建/析构上百万次 —— 每次省一次 cache miss 都能换来几个百分点的整体性能。

3.6.2 intrusive_ptr 的”内嵌”设计

c10::intrusive_ptr<T> 反过来把 refcount 嵌到 T 对象自己身上

graph TB
    IP[intrusive_ptr&lt;T&gt;<br/>8 字节: 数据指针]
    IP --> OBJ["T 对象 (继承 intrusive_ptr_target)<br/>──────────<br/>std::atomic&lt;uint64_t&gt; combined_refcount_<br/>(strong + weak + PyObject 标志)<br/>──────────<br/>用户字段..."]

    style OBJ fill:#dcfce7,stroke:#22c55e,stroke-width:2px

要点:

  1. 指针只有 8 字节(不是 16)—— intrusive_ptr<T> 内部就是一个原始指针
  2. 没有控制块 —— refcount 直接嵌在 T 里
  3. 少一次堆分配 —— make_intrusive<T>(...) 一次 malloc 拿到对象 + refcount
  4. Cache 友好 —— 访问 T 字段时 refcount 也在同一个 cache line 上

代价是 T 必须继承 intrusive_ptr_target,引入了侵入式约束。但对 PyTorch 这种”所有共享对象都已经是大对象”的场景,这个约束几乎零代价。

源码里这段哲学的注释(c10/util/intrusive_ptr.h:117-120)只有一句话:

intrusive_ptr<T> is an alternative to shared_ptr<T> that has better performance because it does the refcounting intrusively.

简洁、直接,没有废话。

3.6.3 单原子 64-bit 同时管 strong / weak / PyObject

打开 c10/util/intrusive_ptr.h:188

mutable std::atomic<uint64_t> combined_refcount_;
static_assert(sizeof(std::atomic<uint64_t>) == 8);
static_assert(alignof(std::atomic<uint64_t>) == 8);
static_assert(std::atomic<uint64_t>::is_always_lock_free);

一个 atomic uint64 同时存 strong refcount + weak refcount + PyObject 标志位。布局:

位 63              位 32-62          位 0-31
[ kHasPyObject 1bit ] [ weakcount 31bits ] [ strong refcount 32bits ]

代码里的常量定义(intrusive_ptr.h:34-41):

constexpr uint64_t kImpracticallyHugeReferenceCount = 0x0FFFFFFF;
constexpr uint64_t kImpracticallyHugeWeakReferenceCount =
    (kImpracticallyHugeReferenceCount << 32);
constexpr uint64_t kReferenceCountOne = 1;
constexpr uint64_t kWeakReferenceCountOne = (kReferenceCountOne << 32);
constexpr uint64_t kUniqueRef = (kReferenceCountOne | kWeakReferenceCountOne);
constexpr uint64_t kHasPyObject = (uint64_t(1) << 63);

这种”单原子打包”的好处:

  • incref/decref strong 和 weak 是同一个原子操作 —— 在 std::shared_ptr 里需要两个 atomic
  • kUniqueRef 检查(“我是唯一持有者吗”)只需要一次原子读 —— 用于优化 inplace 操作

is_always_lock_free 这个 static assert 保证了 64-bit atomic 在所有目标平台都能用 lock-free 实现(不退化成 mutex)—— 这是 PyTorch 假设的硬件下限。

3.6.4 kImpracticallyHugeReferenceCount 哨兵:防 overflow 的智慧

注意上面常量里有一个 kImpracticallyHugeReferenceCount = 0x0FFFFFFF(约 2.7 亿)。这是干嘛的?

它是 “析构进行中”的哨兵值。当某个对象的 refcount 减到 0、即将开始析构时,PyTorch 会先把 refcount 设为这个”巨大值”,然后开始析构。如果析构期间又有人尝试 incref(例如析构函数里调了某个回调,回调里又拿到了对象引用)—— refcount 不会从 0 变成 1(导致”复活”),而是从 2.7 亿变成 2.7 亿+1 —— 这是个明显异常的值,能被检查出来。

这种”哨兵值”的设计在并发引用计数库里很常见(boost/intrusive_ptr 也有类似机制),但 PyTorch 的实现把它整合进单原子方案里非常优雅。

3.6.5 Note [Stack allocated intrusive_ptr_target safety]

intrusive_ptr.h:126-141 有一段 Edward Yang 写的注释,解释一个微妙的安全问题:

A well known problem with std::enable_shared_from_this is that it allows you to create a std::shared_ptr from a stack allocated object, which is totally bogus because the object will die once you return from the stack.

std::shared_ptr 允许你这样写:

struct Foo : std::enable_shared_from_this<Foo> { ... };
Foo foo;                              // 栈上分配
auto p = foo.shared_from_this();      // 编译通过,但 p 析构时会 delete foo —— UB!

intrusive_ptr 怎么防?答案是 intrusive_ptr_target 的默认 refcount 是 0。当你用原始指针构造 intrusive_ptr 时,PyTorch 会断言 refcount > 0(说明对象是被 make_intrusive 创建的,已经在堆上)。如果你把栈对象传进 intrusive_ptr,refcount 还是 0,断言挂掉,“no intrusive_ptr for you!”

这种”主动防御”是 PyTorch 工程文化的另一例。

源码里 intrusive_ptr 还提供 unsafe_steal_from_new 等显式”我知道我在做什么”的接口,给到那些真的需要从原始指针构造(如某些 FFI 边界)的高级用户。这些 unsafe API 名字里带 unsafe 三个字母,逼用户在写代码时就在每一处看到风险提示 —— Rust 风格的不安全标记在 C++ 里也奏效。

3.6.6 make_intrusive 与单次堆分配

make_intrusive<T>(args...)intrusive_ptr 的”标准创建方式”,对标 std::make_shared。它的实现简洁:

// 简化的 make_intrusive
template <class T, class... Args>
intrusive_ptr<T> make_intrusive(Args&&... args) {
    auto* ptr = new T(std::forward<Args>(args)...);
    return intrusive_ptr<T>::reclaim_copy(ptr);
}

注意它做的事只是 new T(...) 然后包成 intrusive_ptr —— 因为 intrusive_ptr_target 把 refcount 嵌入了 T 自身,没有独立控制块要分配,所以严格只有一次 new

对比 std::make_shared 虽然也只调一次 operator new,但分配的内存大小是 sizeof(T) + sizeof(ControlBlock),对象布局依赖编译器实现。intrusive_ptr 直接 new T,对象布局完全可控 —— 这对底层调试和性能分析更友好。

3.6.7 实际训练场景里 intrusive_ptr 的影子

当你写训练代码时,intrusive_ptr 的影子无处不在:

  • 每次 tensor.clone() —— 新建 TensorImpl + 新建 StorageImpl,两次 make_intrusive
  • 每次 tensor.view(...) —— 新建 TensorImpl + 复用现有 StorageImpl,一次 make_intrusive + 一次 incref
  • 每次 optimizer.step() 里更新一个参数张量 —— 涉及参数 / grad / momentum / variance 多个张量的 intrusive_ptr 同时操作
  • 每次 backward 释放一个反向图节点 —— 一次 decref,可能触发析构

一个真实的性能数字:v1.x 的 intrusive_ptr 在某些场景被发现 incref/decref 占了 ATen 总时间的 5-8%。v2.0 之后通过把 strong/weak 合并到单原子(就是本章讨论的 combined_refcount_)把这部分降到了 2-3%。这就是看似纯工程的 ref counting 优化能换来的真金白银。

3.7 PyObject 保活:为 Python 而生的 64 位顶位

combined_refcount_ 的位 63 是 kHasPyObject 标志位 —— 这是 intrusive_ptr 区别于普通 shared_ptr 的最大特性,也是 PyTorch 能让 Python 端的 torch.Tensor 保持身份的核心机制。

3.7.1 问题:Python 包装对象的”双重身份”

考虑一段代码:

import torch

a = torch.randn(3, 3)
a.my_attr = 'hello'        # 给 Python 端的 Tensor 加自定义属性

b = a + 0                  # 触发 C++ 端创建新 TensorImpl,再包成 Python Tensor
print(a.my_attr)           # 'hello' —— 但这是 a 的属性,不是 b 的

有时候 PyTorch 在 C++ 端会通过某些操作”丢失” Python 端的引用

del a                      # Python 端 a 被删,但 C++ 端 TensorImpl 可能还在
                           # (某个 grad_fn 还在持有 it)
# ...过一会儿,C++ 端通过某种路径又把那个 TensorImpl 暴露回 Python
# 这次包装出来的 Tensor 还能找回 my_attr 吗?

如果 PyTorch 每次都重新创建 PyObject —— 不行,老的 my_attr__torch_function__ 注册等一切自定义都丢了。

3.7.2 解法:combined_refcount_ 的位 63

kHasPyObject 位的语义是”这个 C++ 对象当前有一个 Python wrapper”。当 C++ 端的 TensorImpl 第一次被包成 THPVariable(Python 对象),位 63 被置 1。

关键规则(注释在 intrusive_ptr.h:171-186):

  • PyObject 持有 C++ 对象的 strong reference
  • 当 C++ 端 refcount 从 1 涨到 2 时,PyTorch 同时给 PyObject 加一次 Python refcount(让 Python 端不释放 wrapper)
  • 当 C++ 端 refcount 从 2 降到 1 时,给 PyObject 减一次 Python refcount

这种”同步引用计数”机制让 C++ 与 Python 两端的对象生命周期保持一致:只要 C++ 还有引用,Python wrapper 就活着;wrapper 上挂的属性、子类、注册都不会丢

实现上还要应付 GIL —— try_incref_pyobject() 用 acquire-release 内存序处理 race condition,避免 Python GC 和 C++ 析构同时跑时的双 free 风险。这是 intrusive_ptrshared_ptr 复杂得多的根本原因。

3.7.3 这个机制带来的能力

理解了 PyObject 保活,几个 PyTorch 用法的”魔法感”就消失了:

  • __torch_function__ 协议:用户可以在 Python 端继承 torch.Tensor 写一个子类,重写 __torch_function__ 拦截算子调用 —— 这个子类在算子里来回流动后还能保持身份,就靠 PyObject 保活
  • tensor.requires_grad_(True) 在 with 块里 detach 又 attach 不会丢自定义属性
  • DataLoader 的 worker 把张量从子进程传回主进程时,能保持原 Python 对象语义

第 5 章的 TorchDispatchMode 也依赖这套机制 —— 如果 Python 端的 wrapper 不稳定,Mode 就无法跨多次 dispatch 调用持续生效。

3.7.4 incref 用 relaxed、decref 用 acq-rel 的内存序选择

如果你仔细读 intrusive_ptr.h:75-101,会看到一段对 atomic 内存序的精彩讨论:

// The only requirement for refcount increment is that it happens-before
// decrement, so no additional memory ordering is needed.
inline uint64_t atomic_combined_refcount_increment(...) {
    return combined_refcount.fetch_add(inc, std::memory_order_relaxed) + inc;
}

// All non-final decrements must synchronize-with the final decrement.
// So all non-final decrements have to store-release while the final
// decrement has to load-acquire ... it's easiest just to have all
// decrements be acq-rel. And it turns out, on modern architectures
// and chips, it's also fastest.
inline uint64_t atomic_combined_refcount_decrement(...) {
    return combined_refcount.fetch_sub(dec, std::memory_order_acq_rel) - dec;
}

这是性能敏感代码里典型的内存序选择:

  • incref 只需 relaxed:因为 incref 只是”我认领一份所有权”,不要求看到对象的最新状态。任何线程对对象的修改在 decref 时才需要可见
  • decref 必须 acq-rel:保证最后一次 decref(refcount → 0)能看到所有先前 incref 之前的写操作,避免 use-after-free

这种”非对称内存序”在 Rust Arc 里是同样的设计(参考 Rust std 文档:increment relaxed、decrement Release,析构前用 Acquire fence)。最新一致的硬件平台(x86 / ARM v8.1+)上,acq-rel 几乎和 relaxed 一样快,这就是注释里”on modern architectures and chips, it’s also fastest”的依据。

如果你写自己的并发引用计数库,这套内存序选择是基本功,照搬即可。

3.7.5 PyObject 保活的反例:什么时候它会失败

kHasPyObject 机制不是万能的。它仅在张量从 Python 进入 C++、再回到 Python 时起作用。如果某个张量纯在 C++ 创建并消费、从未暴露给 Python,PyObject 永远不会创建,位 63 不会被点亮。

考虑一个真实场景:

class MyTensor(torch.Tensor):
    @staticmethod
    def __new__(cls, data, *args, **kwargs):
        return torch.Tensor._make_subclass(cls, data)

a = MyTensor(torch.randn(3, 3))
b = a.matmul(a.t())     # 内部 C++ 计算后回到 Python
print(type(b))          # 期望是 MyTensor 子类

PyObject 保活机制让 b 的类型仍然是 MyTensor(只要 PyTorch 在算子里没有显式擦除子类信息)。但有些算子在生成中间张量时会显式调用 _make_wrapper_subclass(torch.Tensor, ...) —— 把子类擦除回基类。这是为什么有些自定义子类的”子类传染性”会失效。

__torch_function__ 协议本质就是给用户一个”在子类被擦除前接管控制”的钩子。第 5 章会讲。

3.7.6 Allocator 接口:c10 与硬件后端的契约

c10 还有一个重要抽象本章必须提到:c10::Allocatorc10/core/Allocator.h)。它定义了 PyTorch 与任何”分配字节缓冲”的硬件后端之间的契约。

// 简化的 Allocator 接口
struct C10_API Allocator {
  virtual ~Allocator() = default;
  virtual DataPtr allocate(size_t n) = 0;     // 分配 n 字节
  virtual DeleterFnPtr raw_deleter() const;
  ...
};

每个设备有自己的 Allocator 实现:

设备Allocator 实现关键文件
CPUDefaultCPUAllocator (走 malloc/free)c10/core/CPUAllocator.cpp
CPU pinnedPinnedCPUAllocator (cudaMallocHost)aten/src/ATen/cuda/CachingHostAllocator.cpp
CUDACUDACachingAllocator (本书第 4 章主角)c10/cuda/CUDACachingAllocator.cpp
MPSMPSAllocator (Apple Metal)aten/src/ATen/mps/MPSAllocator.mm
MetaMetaAllocator (假分配,不真给内存)c10/core/Allocator.cpp

Allocator 通过全局函数 c10::GetAllocator(device_type) 注册和查找。新增一个硬件后端时,第一件事就是实现 Allocator 并注册到 c10

这种”分配器即接口”的设计让 PyTorch 后端扩展极其规整 —— 厂商不需要改 c10 核心代码,只在自己的库里实现 Allocator 子类、注册即可。第 22 章自定义算子会再讲。

Allocator::allocate 返回的不是裸 void*,而是 DataPtr —— 一个带”析构器”的智能指针。这是为什么不同后端的内存释放方式(freecudaFreepool_release)都能在 PyTorch 内部以同一接口工作 —— 释放策略嵌在 DataPtr 的 deleter 函数指针里。

3.7.7 一个细节:DispatchKey 是怎么从 Tensor 计算出来的

回到第 2 章 §2.4:TensorImpl::key_set_ 字段就是这个张量的 DispatchKeySet。但它是怎么算出来的

简化的计算逻辑(c10/core/TensorImpl.cpp_set_key_setset_storage_keep_dtype 等函数交互):

  1. 从 Device 推 backend key:CPU → DispatchKey::CPU、CUDA → DispatchKey::CUDA
  2. 从 Layout 推 functionality key:Strided → DispatchKey::Dense、Sparse → DispatchKey::Sparse
  3. 从 dtype 是否量化推:QInt8 → DispatchKey::Quantized
  4. requires_grad:true → AutogradXxx 系列

每次张量创建时,构造函数计算这些键并 OR 到 key_set_。一旦计算出来基本不变(除非用户调 requires_grad_(...),会重新刷新)。

这个 “张量的所有元数据共同决定它的 dispatch 身份” 的设计是 PyTorch 多分派系统的精髓。第 5 章会拆 dispatcher 怎么把 key_set_ 与 thread-local 的 mode keys OR 后做 lookup。

3.8 横向对比:shared_ptr / Arc / Caffe2 旧引用计数

把三种方案放在一起:

维度std::shared_ptr<T>c10::intrusive_ptr<T>Rust Arc<T>
引用计数位置单独的控制块嵌入 T 对象单独的”内部”分配(与 T 一起 malloc)
Refcount 大小通常 strong + weak 各 32-bit (两个 atomic)单 atomic 64-bit (打包)strong + weak 两个 atomic
堆分配次数2 次(除非 make_shared)1 次1 次
侵入式是(必须继承 intrusive_ptr_target)
Python 集成有(kHasPyObject 位)
Cache locality差(控制块分离)

PyTorch 选 intrusive_ptr极致性能 + Python 集成 + 侵入式可接受 三个条件叠加的产物。这个选择不是普世正确的 —— 普通 C++ 项目用 shared_ptr 就足够。但对一个每秒上百万次张量创建的库,这种”卡尺级别”的优化是性能差距的源泉之一。

值得一提的是 Boost 早就提供了 boost::intrusive_ptr<T> —— 思想几乎相同。PyTorch 的 c10 没用 boost 而是从头写一份,原因有两条:boost 是个庞大依赖、c10 想保持”几乎不依赖外部”原则;boost 没有 PyObject 保活机制、kHasPyObject 这一类 PyTorch 特有需求要从头实现。所以 c10 的 intrusive_ptr 是”boost 思想 + PyTorch 需求”的二次创作,不是凭空发明。把它和 boost 放一起读,就能看清”通用智能指针”和”专为深度学习框架优化的智能指针”的区别。

Caffe2 时代曾经用过另一种引用计数(caffe2::Blob 系列),后来在合并 PyTorch 时被统一到 intrusive_ptr。今天 c10 的 intrusive_ptr 是 PyTorch + 已死的 Caffe2 两套生态融合后的产物,做到了”取精华、去糟粕”。

3.9 跨书关联

  • 《Tokio 异步运行时》第 X 章 智能指针与并发:Tokio 大量用 Arc<T>,与 PyTorch 的 intrusive_ptr 是不同语言里相同思想的体现。值得对照看的细节是 Rust Arc 的 weak count 实现 —— 它也用了”如果 weak count 是 0 但 strong count > 0 时不可能 race”的 invariant
  • 《vLLM 内核探秘》第 6 章 Worker 与 Executor:vLLM 在 Python 端用 multiprocessing.shared_memory 做跨进程张量传递,shared_memory 的 refcount 与 PyTorch 的 intrusive_ptr 在跨进程时的协议是个有趣的话题
  • 《Rust 编译器之路》第 X 章 编译期检查:Rust 编译器用 Lrc<T>(Local Rc)替代 Rc<T> 优化 single-thread 场景。c10 的 intrusive_ptr 没有这个优化(始终用 atomic),原因是 PyTorch 几乎从不做”我确定 single-thread”的假设

3.10 一个练习:手写一个 mini-intrusive_ptr

为了内化本章内容,写一个 30 行的 mini 版本:

#include <atomic>
#include <cassert>
#include <utility>

// mini 版 intrusive_ptr_target
struct MyTarget {
    mutable std::atomic<uint64_t> combined_refcount_{0};

    void incref() const { combined_refcount_.fetch_add(1, std::memory_order_relaxed); }
    bool decref() const {
        // 返回 true 表示这是最后一次 decref,可以析构
        return combined_refcount_.fetch_sub(1, std::memory_order_acq_rel) == 1;
    }

    virtual ~MyTarget() = default;
};

template <class T>
class my_intrusive_ptr {
    T* ptr_ = nullptr;
public:
    my_intrusive_ptr(T* p) : ptr_(p) { if (ptr_) ptr_->incref(); }
    my_intrusive_ptr(const my_intrusive_ptr& o) : ptr_(o.ptr_) { if (ptr_) ptr_->incref(); }
    my_intrusive_ptr& operator=(const my_intrusive_ptr& o) {
        if (this != &o) { reset(); ptr_ = o.ptr_; if (ptr_) ptr_->incref(); }
        return *this;
    }
    void reset() {
        if (ptr_ && ptr_->decref()) delete ptr_;
        ptr_ = nullptr;
    }
    ~my_intrusive_ptr() { reset(); }
    T* get() const { return ptr_; }
    T* operator->() const { return ptr_; }
};

// 使用
struct Foo : MyTarget { int x; Foo(int v) : x(v) {} };

int main() {
    my_intrusive_ptr<Foo> a(new Foo(42));
    {
        auto b = a;                 // refcount 1 → 2
        assert(a->x == b->x);
    }                               // b 析构, refcount 2 → 1
}                                   // a 析构, refcount 1 → 0, delete Foo

这 30 行就是 c10::intrusive_ptr 最朴素的样子。把它跑起来、加 weak_intrusive_ptr 试试、再把 strong 和 weak 打包到单 atomic 里 —— 你就重走了 PyTorch 团队做这套设计的全部步骤。

3.10.4 weak_intrusive_ptr 与”我能否复活引用”

intrusive_ptr 还有一个孪生兄弟 weak_intrusive_ptr<T> —— 弱引用版本。语义类似 std::weak_ptr:持有它不阻止对象析构,但可以尝试”升级”成 strong intrusive_ptr。

它的关键 API 是 lock()

intrusive_ptr<T> lock() const noexcept {
    // 原子地尝试把 strong refcount 从 N (N > 0) 增加到 N+1
    // 如果当前 strong refcount 是 0,返回 null(对象已死)
}

实现要点是比较交换循环(CAS loop):

while (true) {
    auto current = combined_refcount_.load(std::memory_order_relaxed);
    if (refcount(current) == 0) return nullptr;     // 对象已死
    if (combined_refcount_.compare_exchange_weak(
            current, current + kReferenceCountOne,
            std::memory_order_relaxed)) {
        return intrusive_ptr<T>::reclaim(this);     // 升级成功
    }
    // CAS 失败,说明其他线程改了 refcount,重试
}

CAS loop 是无锁并发编程的标准技巧。它保证 lock() 不会出现”我刚检查 refcount 还是 1,下一刻就掉到 0 然后我还成功 incref” 的 race condition。

在 PyTorch 实际使用中,weak_intrusive_ptr 的最大场景是 autograd 的 grad_fn 节点保存输入张量:grad_fn 用 weak ref 持有输入张量,避免反向图无限保持张量存活;反向时 lock() 看张量还在不在,不在就直接抛”input was modified inplace”错。这是为什么有时候你写 loss.backward() 报错说某个张量已被释放。

3.10.5 一个真实的踩坑:intrusive_ptr 与多进程

intrusive_ptr 在单进程 C++ 里几乎没坑,但跨进程时有趣的问题就出现了。考虑 PyTorch DataLoader 的 worker:

loader = DataLoader(dataset, num_workers=4)
for batch in loader:
    ...   # batch 是从 worker 子进程传过来的

worker 子进程里创建一个 tensor,要传到主进程。底层路径:

  1. worker 创建 TensorImpl + StorageImpl,refcount = 1
  2. 通过 torch.multiprocessing 用 shared memory 把 storage 的字节缓冲共享
  3. 主进程拿到一个新的 TensorImpl,指向同一段共享内存
  4. 主进程消费完,refcount 降到 0,析构

问题:worker 端和主进程端的 TensorImpl 是两个独立对象(不可能跨进程共享 C++ 对象),但底层字节缓冲是共享的。如果 worker 端的 storage 已经先析构了 refcount → 0,再来主进程析构,cudaFree 是不是会被调两次?

PyTorch 的解法:StorageImpl 的 deleter 在跨进程场景里被换成”共享内存计数”语义。worker 的 deleter 调用 shm_unlink 让计数 -1,主进程的 deleter 也调用 shm_unlink,只有最后一个调到的进程才真 free。这是 data_ptr_DeleterFnPtr 抽象的真实战场。

如果你写过 torch.multiprocessing 出现”shared memory leak”或”Bad file descriptor”错误,根源往往是这套机制的某个边界条件没被覆盖。社区在 v1.7-1.10 期间修了一系列这类 issue,今天大部分场景已经工作得很好。

3.10.6 intrusive_ptr 在调试时怎么用

实战提示:调试某个张量”为什么不被释放”时,可以用以下技巧:

import torch, sys

a = torch.randn(1000, 1000)
print(sys.getrefcount(a))           # Python 端 refcount
# 在 C++ 端可以打印 a._typed_storage()._cdata 的 refcount

sys.getrefcount 给的是 Python 端 PyObject 的 refcount,不是 C++ 端 TensorImpl 的 refcount —— 但通过 PyObject preservation 机制,两者紧密相关:Python refcount 触底意味着 C++ refcount 减 1。

如果你想看 C++ refcount,PyTorch 在 debug 模式下提供 tensor.use_count() 接口(实际上调用的是 intrusive_ptr.use_count())。生产代码里这个接口不暴露,但调试构建是趁手的工具。

第 21 章 Profiler 那章会讲怎么用 PyTorch 的内存分析工具找出”哪些张量被某个对象 hold 住了”。理解了 intrusive_ptr 的 refcount 模型,那些诊断报告才能读懂。

3.11 c10 抽象的”几条不成文规则”

读完整章,整理一下 PyTorch 团队设计 c10 时反复体现的几条原则:

  1. “每秒上百万次”是性能预算的下限 —— 任何加在 TensorImpl / Storage / 引用计数路径上的开销都要乘上百万
  2. release 模式相信调用者 —— 大量检查在 DEBUG 模式下断言、release 模式下删除(如 Device::validate)
  3. 历史向前兼容、内部向后演进 —— typed/untyped storage 共存、TypeMeta 保留 caffe2 时代抽象、Variable/Tensor 合并后老命名仍在
  4. 位操作优于查表 —— DispatchKeySet 用 64-bit bitmap、refcount 用单原子打包
  5. 侵入式优于通用 —— intrusive_ptr 牺牲了”任何 T 都能用”换来”零控制块开销”
  6. 全局单例做”空”哨兵 —— UndefinedTensorImpl、kHasPyObject 哨兵值
  7. 注释是设计文档 —— c10 关键文件顶部数百行注释是 Edward Yang 等核心维护者留下的设计史,比任何博客都权威

把这七条记在心里,后面再读 aten/torch/csrc/autograd/torch/_dynamo/ 时你能预判很多决策的方向。

3.12 一个 C++ 工程的”通用启示”

如果你不写 PyTorch、只是借鉴它的设计思想到自己的项目,本章三条最值得带走:

第一侵入式引用计数在性能敏感场景胜过 std::shared_ptr —— 但前提是你能控制基类、能接受”必须继承 X”的耦合。如果你的项目里有”必须高频共享、生命周期复杂、对 cache 敏感”的对象(比如游戏引擎里的 GameObject、数据库引擎里的 Page),照着 c10 的 intrusive_ptr 写一份非常划算。

第二enum + bitmapenum + std::variant 适合多维 dispatch —— 当你的对象需要按 N 个正交维度的笛卡尔积选择行为,与其用 std::variant<A, B, C, ...> 做 N-级 visit,不如用 N 个一维 enum 拼成一个 bitmap,然后做位操作。这种设计能把”多分派”的运行时开销压到”一次原子读 + 几次位运算”。

第三release 模式下”相信调用者”是可以的工程取舍 —— 但前提是 DEBUG 模式下要充分自检。c10 大量函数都是 DEBUG_ONLY 检查、release 模式删除,这种”双模式编译”的纪律不是偷懒,而是在性能敏感库里平衡安全与速度的标准手法。

第四用注释而非外部文档承载设计意图 —— c10/core/DispatchKey.h 顶部 200 行的注释、intrusive_ptr.h 中那段关于位 63 与 PyObject 保活的注释 —— 这些是 c10 演进过程中沉淀下来的”为什么”。把设计文档放在源码里,新维护者改代码时一定会看到、一定会被提醒。这种”文档贴近代码”的工程纪律比 wiki 上的”设计文档”有效得多。

下一章拆 CUDA Caching Allocator —— PyTorch 在 GPU 显存管理上做的最重要的工程优化之一,也是为什么 torch.cuda.empty_cache() 在大多数时候是无效的根本原因。

评论 0