第22章 自定义算子与 C++ 扩展

“Writing a custom op in 2024 is @torch.library.custom_op. Forget everything you knew about Variable / autograd.Function / TORCH_LIBRARY in 2018.”

—— PyTorch dev podcast,custom ops 现代教程

本章要点

  • v2.4+ 推荐 torch.library.custom_op 装饰器:一个 API 注册算子 + 自动接入 dispatcher / autograd / torch.compile
  • register_fake 给 FakeTensor 路径:torch.compile / FSDP 在 trace 时需要”shape 推导而不真算”
  • register_autograd 加反向规则:用类似 autograd.Function.backward 的写法
  • C++ 扩展走 TORCH_LIBRARY + pybind11:性能敏感时手写 C++ / CUDA kernel
  • 完整生态接入:自定义算子能与 dispatcher / autograd / FX / Inductor / DDP / FSDP 全部协作
  • 替代老 APIautograd.Function 还能用,但 torch.compile 兼容性差,新代码用 custom_op

22.1 何时需要自定义算子

PyTorch 内置 3000+ 算子,但仍有缺口:

  • 新硬件指令:自家芯片有特殊指令(如 NPU 的 fused attention),想用就要写 kernel 包成 PyTorch op
  • 新算子:论文里某个新激活函数、特殊归一化,PyTorch 还没收
  • 性能极致:某段热路径手写 CUDA 比组合 ATen 算子快 30%+
  • 第三方库集成:FlashAttention、xformers、Triton kernel 想暴露成 torch op

如何让自家 kernel 像内置算子一样工作 —— autograd 自动反向、torch.compile 能编译、profiler 能看到、FSDP 能正确处理 —— 是本章主题。

22.2 现代标配:torch.library.custom_op

v2.4+ 推荐写法:

import torch

@torch.library.custom_op("mylib::mymul", mutates_args=())
def my_mul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    return x * y

@my_mul.register_fake
def _(x, y):
    # FakeTensor 路径: 只返回正确 shape/dtype 的 empty tensor
    return torch.empty_like(x)

def my_mul_backward(ctx, grad):
    x, y = ctx.saved_tensors
    return grad * y, grad * x

def my_mul_setup_context(ctx, inputs, output):
    x, y = inputs
    ctx.save_for_backward(x, y)

my_mul.register_autograd(my_mul_backward, setup_context=my_mul_setup_context)

——这一段做了三件事:

  1. @custom_opmy_mul 注册到 dispatcher 的 mylib::mymul schema
  2. register_fake 给 FakeTensor 路径提供 shape 推导
  3. register_autograd 给反向规则

之后 my_mul(x, y) 就像内置算子一样工作。

22.2.1 schema 字符串

"mylib::mymul" 是命名空间 + 算子名。mutates_args=() 表示”不修改任何输入”(如果修改了 x,要写 mutates_args=("x",))。完整 schema 由 PyTorch 从函数 type hint 自动推导:

mylib::mymul(Tensor x, Tensor y) -> Tensor

如果你的 op 改了输入张量,schema 用 Tensor(a!) x 标记 alias。这套语法第 6 章 §6.2 讲过。

22.2.2 register_fake 的角色

FakeTensor 在第 5 章 §5.7 与第 13 章 AOTAutograd 出现过。几乎所有现代 PyTorch 高级特性都依赖 fake 路径

  • torch.compile 用它做 graph capture
  • FSDP 用它做 lazy init / shape 推导
  • export 用它做 torch.export(model)
  • meta tensor (无数据张量) 也走 fake

所以没注册 fake 函数的自定义算子在 torch.compile 下会 graph break。register_fake 不是可选 —— 现代代码必须有。

fake 函数只允许调用 shape 操作(empty_like / zeros / view / 算 shape),不能做实际数值计算。第 6 章 §6.4.2.5 警告过这条。

22.2.3 register_autograd:反向规则

register_autograd 接受两个函数:backwardsetup_context。语义与 autograd.Function 类似,但分开成两步:

  • setup_context(ctx, inputs, output):保存反向需要的张量(在 forward 完成后调用)
  • backward(ctx, *grads):算反向

PyTorch 内部把这套包成 autograd Node,与第 7 章讲的 XxxBackward0 完全等价。自定义算子的反向图与内置算子的反向图无差别,能被 autograd Engine(第 8 章)调度、被 AOTAutograd(第 13 章)capture。

22.3 Triton kernel 作为 custom_op 的实现

如果你想用 Triton 写 kernel(性能比纯 Python 高 10x+),可以让 custom_op 内部调 Triton:

import triton
import triton.language as tl

@triton.jit
def my_kernel(x_ptr, y_ptr, out_ptr, n: tl.constexpr):
    pid = tl.program_id(0)
    offsets = pid * 128 + tl.arange(0, 128)
    mask = offsets < n
    x = tl.load(x_ptr + offsets, mask)
    y = tl.load(y_ptr + offsets, mask)
    tl.store(out_ptr + offsets, x * y, mask)

@torch.library.custom_op("mylib::triton_mul", mutates_args=())
def triton_mul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    out = torch.empty_like(x)
    n = x.numel()
    grid = lambda meta: (triton.cdiv(n, 128),)
    my_kernel[grid](x, y, out, n)
    return out

注意:torch.compile 看到 triton_mul 时会 inline 调用进生成的 fused kernel,不会再拆开它。这种”自定义 Triton kernel + custom_op”是 FlashAttention 等 SOTA 算子的标准接入方式。

22.4 C++ / CUDA 扩展

性能极敏感时手写 C++(含 CUDA)。流程:

  1. .cpp 文件,用 TORCH_LIBRARY 注册算子
  2. setup.pytorch.utils.cpp_extension.CUDAExtension
  3. python setup.py install 编译成 .so
  4. Python 端 import 即可

C++ 端:

// my_ops.cpp
#include <torch/extension.h>
#include <torch/library.h>

at::Tensor my_mul_cpu(const at::Tensor& x, const at::Tensor& y) {
    return x * y;
}

at::Tensor my_mul_cuda(const at::Tensor& x, const at::Tensor& y) {
    // 实际 CUDA kernel launch
    auto out = at::empty_like(x);
    my_mul_cuda_kernel<<<grid, block>>>(x.data_ptr<float>(), y.data_ptr<float>(),
                                          out.data_ptr<float>(), x.numel());
    return out;
}

TORCH_LIBRARY(mylib, m) {
    m.def("mymul(Tensor x, Tensor y) -> Tensor");
}

TORCH_LIBRARY_IMPL(mylib, CPU, m) {
    m.impl("mymul", my_mul_cpu);
}

TORCH_LIBRARY_IMPL(mylib, CUDA, m) {
    m.impl("mymul", my_mul_cuda);
}

setup.py

from setuptools import setup
from torch.utils.cpp_extension import CUDAExtension, BuildExtension

setup(
    name='mylib',
    ext_modules=[CUDAExtension('mylib', ['my_ops.cpp', 'my_kernel.cu'])],
    cmdclass={'build_ext': BuildExtension},
)

加载后 Python 端:

import torch.ops.mylib
out = torch.ops.mylib.mymul(x, y)

C++ 扩展是国内 AI 芯片厂商接 PyTorch 的标准路径 —— 在 cpp 端用自家 SDK 写 kernel,注册到 dispatcher 的 PrivateUse1 key。

22.5 老 API:autograd.Function

老的 v1.x 写法仍然支持:

class MyMul(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, y):
        ctx.save_for_backward(x, y)
        return x * y

    @staticmethod
    def backward(ctx, grad):
        x, y = ctx.saved_tensors
        return grad * y, grad * x

out = MyMul.apply(x, y)

简单直接,调试方便。但缺点:

  • torch.compile 看到 apply 通常 graph break:Inductor 不知道怎么编译 Python autograd.Function
  • 没有 schema:没法走 dispatcher,对 FSDP / FX 不友好
  • 没有 fake 实现:torch.compile / export 走不通

如果你只是研究阶段快速写个 op、不上 compile:autograd.Function 够用。如果生产代码 + 想 torch.compile 加速:必须用 torch.library.custom_op

第 7 章 §7.8.1 我们对比过两套接口,结论一致。

22.6 完整集成检查清单

写一个生产级自定义算子,要做的事:

flowchart TB
    Op[custom_op 装饰器]
    Op --> Fake[register_fake<br/>shape 推导]
    Op --> Auto[register_autograd<br/>反向规则]
    Op --> Cpu[CPU kernel<br/>register_kernel device='cpu']
    Op --> Cuda[CUDA kernel<br/>register_kernel device='cuda']

    Fake --> Compile[✓ torch.compile 兼容]
    Cpu --> Eager[✓ eager 路径]
    Cuda --> Eager
    Auto --> Eng[✓ autograd Engine]

    Style0[注册到 dispatcher<br/>自动获得]
    Op --> Style0
    Style0 --> Disp[dispatch 调度]
    Style0 --> Prof[profiler 自动看到]
    Style0 --> Fsdp[FSDP / DDP 兼容]

    style Op fill:#fef3c7,stroke:#f59e0b,stroke-width:2px

清单:

  1. @custom_op 装饰器声明
  2. register_fake 给每个 op
  3. register_autograd 如果可微
  4. register_kernel(..., 'cpu') + register_kernel(..., 'cuda') 各自实现
  5. ✅ 写单元测试用 torch.library.opcheck 自动验证(PyTorch 提供的算子合规性检查)

22.6.5 opcheck:自定义算子的合规性测试矩阵

torch/library.py:1632torch.library.opcheck 是自定义算子的”质保检查”。它跑 5 项测试,确认 op 与 PyTorch 各子系统兼容:

import torch
from torch.library import opcheck

opcheck(my_mul, args=(x, y), test_utils=("test_schema", "test_autograd_registration",
                                          "test_faketensor", "test_aot_dispatch_static",
                                          "test_aot_dispatch_dynamic"))

5 项测试的具体职责:

测试检查什么
test_schemaschema 字符串与实际实现的输入输出 dtype / shape 是否一致
test_autograd_registration注册了 autograd 后反向规则是否数值正确(用 gradcheck 比对数值梯度)
test_faketensorfake 函数返回的 shape / dtype 是否与真实 kernel 输出一致
test_aot_dispatch_static在 AOTAutograd(静态 shape 模式)下能否正确 trace 与编译
test_aot_dispatch_dynamic同上但 dynamic shape 模式(更严格,要求 fake 函数能处理 SymInt)

生产级自定义算子必须 opcheck 通过。社区贡献到 PyTorch 主仓的 op PR 都被要求附 opcheck 测试。这套自动化检查避免了”自定义 op 在 eager 跑得对、torch.compile 编译错”等隐蔽 bug。

opcheck 内部用 torch._library.fake_class_registry 验证 fake 实现、用 torch.autograd.gradcheck 验证反向、用 torch._dynamo 跑 trace 验证 compile 路径。一次调用覆盖整个生态的兼容性。

22.6.6 Library 低级 API

@custom_op 是高级糖,底层是 torch.library.Librarylibrary.py:68)。它提供更细粒度的算子注册:

from torch.library import Library

# 创建一个 library (类似 C++ 端的 TORCH_LIBRARY)
lib = Library("mylib", "DEF")

# 注册 schema (没有实现, 等下注册)
lib.define("mymul(Tensor x, Tensor y) -> Tensor")

# 给特定 dispatch key 注册实现
lib.impl("mymul", lambda x, y: x * y, "CPU")
lib.impl("mymul", my_cuda_kernel, "CUDA")
lib.impl("mymul", my_meta_kernel, "Meta")    # FakeTensor 也是 Meta key

第二个参数 "DEF" 是 library 的 kind:

  • DEF:定义新算子(创建 schema)
  • IMPL:给已有算子加新 dispatch key 实现
  • FRAGMENT:在已有 library 里追加新 op(可多次)

@custom_op 装饰器内部就是构造 Library 然后调 define / impl。直接用 Library 时你能精确控制每个 dispatch key 的实现 —— 适合需要”给 PrivateUse1 注册新 backend”等高级场景。

22.6.7 PrivateUse1:国产芯片接入完整路径

PyTorch 给厂商扩展自家硬件留了 3 个 dispatch key:PrivateUse1 / PrivateUse2 / PrivateUse3(第 3 章 §3.5)。完整接入流程:

# 1. 给 PrivateUse1 起个有意义的名字
torch.utils.rename_privateuse1_backend("npu")
# 之后用户可以写 tensor.to('npu') 而非 'privateuseone'

# 2. 给 PrivateUse1 注册所有 ATen 算子的实现
@torch.library.impl("aten::add.Tensor", "PrivateUse1")
def npu_add(self, other, alpha=1):
    # 调你家硬件 SDK 的 add kernel
    return _npu_runtime.add(self, other, alpha)

# ... 给几百个常用算子各注册一个 impl ...

# 3. 提供 generate_methods_for_privateuse1_backend 让 tensor.npu() 等方法可用
torch.utils.generate_methods_for_privateuse1_backend()

torch/utils/backend_registration.py:20rename_privateuse1_backendPrivateUse1 重命名 + :362 的 generator 自动给 Tensor 添加 .npu() / .is_npu / .npu() 等方法。这套 API 让国产芯片厂商可以做出完整 PyTorch 体验而不修改主仓代码。

实际工作量:给 PyTorch 全部 3000+ 算子各写一个 backend impl 是几十人月的工程,但**torchgen/gen_backend_stubs.py(第 6 章 §6.10.5)能从一份”目标算子列表 YAML”自动生成 stub 代码**,厂商只需要填实现细节 —— 工作量降到几百算子级。

torch_npu(华为)、torch_mlu(寒武纪)、torch_xpu 等都走这条路。开源在 GitHub 能看到完整模板。

22.6.8 allow_in_graphdisable:torch.compile 的两个逃生口

写自定义算子时常遇到 Dynamo 不会 trace 的代码(如调了第三方 C 扩展、动态行为太复杂)。PyTorch 提供两个装饰器作为逃生口:

@torch.compiler.allow_in_graphtorch/compiler/__init__.py:72):

@torch.compiler.allow_in_graph
def my_special_function(x, y):
    # Dynamo 不 trace 这个函数体
    # 把整个调用当作"一个不透明 op"加入 graph
    return some_external_lib.do_magic(x, y)

效果:Dynamo 看到调用 my_special_function(x, y) 时,把它当作单个不透明算子放进 FX Graph(不展开内部)。Inductor 等后端会调用原始函数,跳过编译。

@torch._dynamo.disable

@torch._dynamo.disable
def my_complex_logic(x):
    # Dynamo 看到这个调用直接 graph break, 退回 eager
    if x.sum() > 0:
        return some_python_heavy_logic(x)
    else:
        return another_branch(x)

效果:Dynamo 在调用处触发 graph break,整段函数用 eager 跑,break 之后再开始新 trace。

两者关键区别:

装饰器Dynamo 行为适合场景
allow_in_graph当作不透明 op 留在 graph 里函数行为是确定的 tensor 计算,但 Dynamo trace 不动(如调了某 C 扩展)
disable触发 graph break,退回 eager函数有复杂 Python 逻辑(动态控制流 / 大量 dict 操作 / print),不希望 Dynamo 浪费时间分析

实际工程里:

  • 写自定义 Triton kernel + register_fake:用 custom_op(§22.2),不需要这两个装饰器
  • 集成第三方 C 扩展(如 FlashAttention v1 的私有 wrapper):用 allow_in_graph 把它当黑盒
  • 训练循环里的 logging / metric reporting 函数:用 disable 让 Dynamo 不要试图分析

torch/_dynamo/decorators.py 还提供更细的开关:disallow_in_graph(强制某 op 触发 graph break)、mark_static_address(声明 tensor 地址不会变,让 CUDA Graph 能复用)等。生产代码里写自定义算子的 escape hatch,理解这套装饰器家族能让你优雅处理”compile 不动”的边角情况。

22.6.9 inplace 与多输出算子的注册

@custom_op 默认假设 op 是”纯函数”(无副作用、单输出)。两种特殊形态需要额外配置:

inplace 算子(mutate input)

@torch.library.custom_op("mylib::add_inplace_", mutates_args=("x",))
def add_inplace_(x: torch.Tensor, y: torch.Tensor) -> None:
    x.add_(y)
    # 没有 return: schema 是 (Tensor(a!) x, Tensor y) -> ()

mutates_args=("x",) 让 schema 里 x 标 alias Tensor(a!)。functionalize(§13.4)看到这个标记后会重写代码:把 add_inplace_(x, y) 变成 x_new = x + y; x = x_new 这种纯函数版本。这是 v2.x 让 inplace op 与 compile 共存的关键。

不写 mutates_args 但实际 mutate 输入 → 隐蔽 bug:torch.compile 假设无副作用、生成的 kernel 不会复制 x,运行时 x 被修改但 graph 看不到 → 后续算子拿到错的 x。

多输出算子

@torch.library.custom_op("mylib::topk_with_idx", mutates_args=())
def topk_with_idx(x: torch.Tensor, k: int) -> tuple[torch.Tensor, torch.Tensor]:
    values, indices = torch.topk(x, k)
    return values, indices

@topk_with_idx.register_fake
def _(x, k):
    new_shape = list(x.shape)
    new_shape[-1] = k
    return torch.empty(new_shape, dtype=x.dtype), torch.empty(new_shape, dtype=torch.int64)

返回 Tuple[Tensor, ...] 时 schema 自动是 -> (Tensor, Tensor)。fake 函数也返回 tuple。

inplace + 多输出组合

@torch.library.custom_op("mylib::layernorm_inplace", mutates_args=("x", "running_mean"))
def layernorm_inplace(
    x: torch.Tensor,
    running_mean: torch.Tensor,
    weight: torch.Tensor,
) -> torch.Tensor:
    # 修改 x 与 running_mean, 返回新 tensor
    ...

复杂场景里这套语法要小心写。schema 错了 → AOTAutograd 会在 trace 时报”functionalize 失败”。opcheck 内置 functionalize 检查能在 commit 前发现这类问题(§22.6.5)。

22.6.10 register_kernel:每个 device 单独注册

@custom_op 的函数体是 op 的默认实现(CompositeImplicitAutograd key)。如果你想为不同 device 写专门 kernel,用 register_kernel

@torch.library.custom_op("mylib::mymul", mutates_args=())
def mymul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    # 默认实现 (eager 路径用)
    return x * y

@mymul.register_kernel("cuda")
def _(x, y):
    # CUDA 专用: 调 Triton kernel
    out = torch.empty_like(x)
    grid = (triton.cdiv(x.numel(), 128),)
    my_triton_kernel[grid](x, y, out, x.numel())
    return out

@mymul.register_kernel("cpu")
def _(x, y):
    # CPU 专用: 调 OpenMP kernel
    out = torch.empty_like(x)
    my_cpp_extension.cpu_mul(x, y, out)
    return out

@mymul.register_kernel("xpu")
def _(x, y):
    # Intel XPU 专用
    return x * y    # 通用 fallback

dispatcher(§5.x)根据 input device 自动路由。这套机制让”一个 op 多 backend”不需要写 dispatch 逻辑、PyTorch 帮你做。

实战例子:FlashAttention 的 PyTorch 接入:

  • 默认实现:调用 F.scaled_dot_product_attention(fallback)
  • CUDA:调自家 CUDA kernel(Hopper / Ampere 各一份)
  • CPU:调 PyTorch 通用 attention(性能差但能跑)

理解 register_kernel 让你看到自定义算子的”多后端”不需要复杂代码 —— 装饰器 + dispatcher 自动协作。

22.6.11 JIT 加载 C++ 扩展:开发期免编译

§22.4 用 setup.py 编译 C++ 扩展,每次改完要重新 build。开发期更方便的方式是 torch.utils.cpp_extension.load

import torch.utils.cpp_extension as cpp_ext

my_ops = cpp_ext.load(
    name='my_ops',
    sources=['my_ops.cpp', 'my_kernel.cu'],
    extra_cflags=['-O3'],
    extra_cuda_cflags=['-O3', '-arch=sm_90'],
    verbose=True,
)

# 直接用
out = my_ops.mymul(x, y)

load 内部:

  1. 把 sources 编译成 .so(首次几十秒)
  2. 缓存到 ~/.cache/torch_extensions/
  3. 后续相同 sources 命中缓存(毫秒级)
  4. 改了 source 自动重编

适合开发场景:写 / 改 / 测的循环里不用每次跑 setup.py install

进阶:load_inline 让你直接传 C++ source 字符串、不用文件:

my_ops = cpp_ext.load_inline(
    name='inline_ops',
    cpp_sources='''
        torch::Tensor add_one(torch::Tensor x) {
            return x + 1;
        }
    ''',
    functions=['add_one'],
)

适合写小 demo / unit test。生产代码仍用 setup.py + .so 文件(避免每次进程启动都编译)。

实战:研究迭代算法时,load_inline + Jupyter notebook 让你能像写 Python 一样快速迭代 C++ kernel。这套工程便利极大降低了”写 C++ 扩展”的心智门槛。

22.6.12 ABI 兼容性:跨 PyTorch 版本的痛点

C++ 扩展编译出的 .so 对 PyTorch 版本敏感。原因:

  • libtorch C++ ABI 不冻结:PyTorch 团队在 v2.x 多次重构内部 API
  • CUDA Toolkit 版本:编译用 12.4、运行时 12.5+ OK;但 12.4 → 11.8 不行
  • Compiler ABI:gcc 7 编译的 .so 在 gcc 11 系统上可能报 undefined symbol

实战遇到的 ABI 错误:

ImportError: undefined symbol: _ZN3c104impl21py_handle_tdiFEPN10pybind11_4dictE

——pybind11 内部 symbol 在 PyTorch v2.4 与 v2.6 之间改了 mangling。

解决方案:

1. 锁版本 + per-version build

# 用户安装时根据 PyTorch 版本下载对应 wheel
pip install my-extension==0.1.0+pt2.6
pip install my-extension==0.1.0+pt2.4

每个 PyTorch 主版本编一份 wheel。

2. 用 LIBTORCH_USE_GLIBCXX_ABI

# 编译时指定 ABI
TORCH_CUDA_ARCH_LIST="8.0;9.0" \
LIBTORCH_USE_GLIBCXX_ABI=1 \
python setup.py bdist_wheel

让生成的 .so 与 PyTorch 内部 ABI 对齐。

3. JIT load (§22.6.11)

绕过 ABI 问题:用户机器现场编 → 自动用当前 PyTorch 的 ABI。代价是首次启动慢。

4. AOTI 路径

把自定义算子打包进 .pt2(§15.6.21),让 AOTI runtime 加载。AOTI 内部把 ABI 抽象掉,跨版本兼容性更好。

实战:开源 PyTorch 扩展(如 FlashAttention、xformers)维护团队都把”per-PyTorch-version build matrix”放在 CI 里。生产代码部署时锁住 PyTorch + 扩展版本。这是 C++ 扩展不可避免的工程税,优先用纯 Python + Triton(§22.3)能完全避开 ABI 问题

22.6.13 Composite Implicit Autograd:算子的 decomposition

PyTorch 内置 op 有几类 autograd 处理方式:

Autograd Key含义
Autograd显式注册反向规则(如 mmlinear,硬编码反向)
CompositeImplicitAutogradop 内部调其他 op,autograd 自动追踪(不需要写反向)
CompositeExplicitAutogradcomposite 但显式标 autograd-eligible
AutogradPrivateUse1厂商自家硬件的 autograd 实现

自定义算子默认是 CompositeImplicitAutograd —— 函数体调其他可微算子,autograd 自动追踪。这种 op 不需要写 register_autograd

@torch.library.custom_op("mylib::my_attention", mutates_args=())
def my_attention(q, k, v):
    # 内部调 ATen 算子, 全部可微
    scores = q @ k.transpose(-2, -1)
    attn = scores.softmax(-1)
    return attn @ v
# 不需要 register_autograd! autograd 自动通过 mm + softmax + mm 追踪

如果用了 Triton kernel / C++ kernel,autograd 看不到内部 op,必须 register_autograd:

@torch.library.custom_op("mylib::triton_attention", mutates_args=())
def triton_attention(q, k, v):
    # Triton kernel 内部 op autograd 看不到
    return my_triton_kernel(q, k, v)

# 必须显式注册反向
def backward(ctx, grad_out):
    q, k, v = ctx.saved_tensors
    return triton_backward_kernel(q, k, v, grad_out)

理解这两套路径让你写自定义算子时知道”何时需要 register_autograd”。简单 Python composite → 不需要;Triton/C++ kernel → 必须。

PyTorch 内部很多算子是 CompositeImplicitAutograd,让 ATen 代码生成不需要为每个 op 写反向。这套设计让 PyTorch 几千算子的反向规则维护成本可控。

22.6.14 Triton autotune:让 kernel 自动找最优配置

写 Triton kernel 时关键参数(block size / num_warps / num_stages)需要为每个硬件 / shape 调优。手动调耗时,Triton 内置 autotune 自动搜索

import triton
import triton.language as tl

@triton.autotune(
    configs=[
        triton.Config({'BLOCK_SIZE': 128}, num_warps=4, num_stages=2),
        triton.Config({'BLOCK_SIZE': 256}, num_warps=4, num_stages=2),
        triton.Config({'BLOCK_SIZE': 256}, num_warps=8, num_stages=3),
        triton.Config({'BLOCK_SIZE': 512}, num_warps=8, num_stages=4),
        # ... 更多配置
    ],
    key=['n'],    # n 不同时重新选 config
)
@triton.jit
def my_kernel(x_ptr, y_ptr, out_ptr, n: tl.constexpr,
              BLOCK_SIZE: tl.constexpr):
    pid = tl.program_id(0)
    offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n
    x = tl.load(x_ptr + offsets, mask)
    y = tl.load(y_ptr + offsets, mask)
    tl.store(out_ptr + offsets, x * y, mask)

工作机制:

  1. 第一次某个 n 调用时,autotune 跑所有 configs、测每个的 GPU 时间
  2. 选最快的 config
  3. 缓存到 (kernel, n) → best_config 映射
  4. 后续相同 n 直接用 best_config

cost:第一次跑慢几十 ms(要试几个 config),后续命中缓存零开销。生产代码 warmup 阶段触发 autotune、之后稳态运行。

进阶:

  • prune_configs_by 让你写自定义函数过滤掉 illegal config(如 BLOCK_SIZE 太大超 shared memory)
  • reset_to_zero 让某些 input 在每次 autotune trial 后清零(避免累积副作用)
  • do_bench 自定义 benchmark 函数

实战:FlashAttention v2 / v3 内部用了 几十个 config × 几十种 shape 的 autotune 矩阵,让单个 kernel 在不同 GPU + 不同 shape 都接近 hardware peak。理解 autotune 让你看到现代 SOTA kernel 的工程实质:不是手写一个完美 kernel,是搜索空间 + 自动调优

22.6.15 vmap × custom_op:批量化的自动支持

vmap(functorch / torch.func.vmap)让 op 自动批量化:

def add(x, y):
    return x + y

batched_add = torch.func.vmap(add)
# batched_add 接受 [B, ...] 输入, 内部 batched 算 add

PyTorch 内置 op 的 vmap 规则已经写好。自定义 op 默认 vmap 会失败

@torch.library.custom_op("mylib::mymul", mutates_args=())
def mymul(x, y):
    return x * y

torch.func.vmap(mymul)(x, y)
# 报错: vmap rule not registered for mylib::mymul

需要 register_vmap

@mymul.register_vmap
def _(info, in_dims, x, y):
    # in_dims: 输入 tensor 沿哪个维度 batch
    # 实现: 把 vmap 输入展开成 normal call
    x_dim, y_dim = in_dims
    if x_dim is not None and y_dim is None:
        y = y.unsqueeze(x_dim).expand_as(x)
    elif y_dim is not None and x_dim is None:
        x = x.unsqueeze(y_dim).expand_as(y)
    out = mymul(x, y)
    out_dim = x_dim if x_dim is not None else y_dim
    return out, out_dim

实战工作量:复杂 op 的 vmap rule 比 forward 还难写。简单做法:默认 register_vmap 不实现,文档说”vmap 不支持”,让用户避开 vmap。生产代码里 vmap 用户少(functorch 主要给研究用),多数自定义 op 不写 vmap rule 也能跑。

理解 vmap 的存在让你知道 PyTorch 的”自动批量化”也是抽象层 + 各 op 单独支持。custom_op 想完整融入 PyTorch 生态需要 fake / autograd / vmap / dispatch 多层注册。

22.6.16 自定义 op 注册到 Inductor lowering

torch.compile 看到自定义 op 时,默认走 fallback 路径:直接调用原 op、不与周围算子 fuse。如果你想让 Inductor 真正编译你的 op(fuse 到 Triton kernel 里),用 register_lowering

from torch._inductor.lowering import lowerings, register_lowering
from torch._inductor.ir import Pointwise

@register_lowering(torch.ops.mylib.mymul)
def mymul_lowering(x, y):
    # 返回 Inductor IR (Pointwise)
    return Pointwise.create(
        device=x.get_device(),
        dtype=x.get_dtype(),
        inner_fn=lambda idx: x.make_loader()(idx) * y.make_loader()(idx),
        ranges=x.get_size(),
    )

效果:torch.compile 看到 mymul(a, b) + c 时,不是”调 mymul kernel + 调 add kernel”,而是直接把 mymul 的语义编译进同一个 fused Triton kernel —— 真正的 op fusion。

适用:

  • 简单算子(pointwise / reduction):写 lowering 让 Inductor 优化
  • 复杂算子(attention / GEMM):保留 fallback,让 Inductor 当黑盒处理

PyTorch 内置 ATen op 都有 lowering,自定义 op 默认没有。写 lowering 是性能极致场景才做的工作 —— FlashAttention 等 SOTA op 已经够快、不需要再 fuse 进周围算子;普通 element-wise op 写 lowering 收益巨大。

理解 lowering 让你看 Inductor 不是”魔法编译器”,是 lowering registry 驱动的代码生成器。每个 op 一行 lowering 让它进入编译路径。

22.6.17 完整 FlashAttention 接入路径

把全章话题合起来看 FlashAttention 这种 SOTA op 怎么完整接入 PyTorch:

graph TB
    FA[FlashAttention CUDA kernel]
    FA --> CO[custom_op 装饰器<br/>schema: q, k, v -> out]
    CO --> Fake[register_fake<br/>shape 推导]
    CO --> Auto[register_autograd<br/>反向 = 另一个 FA backward kernel]
    CO --> Cuda[register_kernel cuda<br/>调实际 CUDA kernel]
    CO --> Cpu[register_kernel cpu<br/>调用 fallback SDPA]

    Fake --> Compile[torch.compile 兼容]
    Auto --> Backward[autograd Engine 调度反向]
    Cuda --> Eager[eager 路径]

    style FA fill:#fef3c7,stroke:#f59e0b,stroke-width:2px
    style Compile fill:#dcfce7

代码骨架:

@torch.library.custom_op("mylib::flash_attention", mutates_args=())
def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
    # 默认实现 (CPU fallback)
    return F.scaled_dot_product_attention(q, k, v)

@flash_attention.register_kernel("cuda")
def _(q, k, v):
    # 调真实 CUDA kernel (FlashAttention v3)
    return _flash_attn_v3.forward(q, k, v)

@flash_attention.register_fake
def _(q, k, v):
    return torch.empty_like(q)

def fa_backward(ctx, grad_out):
    q, k, v, out, lse = ctx.saved_tensors    # lse = log-sum-exp, FA 内部产物
    grad_q, grad_k, grad_v = _flash_attn_v3.backward(grad_out, q, k, v, out, lse)
    return grad_q, grad_k, grad_v

def fa_setup_context(ctx, inputs, output):
    q, k, v = inputs
    out, lse = output_with_lse(q, k, v)    # 实际场景里 forward 输出 lse
    ctx.save_for_backward(q, k, v, out, lse)

flash_attention.register_autograd(fa_backward, setup_context=fa_setup_context)

# 测试
import torch
opcheck(flash_attention, args=(torch.randn(2, 8, 1024, 64, device='cuda'),) * 3)

部署后用法与内置 op 完全一致:

out = torch.ops.mylib.flash_attention(q, k, v)
# 或者 monkey-patch F.scaled_dot_product_attention 让全局透明用 FA

理解这套接入让你看 FlashAttention 不是”独立库”,是经过 PyTorch custom_op 接入的 first-class 算子。所有 PyTorch 用户能像用 mm 一样用它。custom_op 是 PyTorch 生态吸纳新 SOTA 算子的标准接口

22.6.18 自家 AI 芯片完整接入 PyTorch 的工程

国产 AI 芯片厂商把硬件接进 PyTorch 是几十人月的系统工程。完整路径:

第 1 阶段:基础 backend

# 1. 注册 PrivateUse1 → "npu"
torch.utils.rename_privateuse1_backend("npu")

# 2. 实现 device guard / stream / event 抽象
class NPUStream(...): ...
class NPUEvent(...): ...

# 3. 注册到 PyTorch
torch._C._jit_register_npu_backend(...)

# 4. tensor.npu() 等 method
torch.utils.generate_methods_for_privateuse1_backend()

第 2 阶段:算子实现

# 给最常用的 200-500 个算子各写 NPU impl
# 用 codegen 减少手写代码
@torch.library.impl("aten::add.Tensor", "PrivateUse1")
def npu_add(self, other, alpha=1):
    return _npu_runtime.add(self, other, alpha)

@torch.library.impl("aten::mm", "PrivateUse1")
def npu_mm(self, mat2):
    return _npu_runtime.gemm(self, mat2)

# ... 几百个算子 ...

第 3 阶段:CommunicationBackend (NCCL 替代)

class NPUCommBackend(ProcessGroup):
    def allreduce(self, tensors, opts): ...
    def allgather(self, output_tensors, input_tensors, opts): ...
    # ... 实现完整 c10d ProcessGroup 接口 ...

torch.distributed.Backend.register_backend("hccl", create_npu_comm)

第 4 阶段:编译栈集成

# 给 torch.compile 注册自家 backend
@torch._dynamo.register_backend
def npu_compiler(fx_graph, example_inputs):
    # 调自家编译器把 fx_graph 编译成 NPU binary
    return npu_compile(fx_graph, example_inputs)

# 用法
@torch.compile(backend="npu_compiler")
def model(x): ...

第 5 阶段:训练 / 推理生态

  • FSDP / DDP 适配(用 hccl backend)
  • AMP / bf16 支持
  • safetensors / DCP 集成
  • profile + Kineto 自家 backend

整个工程量级:

阶段工程量说明
基础 backend1-2 人月设备 / 流抽象
算子实现6-12 人月200+ 算子
通信 backend1-2 人月完整 c10d 接口
编译集成3-6 人月自家 graph compiler
生态适配2-4 人月FSDP / AMP / profile
合计15-25 人月一个团队 5 人 3-5 个月

torch_npu(华为)、torch_mlu(寒武纪)、torch_xpu(Intel)都走过这条路。开源代码可以 GitHub 看完整例子。custom_op + PrivateUse1 是国内 AI 芯片厂商生态参与 PyTorch 的核心入口,不需要 fork 主仓代码。

理解这条路径让你看 PyTorch 不是 NVIDIA 专属,是真正”硬件中立”的开放生态。

22.6.19 自定义算子的演进时间线

PyTorch 自定义算子 API 的几个关键节点:

版本主流 API特点
v0.4 (2018)autograd.Function简单但与编译栈不兼容
v1.0 (2018)+ torch.utils.cpp_extensionC++ kernel 接入
v1.5 (2020)+ TORCH_LIBRARY C++ 宏注册到 dispatcher
v1.10 (2021)+ torch.library.Library Python API替代部分 C++ 宏
v1.13 (2022)+ meta tensor / fake 概念编译路径前置
v2.0 (2023)+ register_fake 等torch.compile 兼容
v2.4 (2024)torch.library.custom_op 装饰器现代标配
v2.4+ opcheck 自动测试合规性检查
v2.6 (2025)+ register_kernel 优化 + lowering 接口完善与 Inductor 深度集成
v2.10 (2025)+ 完整 functorch/vmap 集成全 PyTorch 生态兼容
v2.11 (2026)API 稳定生产级别成熟

整体趋势:

  • v1.x:从 autograd.Function(仅 autograd)到 TORCH_LIBRARY(完整 dispatcher)
  • v2.x:从分散 API 收敛到 custom_op 装饰器一站式
  • v2.4+:与编译栈、量化、distributed 深度集成

理解时间线让你看到自定义 op 不是一开始就这么好用 —— 经过几年迭代才达到”10 行 Python 装饰器”的体验。生产代码用最新 API(custom_op)能省最多事。

22.6.20 常见 bug 排查 cheat sheet

实战写自定义 op 遇到的报错与解法:

报错根因解决
RuntimeError: ... shape mismatch 在 compile 但不在 eagerfake 函数 shape 推导错检查 fake 返回 shape 是否与真实 kernel 一致
Expected at most 0 ... got Xschema 字符串与函数签名不匹配type hint 改对 / schema 显式
mutates_argsfunctionalize 假设无副作用、kernel 实际 mutate加正确 mutates_args=("x",)
Dynamo Unsupported: ... graph break未注册 fake / Dynamo 看不进 opregister_fake 或 allow_in_graph
gradcheck 失败反向规则数值不对用 finite-diff 一步步验证、或 torch.autograd.functional.jacobian 比对
inductor lowering not registered没注册 Inductor lowering(fallback 到 eager)写 register_lowering 或接受 fallback
ABI undefined symbolC++ 扩展与 PyTorch 版本不匹配重新编译 / 用 JIT load
vmap rule 没注册functorch 不知道怎么批量化 opregister_vmap 或文档声明不支持
autograd 反向时 saved_tensors 是 Nonesetup_context 没保存setup_context 里调 ctx.save_for_backward
opcheck test_aot_dispatch_dynamic failfake 函数没处理 SymInt 输入fake 里所有 shape 操作改用 SymInt-friendly API

把这张表存到内部 wiki,新人写自定义 op 时遇到报错对照查 → 节省至少 3 天试错时间。

22.6.21 export 与自定义算子

torch.export(§12.8.28)把 model 导成 ExportedProgram,给部署用。自定义算子在 export 路径的处理:

@torch.library.custom_op("mylib::flash_attention", mutates_args=())
def flash_attention(q, k, v): ...

class MyModel(nn.Module):
    def forward(self, x):
        q, k, v = split(x)
        return torch.ops.mylib.flash_attention(q, k, v)

# 导出
exported = torch.export.export(MyModel(), example_inputs)

# ExportedProgram 内部的 fx graph 含 mylib::flash_attention 节点
print(exported.graph)

ExportedProgram 内部用 op 的完整 fqnmylib::flash_attention)记录,而不是 inline op body。部署时:

  • AOTI:把 mylib::flash_attention 编译进 .so,runtime 调原 kernel
  • ExecuTorch:让 op 走 delegate 到目标硬件
  • ONNX:自定义 op 没标准化 → 报错(除非用 onnx custom domain)

为让自定义 op 能 export:

  1. 必须有 register_fake(export 用 FakeTensor 跑)
  2. schema 要稳定(不能动态加参数)
  3. 不能有 graph break(complex Python logic)

实战:FlashAttention 等 SOTA op 都已 export-friendly。自家研究算子如果要部署,一开始就按 export 兼容写。v2.x 之后”导得出 vs 导不出”是判断 op 工程级别的关键 metric

22.6.22 自定义 op 性能调优 flow

写完一个 custom_op 跑通后,通常发现”比预期慢”。调优流程:

flowchart TD
    Slow[op 慢]
    Slow --> P1[1. profile 看 op 在 trace 里占多少]
    P1 --> Q1{是 op 内部慢, 还是 op 外部 dispatch 慢?}

    Q1 -->|op 内部| Q2{kernel 是否 launch 多次?}
    Q1 -->|dispatch 多| FUSE[让 op 接受更大 input<br/>减少 dispatch 次数]

    Q2 -->|是| Bundle[bundle 多次小 launch 成一次大 launch]
    Q2 -->|否| Q3{Tensor Core 利用率?}

    Q3 -->|低| Align[shape padding 到 16 倍数<br/>+ 用 fp16/bf16]
    Q3 -->|高| MB[memory bound<br/>看能否减少 read/write]

    Slow --> P2[2. 看 with torch.compile 是否能 fuse]
    P2 --> Lower[实现 register_lowering<br/>让 op 进入 fusion]

    Slow --> P3[3. 比对竞品 baseline]
    P3 --> Algo[换更优算法<br/>FA v2 → v3 → ...]

    style P1 fill:#fef3c7
    style P2 fill:#dcfce7
    style P3 fill:#dbeafe

实战 case(自家写的 fused RMSNorm op):

第 1 轮 profile:op 占总时间 8%,但 RMSNorm 数学上只是 mean + rsqrt + scale → 应该 < 1%。 看 trace:op 内部 launch 4 个 kernel(mean / sqrt / rsqrt / scale)→ 应该 fuse 成 1 个。 修复:手写 Triton kernel 把 4 步合到一个 → 1.5%。

第 2 轮:仍比 NVIDIA TransformerEngine 的 RMSNorm 慢 30%。 profile metric:SM 占用率 70%(对方 95%)。 修复:调 BLOCK_SIZE / num_warps(autotune),找到最优配置 → 性能匹配 TE。

整套调优 1-2 天。关键是 profile 驱动——每步看数据找根因,不靠猜。

22.6.23 multi-level dispatch:算子的多层 fallback

dispatcher(§5.x)按 priority 调用算子:从最具体 device 找到最通用 fallback。custom_op 也参与这套机制。

graph TB
    Call[mymul x y]
    Call --> D[dispatcher]
    D --> D1{x is on CUDA?}
    D1 -->|是| K1[找 CUDA impl]
    K1 -->|找到| RunCuda[运行 CUDA kernel]
    K1 -->|没有| K2[找 CompositeImplicitAutograd]
    K2 -->|找到| RunComp[运行默认实现]
    K2 -->|没有| Fail[报错: 没注册]

    D1 -->|是 PrivateUse1| KP[找 PrivateUse1 impl]
    KP -->|找到| RunNpu[运行 NPU kernel]
    KP -->|没有| K2

    style RunCuda fill:#dcfce7
    style RunComp fill:#fef3c7
    style RunNpu fill:#dbeafe

priority 顺序(精简):

  1. AutogradXxx(具体 device):训练时优先
  2. Xxx(具体 device):CPU / CUDA / MPS / PrivateUse1
  3. CompositeImplicitAutograd:用其他 op 拼出来的默认实现
  4. CompositeExplicitAutograd:显式标记的 composite

每层都可以注册自家 impl。fallback 链让自定义 op 在缺失某 device 实现时仍能跑(虽然慢):

@torch.library.custom_op("mylib::mymul", mutates_args=())
def mymul(x, y):
    return x * y    # 默认 (CompositeImplicitAutograd)

@mymul.register_kernel("cuda")
def _(x, y):
    return my_cuda_kernel(x, y)    # CUDA fast path

# 没注册 cpu impl?
# CPU input 调 mymul → 找不到 CPU impl → 走默认 (composite) → x * y

CPU 用户能跑(虽然慢),CUDA 用户用快路径。优雅 fallback 让自定义 op 通用。

理解 multi-level dispatch 让你看 PyTorch 的”扩展性”——每个 op 可以为 N 个 device 注册 N 份实现,dispatcher 自动选最快的。这是单一 codebase 支持几十种硬件的工程基础。

22.6.24 SOTA op 接入示例:开源生态中的 5 个典型 case

把全章话题落到具体例子,5 个开源 SOTA op 的接入方式:

1. FlashAttention (Tri Dao)

  • 路径:CUDA kernel → flash_attn Python wrapper → custom_op 注册到 PyTorch
  • 全套:fake / autograd / register_kernel(“cuda”) + (“cpu” fallback)
  • v2.4+ PyTorch 内置 SDPA 自动用 FA v2/v3

2. xformers

  • 路径:CUDA + Triton kernel → 自家 wrapper → 部分注册成 PyTorch op
  • 不全用 custom_op(早于 v2.4 出现),有些走 autograd.Function
  • v2.x 时代逐步迁到 custom_op

3. Liger Kernel (Linkedin)

  • 路径:纯 Triton kernel(fused RMSNorm / GeGLU / RoPE 等)
  • 全 Python:@triton.jit + @torch.library.custom_op
  • 标杆”Triton + custom_op”现代实践

4. bitsandbytes (8-bit / 4-bit ops)

  • 路径:自家 CUDA kernel → C++ extension
  • 部分注册成 PyTorch op,部分仍是函数式
  • 走 PrivateUse1 / 自定义 dtype 路径

5. Apex (NVIDIA)

  • 路径:纯 CUDA kernel + setup.py 编译 .so
  • 老一代实践,许多 op 是 autograd.Function
  • 现代被 PyTorch 内置取代(fused LayerNorm 等已进 mainline)

观察:

  • 新项目都用 Triton + custom_op:比 CUDA + setup.py 简单 10x
  • 老项目逐步迁移:Apex 等老库的功能逐渐进 PyTorch 主仓
  • 企业级 (NVIDIA / Meta / Google)仍写 CUDA kernel:性能极致 + 控制 ABI

理解这些案例让你看到 PyTorch 自定义 op 生态的全貌:研究项目 → Triton + Python,生产 SOTA → CUDA + 完整 custom_op,硬件厂商 → PrivateUse1 完整接入。每条路径有自己的 trade-off。

22.6.25 functorch 高阶变换:grad / jacrev / vmap 组合

functorch(v1.13+ 内置 torch.func)提供”函数变换”:把可微函数变成它的梯度、Jacobian、Hessian 等。custom_op 想被这些 transform 用,需要满足条件:

import torch
from torch.func import grad, jacrev, vmap

@torch.library.custom_op("mylib::squared", mutates_args=())
def squared(x: torch.Tensor) -> torch.Tensor:
    return x ** 2

@squared.register_fake
def _(x):
    return torch.empty_like(x)

def squared_backward(ctx, grad_out):
    x, = ctx.saved_tensors
    return 2 * x * grad_out

def squared_setup(ctx, inputs, output):
    ctx.save_for_backward(inputs[0])

squared.register_autograd(squared_backward, setup_context=squared_setup)

# 现在能用 functorch transforms
gradient_fn = grad(squared)
print(gradient_fn(torch.tensor(3.0)))    # 6.0 = 2 × 3

jacobian_fn = jacrev(squared)
print(jacobian_fn(torch.tensor([1.0, 2.0, 3.0])))    # diag([2, 4, 6])

工作机制:functorch 通过 dispatcher 调 register_autograd 注册的反向规则。只要 register_autograd 正确,所有 functorch transform 自动可用 —— 不需要单独 register_grad / register_jacrev。

特殊情况:

  • vmap(grad(f)) 这种组合需要 register_vmap(§22.6.15)
  • 二阶导数 (grad(grad(f))) 要求反向函数自身可微 —— register_autograd 的 backward 函数里调的 op 都得是可微 op,不能是 detached value
  • forward-mode AD (jvp) 需要 register_jvp(实验性 API)

实战:研究项目用 functorch 多,custom_op 写正确反向就够。生产 LLM 训练几乎不用 jacrev / hessian(model 太大算不动),functorch 主要给 second-order optimizer / 物理模拟等场景。

理解 functorch 兼容性让你看 custom_op 的”完整生态接入”含义 —— 不只是 forward + backward,还要支持函数变换。

22.6.26 ABI-stable C++ 扩展:v2.6+ 实验性新路径

§22.6.12 提了 ABI 兼容性是 C++ 扩展的痛点。PyTorch v2.6+ 在 torch.csrc.stable namespace 引入 ABI-stable API

#include <torch/csrc/stable/library.h>

// 用 stable API 而非内部 ABI
TORCH_LIBRARY(mylib, m) {
    m.def("mymul(Tensor x, Tensor y) -> Tensor");
}

// stable API 不暴露内部数据结构
torch::stable::Tensor my_mul_cuda(torch::stable::Tensor x, torch::stable::Tensor y) {
    return torch::stable::ops::mul(x, y);
}

TORCH_LIBRARY_IMPL(mylib, CUDA, m) {
    m.impl("mymul", my_mul_cuda);
}

保证:

  • 跨 minor version 兼容:v2.6 编的 .so 在 v2.7 + 加载 OK
  • 不暴露内部 type:仅 stable_tensor / stable_scalar 等
  • 限制 API 集合:只能用 stable namespace 里的函数(约 200 个,覆盖常用场景)

代价:

  • API 比内部 ABI 受限,复杂操作要回退到 unstable
  • 性能略低 1-2%(额外 ABI 转换开销)
  • 仍在实验,几个版本可能调整

适用场景:长期维护的开源 PyTorch 扩展(如 FlashAttention、xformers)—— 不用每次 PyTorch 升级都 rebuild。

短命扩展(自家研究 prototype)继续用普通 C++ 扩展即可。理解这条新路径让你看 PyTorch 团队对 “ABI 痛点”的工程响应——把痛点收编进框架本身解决,而不是让用户每家自己处理。

22.6.27 distributed 训练里的 custom op

custom_op 在分布式训练里要注意:

1. collective 算子用 functional API

# 错误: 用老 inplace API
@torch.library.custom_op("mylib::ring_attention", mutates_args=())
def ring_attention(q, k, v, group):
    out = q @ k.transpose(-2, -1)
    dist.all_reduce(out, group=group)        # ← inplace, functionalize 会失败
    return out @ v

# 正确: 用 functional collectives (§16.7.9)
import torch.distributed._functional_collectives as funcol

@torch.library.custom_op("mylib::ring_attention", mutates_args=())
def ring_attention(q, k, v, group):
    out = q @ k.transpose(-2, -1)
    out = funcol.all_reduce(out, "sum", group)    # ← functional, compile 友好
    return out @ v

2. process_group 不能直接放 schema

ProcessGroup 不是 tensor,不能作为 op 输入。变通:用 group_name (str) 在 op 内部 lookup:

@torch.library.custom_op("mylib::ring_attention", mutates_args=())
def ring_attention(q, k, v, group_name: str):
    group = dist.distributed_c10d._resolve_process_group(group_name)
    ...

3. FSDP-2 / DTensor 协作

DTensor(§18.6.6)有 placement 概念。custom_op 默认不支持 DTensor 输入:

@register_dtensor_dispatch(torch.ops.mylib.ring_attention)
def _(q_dt, k_dt, v_dt, group_name):
    # 显式处理 DTensor placement
    ...

实战:如果 custom_op 要在 FSDP-2 / DTensor 模型里用,必须实现 DTensor dispatch,否则 placement 信息丢失。

4. NCCL communicator caching

custom_op 内部如果调 NCCL,要确保用同一个 communicator(§16.7.5)。lookup 一次后 cache:

_comm_cache = {}

def get_comm(group_name):
    if group_name not in _comm_cache:
        group = dist.distributed_c10d._resolve_process_group(group_name)
        _comm_cache[group_name] = init_nccl_comm(group)
    return _comm_cache[group_name]

理解分布式 custom_op 的这些坑让你写”适配多卡”的自定义算子时不会踩雷。生产 LLM 训练里 custom_op 与 FSDP / TP / PP 协作是真实需求(如自家 attention 实现要兼容现有训练栈)。

22.6.28 推理引擎中的 custom_op:vLLM / SGLang 实例

LLM 推理引擎 vLLM / SGLang / TensorRT-LLM 都大量用自定义 op。具体实现观察:

vLLM 的 attention kernel

# vllm/attention/backends/flash_attn.py
@torch.library.custom_op("vllm::flash_attn_varlen", mutates_args=())
def flash_attn_varlen(
    q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
    cu_seqlens_q: torch.Tensor, cu_seqlens_k: torch.Tensor,
    max_seqlen_q: int, max_seqlen_k: int,
) -> torch.Tensor:
    return _flash_attn_v3.varlen_forward(...)

@flash_attn_varlen.register_fake
def _(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k):
    return torch.empty_like(q)

vLLM 把所有”长度变化的 attention”包成 custom_op,让 torch.compile 能 capture,配合 piecewise CUDA Graph(§15.6.16)实现高吞吐推理。

SGLang 的 paged attention

@torch.library.custom_op("sglang::paged_attn", mutates_args=("output",))
def paged_attn(
    output: torch.Tensor,    # mutates 输出 tensor (KV cache 持续累积)
    query: torch.Tensor,
    key_cache: torch.Tensor, value_cache: torch.Tensor,
    block_tables: torch.Tensor,
    seq_lens: torch.Tensor,
) -> None:
    _sglang_kernel.paged_attention(output, query, key_cache, value_cache,
                                     block_tables, seq_lens)

注意 mutates_args=("output",) 让 output 显式标 inplace。这是推理引擎与训练 op 不同的地方:推理时 KV cache 持续累积、必须 inplace 写入,无法走纯函数路径。

实战经验:

  • 推理引擎的 op 不需要 register_autograd:推理无反向,省工作
  • 必须 register_fake:CUDA Graph 与 torch.compile 都需要
  • mutates_args 要写正确:KV cache mutation 必须显式标
  • register_kernel(“cuda”) 调真实 CUDA kernel;CPU fallback 可选

理解推理引擎的 custom_op 用法让你看到 LLM 推理优化与 PyTorch 自定义 op 接口深度耦合。理解这套接口能让你看 vLLM / SGLang 源码不困惑,甚至自己往里加新算子。

22.6.29 算子注册的”产品哲学”

把全章合起来看,custom_op 接口的设计反映了 PyTorch 团队的几个产品决策:

1. “扩展是用户体验的一部分”

老 PyTorch(v1.x)扩展接口分散:autograd.Function、TORCH_LIBRARY、Library.impl()……每条路径覆盖一部分场景。结果:用户写自定义 op 痛苦、社区贡献 PR 质量参差不齐。

v2.4 收敛到 @torch.library.custom_op 一站式接口,把”如何扩展”变成产品的核心 UX。这是 PyTorch 从”研究框架”成熟为”工业级 ML 平台”的标志。

2. “fake / shape inference 是底座”

v2.x 把 fake 函数从可选变成”几乎必填”。这看起来增加了用户负担,实际是强制让所有 op 都能进入编译路径。否则 LLM 时代 torch.compile 会被零散的 op 不兼容拖累。

这条决策背后是产品判断:“未来所有人都会用 torch.compile”。所以提前要求 op 注册时声明 fake,保证生态顺滑迁移。

3. “Triton 取代 CUDA”

v1.x 时代写自定义 op 必经 C++ + CUDA。v2.x 推 Triton 作为首选,让 Python 工程师都能写 GPU kernel。降低门槛后社区贡献的高性能 op(Liger Kernel 等)数量爆增。

4. “PrivateUse1 给硬件中立”

不绑死 NVIDIA。提供完整 backend extension API 让国产 / 第三方芯片厂商接进来。这条决策让 PyTorch 在 NVIDIA 之外的硬件市场(华为、寒武纪、Intel Arc)保持竞争力。

理解这些产品决策让你看自定义 op 接口不只是”技术 API”,是 PyTorch 团队对”开放生态”的具体实现。每条接口设计选择背后都有商业 / 战略考量。

22.6.30 一段实战脚本:从零到生产 op

把全章的步骤合并成一个完整的实战脚本,写一个 fused “GeLU + Linear” op:

import torch
import triton
import triton.language as tl

# 第 1 步: Triton kernel
@triton.autotune(
    configs=[
        triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 32}, num_warps=4),
        triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32}, num_warps=4),
        triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_warps=8),
    ],
    key=['M', 'N', 'K'],
)
@triton.jit
def fused_gelu_linear_kernel(
    x_ptr, w_ptr, b_ptr, out_ptr,
    M, N, K,
    BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)
    # ... 实现 GeLU(X @ W + b) 的 fused kernel ...
    # 略 (完整实现 30+ 行)

# 第 2 步: custom_op 包装
@torch.library.custom_op("mylib::fused_gelu_linear", mutates_args=())
def fused_gelu_linear(
    x: torch.Tensor, w: torch.Tensor, b: torch.Tensor,
) -> torch.Tensor:
    M, K = x.shape
    K2, N = w.shape
    assert K == K2

    out = torch.empty(M, N, device=x.device, dtype=x.dtype)
    grid = (triton.cdiv(M, 128), triton.cdiv(N, 128))
    fused_gelu_linear_kernel[grid](x, w, b, out, M, N, K)
    return out

# 第 3 步: register_fake
@fused_gelu_linear.register_fake
def _(x, w, b):
    M, K = x.shape
    K2, N = w.shape
    return torch.empty(M, N, device=x.device, dtype=x.dtype)

# 第 4 步: register_autograd
def fgl_backward(ctx, grad_out):
    x, w, b = ctx.saved_tensors
    # ... 实现反向 ...
    grad_x = grad_out @ w.T * gelu_grad(x @ w + b)
    grad_w = x.T @ (grad_out * gelu_grad(x @ w + b))
    grad_b = grad_out.sum(0)
    return grad_x, grad_w, grad_b

def fgl_setup(ctx, inputs, output):
    ctx.save_for_backward(*inputs)

fused_gelu_linear.register_autograd(fgl_backward, setup_context=fgl_setup)

# 第 5 步: opcheck 验证
from torch.library import opcheck
x = torch.randn(64, 256, device='cuda', requires_grad=True)
w = torch.randn(256, 128, device='cuda', requires_grad=True)
b = torch.randn(128, device='cuda', requires_grad=True)
opcheck(fused_gelu_linear, args=(x, w, b))    # 通过

# 第 6 步: 集成到模型
class FusedFFN(torch.nn.Module):
    def __init__(self, dim, hidden):
        super().__init__()
        self.w = torch.nn.Parameter(torch.randn(dim, hidden))
        self.b = torch.nn.Parameter(torch.randn(hidden))

    def forward(self, x):
        return torch.ops.mylib.fused_gelu_linear(x, self.w, self.b)

# 第 7 步: torch.compile 验证
model = FusedFFN(256, 1024).cuda()
compiled = torch.compile(model)
out = compiled(x)
loss = out.sum()
loss.backward()    # 反向自动调 fgl_backward, fused 进 inductor graph

整套约 100 行 Python(不含 Triton kernel 实现)。从研究 idea 到生产 op 一周可达

  • Day 1:写 Triton kernel + 跑通 forward
  • Day 2:register_fake + register_autograd + opcheck
  • Day 3:vmap / Inductor lowering(若需要)
  • Day 4-5:性能调优 + autotune + benchmark
  • Day 6:集成到模型 + 与 baseline 对比 accuracy
  • Day 7:写 unit test + CI 集成

理解这套完整脚本让你看到”自定义 op”在 v2.x 时代不再是几人月的工程,而是一周的开发任务。门槛降低 → 创新加速 —— Triton + custom_op 让大量论文中的新算子能快速进 PyTorch 生态。

22.6.31 自定义 op 的版本兼容性策略

随着 PyTorch / 自家库迭代,自定义 op 的 schema 可能变化。生产代码必须考虑兼容性:

1. schema 演进的安全规则

修改是否 break 兼容
新增 op不 break(旧代码不调用就行)
新增 op 的可选参数(带默认值)不 break
新增 op 的必选参数break(旧代码不传新参数报错)
重命名 opbreak
改 input dtypebreak(schema 校验失败)
改 output shape 推导隐性 break(compile 后行为变)

实战做法:

  • 新功能加可选参数def my_op(x, y, *, optional_flag: bool = False) -> Tensor
  • deprecated 老 op,加新 op:保留 mylib::v1_op,新 ckpt 用 mylib::v2_op
  • schema 重大变化:bumping namespace(mylib::opmylib_v2::op

2. 与 PyTorch 版本的兼容

import torch
if torch.__version__ >= "2.4":
    @torch.library.custom_op("mylib::myop", mutates_args=())
    def myop(...):
        ...
else:
    # v2.4 之前的 fallback 写法
    class MyOp(torch.autograd.Function):
        ...

或用 try/except 兜底:

try:
    from torch.library import custom_op
except ImportError:
    # 老 PyTorch 没有这个 API
    custom_op = None

3. ckpt 兼容性

如果 op 是 model 的一部分,model state_dict 没区别(op 的实现不在 state_dict 里)。但用户代码必须能 import 到 op——升级时确保自家 op 库一并升级。

4. 渐进式 deprecation

import warnings

@torch.library.custom_op("mylib::old_op", mutates_args=())
def old_op(x, y):
    warnings.warn(
        "mylib::old_op is deprecated, use mylib::new_op instead",
        DeprecationWarning, stacklevel=2,
    )
    return new_op(x, y)

让用户有时间迁移,几个月后正式删除。

理解这些策略让你写自定义 op 时考虑”长期维护”,不是只考虑 v1。生产 op 一旦上线就要支持多年(用户的 ckpt 还在用),向前 / 向后兼容性是必修课。

22.6.32 op 注册的内部数据结构:从 schema 到 dispatcher

把全章话题落到底层数据结构。@torch.library.custom_op("mylib::mymul", ...) 在 PyTorch 内部最终落到几张表:

graph TB
    Decorator[custom_op 装饰器]
    Decorator --> Lib[Library 对象 mylib<br/>Python 层 wrapper]
    Lib --> CppLib[C++ Library<br/>持有 schema list]
    CppLib --> Dispatcher[Dispatcher 全局表<br/>OperatorHandle]

    Dispatcher --> Schema["schema string<br/>mymul(Tensor, Tensor) -> Tensor"]
    Dispatcher --> Kernels{各 dispatch key 实现表}
    Kernels --> CPU[CPU: lambda x,y: x*y]
    Kernels --> CUDA[CUDA: triton_kernel]
    Kernels --> Auto[AutogradCUDA: 自动包 backward]
    Kernels --> Fake[Meta/FakeTensor: register_fake fn]

    Decorator --> AutogradReg[autograd info<br/>setup_context + backward fn]
    AutogradReg --> Auto

    style Dispatcher fill:#fef3c7
    style Kernels fill:#dcfce7

具体源码位置(v2.x):

  • Python wrappertorch/library.py:CustomOpDef
  • C++ Librarytorch/csrc/api/include/torch/library.h:Library
  • Dispatcheraten/src/ATen/core/dispatch/Dispatcher.h:Dispatcher
  • OperatorHandleaten/src/ATen/core/dispatch/OperatorHandle.h

调用 mymul(x, y) 的内部路径:

  1. Python 调 torch.ops.mylib.mymul(x, y)
  2. C++ OperatorHandle.callBoxed(stack)
  3. Dispatcher 查 dispatch key set(input device + autograd state + …)
  4. 选最高 priority 的 kernel:典型 AutogradCUDA(如果 input requires_grad + 在 CUDA)
  5. AutogradCUDA kernel 是 PyTorch 自动生成的 wrapper:调 forward + 注册反向 Node
  6. forward 调底层 CUDA kernel(用户写的 Triton kernel)
  7. 反向时 autograd Engine(§8.x)调度 Node、最终调 register_autograd 注册的 backward fn

每一步都用 §5.x dispatcher 章讲过的同一套机制 —— 自定义 op 与内置 op 走完全相同的路径。这就是为什么”扩展与内置无差别”(§22.9 第一条设计启示)。理解这套数据结构让你看 PyTorch 的扩展机制不是黑盒,而是清晰的注册 + 查表系统。

22.7 几条工程经验

1. v2.4+ 用 torch.library.custom_op:替代老 TORCH_LIBRARY 宏 + Library.impl() 等手动调用

2. torch.library.opcheck(my_op, args) 是合规性测试:自动检查 fake / autograd / schema 等是否一致。生产 op 必跑

3. Triton kernel + custom_op 是写新算子的最优组合:性能、灵活性、与 compile 兼容性都好

4. mutates_args= 一定写正确:错了 functionalize 会出问题、torch.compile 编译错代码

5. 不要在 fake 函数里做实际计算:会让 torch.compile / FSDP 内存爆 / 性能崩

6. C++ 扩展跨 PyTorch 版本要重编:libtorch ABI 不保证版本兼容。每升级 PyTorch 重建 .so

7. PrivateUse1 是国产芯片接入路径:注册成新 backend 而非新算子,让所有现有算子都能跑

8. torch._dynamo.allow_in_graph 给某些函数特殊白名单:如果你的代码有 Dynamo 不识别但实际 trace-friendly 的部分,用这个绕过 graph break

9. 推理引擎用的 op 不需要 register_autograd:推理无反向,省一步工作。但 register_fake 仍必须

10. 跨 PyTorch 版本部署用 ABI-stable API(v2.6+):避免每升级 PyTorch 都重新编 .so 的工程税

11. distributed 训练里的 collective 必须用 functional APItorch.distributed._functional_collectives 替代 dist.all_reduce,不然 functionalize 会失败

12. 写 Triton kernel 必加 @triton.autotune:让 BLOCK_SIZE / num_warps 自动搜索,避免手调

22.8 跨书关联

  • 第 5 章 dispatcher:自定义 op 注册的底层机制
  • 第 6 章 ATen 代码生成:内置 op 是 codegen,自定义 op 是 register —— 两条路殊途同归
  • 第 7 章 autogradregister_autogradautograd.Function.backward 等价语义
  • 第 12-14 章 编译器栈:fake 函数让自定义 op 进入编译路径,register_lowering 让 op 真正被 Inductor fuse 而非走 fallback
  • 第 16 章 ProcessGroup:分布式训练里 custom_op 与 functional collectives 的协作
  • 第 18 章 FSDP-2 / DTensor:DTensor placement 与 custom_op 的 dispatch 协作
  • 第 21 章 Profiler:opcheck 与 profile 共同保证 op 正确性 + 性能符合预期

22.9 设计启示

PyTorch 自定义算子接口的核心思想:

第一让”扩展”与”内置”无差别:自定义 op 一旦注册就和 torch.add 一样工作。所有上层特性(autograd / compile / FSDP)零修改支持

第二fake 函数是高级特性的入场券:v2.x 之后任何 op 都得能 fake,否则被现代生态边缘化。这条变化看似增加用户负担,实际是 PyTorch 团队对”未来所有 op 都要进编译路径”的产品判断

第三多种 device 各注册一份 kernel:PrivateUse1 给国产芯片厂商完整的扩展能力,不需要 fork PyTorch 主仓,让硬件中立性成为生态扩展的基础设施

第四用装饰器替代宏 / Python 替代 C++:现代 API 让”写自定义 op”从需要 C++ + 宏的工程任务,降级到 10 行 Python 装饰器。这种”降低门槛 + 保留性能”的设计思想让 PyTorch 自定义 op 生态空前繁荣

第五fake / vmap / lowering 是”完整生态接入”的多个维度:每个新维度让 op 与一类 PyTorch 高级特性兼容(compile / functorch / fusion)。理解这种”渐进接入”让你知道 op 想用得上 X 特性需要注册哪个对应 hook

第六opcheck 把”扩展正确性”自动化:以前自家测 op 行为靠人工写测试,opcheck 自动覆盖 schema/autograd/fake/AOT 多条路径。这种”质量基础设施”的存在让社区能持续贡献高质量 op

22.10 跨章呼应:自定义 op 是这本书的”集大成”

把全章合起来看,自定义 op 几乎需要全书前面所有章节的知识:

写自定义 op 时用到对应章节
schema / IValue / ATen§6 ATen 代码生成
dispatcher 注册§5 dispatcher
TensorImpl / Storage§2 Tensor 数据结构
autograd Function / Engine§7-8 autograd
AOTAutograd functionalize§13 AOT Autograd
FakeTensor / register_fake§5.7 + §13
Inductor lowering / fusion§14 Inductor
torch.compile 协作§12-15 编译栈
AMP custom_fwd§20.5.19
FSDP / DTensor / collective§16-18 分布式
profile + opcheck§21 Profiler

写一个生产级 custom_op = 整本书的综合实践。这就是为什么把它放在最后一章(除 23 章哲学收束外)—— 它是检验前面知识掌握程度的”期末考试”。

新人写自定义 op 卡在哪一步,对应回去复习对应章节。这是本书的内部 cross-reference 网络的最后一环。

下一章是收官章 —— 拆 PyTorch 整体设计哲学与未来演进,把 22 章的内容串成一条主线,看从 Tensor 到 custom_op 这条 trace 上 PyTorch 团队留下了什么共通的设计原则。

评论 0