CUDA 算子工程:手写 FlashAttention v2 之路
附录 B · CUDA C++ vs Triton
附录 B · CUDA C++ vs Triton
B.1 Triton 是什么
Triton 是 OpenAI 在 2019 年开源的 GPU 编程 DSL(Domain-Specific Language)。核心理念是用 Python 语法、tile-level 抽象、编译器自动调优写 GPU kernel。
一个 Triton 风格的 GEMM:
import triton
import triton.language as tl
@triton.jit
def gemm_kernel(
a_ptr, b_ptr, c_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_K)
a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(0, K, BLOCK_K):
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k)
acc += tl.dot(a, b)
a_ptrs += BLOCK_K * stride_ak
b_ptrs += BLOCK_K * stride_bk
c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
tl.store(c_ptrs, acc.to(tl.float16))
短得惊人——~30 行 Python 写出一个 GEMM。等价的 CUDA C++ 大概 200+ 行。
B.2 Triton 的设计哲学
Triton 提供 tile-level(瓦片级)抽象:
- 程序员不需要管 thread——你只看到 BLOCK_M × BLOCK_N 大小的 tile。
- Triton 编译器自动决定 tile 内每个元素分配给哪个 thread/warp。
- Triton 编译器自动选择 SMEM 布局、ldmatrix vs cp.async、是否用 Tensor Core。
- Triton 有
triton.autotune装饰器自动搜索最优 tile size。
这种"声明式"抽象比 CUDA C++ 高一层——程序员描述意图("我要一个 BLOCK_M×BLOCK_N 的 GEMM tile"),编译器负责实现(thread 分配、SMEM 布局、寄存器分配)。
B.3 性能对比
工业级测量数据(不同 kernel 不同情况):
| 算子 | CUTLASS (CUDA C++) | Triton | 差距 |
|---|---|---|---|
| GEMM (大尺寸) | 100% baseline | 85-95% | -5 to -15% |
| GEMM (小尺寸) | 100% | 70-90% | -10 to -30% |
| LayerNorm | 100% | 100-110% | 持平或更快 |
| Softmax | 100% | 95-105% | 几乎相同 |
| FlashAttention | 100% | 80-90% | -10 to -20% |
总体来说:
- Triton 在简单算子上接近或超过 CUDA C++:编译器把 element-wise + reduce 优化得很好。
- Triton 在复杂算子上落后 5-20%:比如 GEMM 和 FA,CUTLASS 用了 Hopper 全部新特性(TMA、WGMMA、Warp Specialization),Triton 编译器还在追赶(Triton 3.x 开始原生支持,但还不完美)。
- Triton 的最大优势是开发速度:5x 快的开发,5-15% 慢的运行——很多时候是好的 trade-off。
B.4 什么时候用 Triton
Triton 的"甜区":
- 原型验证:新算法快速实现,看效果。
- 中等复杂度算子:fused softmax、dropout、新激活函数等。
- PyTorch 生态:
torch.compile内部就用 Triton 生成 kernel。 - 快速迭代:调 BLOCK 大小不需要重编译几分钟。
B.5 什么时候坚持 CUDA C++
CUDA C++ 的"甜区":
- 生产 GEMM / Attention:cuBLAS、CUTLASS、FA3 都是 CUDA C++。极致性能必须 C++。
- 复杂 epilogue fusion:CUTLASS 的 EVT 比 Triton 灵活。
- 跨硬件代际:CUTLASS 同一代码支持 Volta/Ampere/Hopper/Blackwell;Triton 后端对老硬件支持差。
- TMA / WGMMA / Warp Specialization 极致优化:Triton 还在追赶这些特性。
- 库的作者:cuBLAS、Megatron-LM 这种"被无数下游用"的库不能容忍 5-10% 性能损失。
B.6 工业上怎么共存
vLLM 是个有趣的案例:
vLLM 架构 (~2026):
├── PagedAttention kernel: CUDA C++ (手写)
├── Marlin INT4 GEMM: CUDA C++ (CUTLASS-based)
├── Fused RMS norm: CUDA C++ + 部分 Triton
├── RoPE / Activation: Triton
├── Mixture-of-Experts: Triton
└── LoRA epilogue: CUDA C++ (CUTLASS)
核心高性能 kernel 用 CUDA C++,外围灵活 kernel 用 Triton——这是工业上最常见的组合。
B.7 PyTorch 2.x 的 TorchInductor
PyTorch 2.x 的 torch.compile 内部用 TorchInductor 把 PyTorch 计算图编译成 Triton kernel。这意味着普通 PyTorch 代码加一行 model = torch.compile(model) 就能享受 Triton 的优化。
但 TorchInductor 不能替代手写 kernel——对 attention、GEMM 这些核心算子,它会调用预编译的 cuBLAS/cuDNN/FA,自己不重写。
B.8 这个附录的小结
CUDA C++ 和 Triton 是互补关系:
- Triton 是高效的 DSL:5x 开发速度,5-15% 运行差距。
- CUDA C++ 是极致性能的最后一公里:cuBLAS、CUTLASS、FA3 都是 C++。
- 工业上两者并存:核心 kernel C++,外围 kernel Triton。
- 学习路径:先学 CUDA C++(理解硬件),再学 Triton(提升生产力)。读完本书,再去看 Triton 会非常顺。