CUDA 算子工程:手写 FlashAttention v2 之路

附录 B · CUDA C++ vs Triton

作者 杨艺韬 · 1,005 字

附录 B · CUDA C++ vs Triton

B.1 Triton 是什么

Triton 是 OpenAI 在 2019 年开源的 GPU 编程 DSL(Domain-Specific Language)。核心理念是用 Python 语法、tile-level 抽象、编译器自动调优写 GPU kernel。

一个 Triton 风格的 GEMM:

import triton
import triton.language as tl

@triton.jit
def gemm_kernel(
    a_ptr, b_ptr, c_ptr,
    M, N, K,
    stride_am, stride_ak,
    stride_bk, stride_bn,
    stride_cm, stride_cn,
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
    BLOCK_K: tl.constexpr,
):
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)

    offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    offs_k = tl.arange(0, BLOCK_K)

    a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
    b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn

    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
    for k in range(0, K, BLOCK_K):
        a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k)
        b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k)
        acc += tl.dot(a, b)
        a_ptrs += BLOCK_K * stride_ak
        b_ptrs += BLOCK_K * stride_bk

    c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
    tl.store(c_ptrs, acc.to(tl.float16))

短得惊人——~30 行 Python 写出一个 GEMM。等价的 CUDA C++ 大概 200+ 行。

B.2 Triton 的设计哲学

Triton 提供 tile-level(瓦片级)抽象:

这种"声明式"抽象比 CUDA C++ 高一层——程序员描述意图("我要一个 BLOCK_M×BLOCK_N 的 GEMM tile"),编译器负责实现(thread 分配、SMEM 布局、寄存器分配)。

B.3 性能对比

工业级测量数据(不同 kernel 不同情况):

算子 CUTLASS (CUDA C++) Triton 差距
GEMM (大尺寸) 100% baseline 85-95% -5 to -15%
GEMM (小尺寸) 100% 70-90% -10 to -30%
LayerNorm 100% 100-110% 持平或更快
Softmax 100% 95-105% 几乎相同
FlashAttention 100% 80-90% -10 to -20%

总体来说:

B.4 什么时候用 Triton

Triton 的"甜区":

  1. 原型验证:新算法快速实现,看效果。
  2. 中等复杂度算子:fused softmax、dropout、新激活函数等。
  3. PyTorch 生态torch.compile 内部就用 Triton 生成 kernel。
  4. 快速迭代:调 BLOCK 大小不需要重编译几分钟。

B.5 什么时候坚持 CUDA C++

CUDA C++ 的"甜区":

  1. 生产 GEMM / Attention:cuBLAS、CUTLASS、FA3 都是 CUDA C++。极致性能必须 C++。
  2. 复杂 epilogue fusion:CUTLASS 的 EVT 比 Triton 灵活。
  3. 跨硬件代际:CUTLASS 同一代码支持 Volta/Ampere/Hopper/Blackwell;Triton 后端对老硬件支持差。
  4. TMA / WGMMA / Warp Specialization 极致优化:Triton 还在追赶这些特性。
  5. 库的作者:cuBLAS、Megatron-LM 这种"被无数下游用"的库不能容忍 5-10% 性能损失。

B.6 工业上怎么共存

vLLM 是个有趣的案例:

vLLM 架构 (~2026):
├── PagedAttention kernel:  CUDA C++ (手写)
├── Marlin INT4 GEMM:        CUDA C++ (CUTLASS-based)
├── Fused RMS norm:          CUDA C++ + 部分 Triton
├── RoPE / Activation:       Triton
├── Mixture-of-Experts:      Triton
└── LoRA epilogue:           CUDA C++ (CUTLASS)

核心高性能 kernel 用 CUDA C++,外围灵活 kernel 用 Triton——这是工业上最常见的组合。

B.7 PyTorch 2.x 的 TorchInductor

PyTorch 2.x 的 torch.compile 内部用 TorchInductor 把 PyTorch 计算图编译成 Triton kernel。这意味着普通 PyTorch 代码加一行 model = torch.compile(model) 就能享受 Triton 的优化。

但 TorchInductor 不能替代手写 kernel——对 attention、GEMM 这些核心算子,它会调用预编译的 cuBLAS/cuDNN/FA,自己不重写。

B.8 这个附录的小结

CUDA C++ 和 Triton 是互补关系:

  1. Triton 是高效的 DSL:5x 开发速度,5-15% 运行差距。
  2. CUDA C++ 是极致性能的最后一公里:cuBLAS、CUTLASS、FA3 都是 C++。
  3. 工业上两者并存:核心 kernel C++,外围 kernel Triton。
  4. 学习路径:先学 CUDA C++(理解硬件),再学 Triton(提升生产力)。读完本书,再去看 Triton 会非常顺。