CUDA 算子工程:手写 FlashAttention v2 之路
第 13 章 CUTLASS 3.x 设计哲学
第 13 章 CUTLASS 3.x 设计哲学
"CUTLASS is not a library you learn — it is a vocabulary you adopt. Once you can speak in 'CollectiveOps' and 'CuTe layouts', the rest of modern CUDA programming starts speaking back." ——CUTLASS 团队的设计宣言
13.1 CUTLASS 是什么
CUTLASS(CUDA Templates for Linear Algebra Subroutines)是 NVIDIA 官方维护的开源 C++ 模板库,第一版 2017 年 Volta 时代发布,到现在第 3.x 版本已经成熟到:
- 30 万行代码(统计 GitHub 仓库 src 目录)。
- 覆盖 Volta / Turing / Ampere / Hopper / Blackwell 五代架构。
- 支持几十种数据类型组合(FP16/BF16/TF32/FP8/INT8/INT4 任意组合)。
- 是 cuBLAS、FlashAttention v3、Transformer Engine、PyTorch ATen GEMM 后端的共同基础。
简单说:今天你跑的几乎所有 LLM GEMM 算子,都是 CUTLASS 直接或间接写的。
但 CUTLASS 也是出名的难学。30 万行模板代码、深度嵌套的类型层级、上百个 traits class——新人打开 CUTLASS 源码常常会陷入"看 5 分钟模板就晕"的状态。
这一章的目标不是教读者用 CUTLASS(它的 API 还在演化),而是教读者理解 CUTLASS 的设计哲学——CuTe Layout、CollectiveOp 三段式、Kernel Schedule。理解这些之后,读者再打开 CUTLASS 源码会发现"哦原来这里在做这件事"。
13.2 CUTLASS 三代演进
CUTLASS 的设计经过了三次大重构,每一次都是对硬件能力的重新抽象:
flowchart LR
subgraph V1 [CUTLASS 1.x · 2017-2018]
V1A[Volta · WMMA 16×16×16]
V1B[平铺 GEMM 模板]
V1C[Bottom-up: 用底层指令组合]
end
subgraph V2 [CUTLASS 2.x · 2019-2022]
V2A[Turing/Ampere · mma.sync 16×8×16]
V2B[Threadblock-level GEMM 抽象]
V2C[Iterator + Pipeline 模式]
end
subgraph V3 [CUTLASS 3.x · 2023+]
V3A[Hopper · WGMMA + TMA]
V3B[CuTe Layout + CollectiveOp]
V3C[Top-down: 描述意图,自动展开]
end
V1 --> V2 --> V3
13.2.1 1.x 时代:模板地狱
CUTLASS 1.x 是 Volta 架构的产物。它的核心抽象是"GEMM Pipeline"——把 GEMM 拆成 prologue / mainloop / epilogue 三段,每段都是模板类。
代码风格大致这样:
template <
typename ElementA, typename LayoutA,
typename ElementB, typename LayoutB,
typename ElementC, typename LayoutC,
int ThreadBlockShapeM, int ThreadBlockShapeN, int ThreadBlockShapeK,
int WarpShapeM, int WarpShapeN, int WarpShapeK,
int MmaShapeM, int MmaShapeN, int MmaShapeK,
typename EpilogueOp
>
class Gemm { ... };
模板参数 30+ 个是常态。能写出来,但巨难看懂、巨难改。
13.2.2 2.x 时代:Iterator 与 Pipeline
CUTLASS 2.x 引入了几个关键抽象:
- Iterator:把"从某个 Tensor 中以某种 stride/layout 读出 fragment"的过程模板化。Iterator 隐藏了地址计算和 vectorized load。
- Pipeline:把 prologue + mainloop + epilogue 用模板组合起来,自动处理 double buffer。
代码可读性大幅提升,但还是有大量隐式约定("layout 必须满足 contract X"),新人难入门。
13.2.3 3.x 时代:CuTe + CollectiveOp
CUTLASS 3.x 是真正的范式重构。两个核心抽象:
- CuTe (CUDA Tensor):一个独立的子库,专门做 layout 代数。把"什么样的数据怎么放"用一种通用语言描述出来。
- CollectiveOp:把 GEMM 拆成 CollectiveMainloop(核心循环)和 CollectiveEpilogue(输出阶段),每段是一个高层抽象类。
3.x 的代码大致长这样:
using CollectiveMainloop = cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
half_t, LayoutA, 8,
half_t, LayoutB, 8,
float,
Shape<_128, _128, _32>,
Shape<_1, _1, _1>, // ClusterShape
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::KernelTmaWarpSpecialized
>::CollectiveOp;
模板参数还是多,但每个参数的含义清晰,且大部分有合理默认。这就是 3.x 的进步——降低使用门槛,但保留自定义能力。
13.3 CuTe Layout:CUTLASS 3.x 的灵魂
CuTe 是 CUTLASS 3.x 中所有概念的基础。它的核心是把"如何排布张量"形式化。
13.3.1 Layout 的定义
CuTe Layout 是 (Shape, Stride) 的 pair:
using L = Layout<Shape<_4, _8>, Stride<_1, _4>>; // 4×8 矩阵, row-major
Shape (4, 8) 表示 4 行 8 列。Stride (1, 4) 表示行步长 1、列步长 4——也就是 row-major 布局:
矩阵地址: 元素 (i, j) 的偏移:
(0,0) -> 0 (0,1) -> 4 ... offset = i * 1 + j * 4
(1,0) -> 1 (1,1) -> 5 ...
这是基本的 layout。CuTe 强大在于能组合 layout:
// 嵌套 layout: 4×8 矩阵被分成 2×4 个 (2×2) 子块, 每个子块内部 row-major
using L = Layout<
Shape<Shape<_2, _2>, Shape<_2, _4>>,
Stride<Stride<_4, _16>, Stride<_1, _2>>
>;
复杂吗?是的。但它能描述任何复杂布局——swizzle 布局、Tensor Core fragment 布局、TMA 的 multicast layout——都能用 (Shape, Stride) 嵌套表示。
13.3.2 Layout 代数
CuTe 提供一组对 Layout 的操作:
compose(A, B):把 A 通过 B 重映射。tile(A, B):把 A 切成多个 B 形状的 tile。product(A, B):A 和 B 的笛卡尔积。get_layout(tensor):取张量的 layout。
这些操作让你可以像写代数公式一样描述 GEMM 的数据流:
auto gA = local_tile(mA, Tile<BM, BK>{}, blockIdx.x); // 每 block 拿 A 的一块
auto gB = local_tile(mB, Tile<BN, BK>{}, blockIdx.y); // 每 block 拿 B 的一块
auto sA = make_tensor(make_smem_ptr(smem_a),
decltype(gA)::layout_type{}); // 在 SMEM 里建对应 tensor
copy(gA, sA); // CuTe 自动决定怎么 copy
copy 函数会根据 source 和 dest 的 layout 自动选择最高效的 copy 方式——可能是 cp.async、可能是 TMA、可能是 vectorized load。程序员不需要手写 cp.async 指令。
这就是 CuTe 最革命性的地方——用一种声明式语言描述数据流,让编译器/库自动选择最优指令。
13.3.3 Swizzle 在 CuTe 中
第 12 章讲的 swizzle layout 在 CuTe 中是一个一等公民:
using SwizzleAtom = decltype(
composition(Swizzle<3, 3, 3>{}, Layout<Shape<_8, _BK>, Stride<_BK, _1>>{})
);
Swizzle<3, 3, 3> 是参数化的 swizzle 函数(M-bits, S-bits, B-bits)。具体含义复杂,但 CUTLASS 提供了几组预定义的 SwizzleAtom,覆盖典型 GEMM 需要的所有 swizzle。
13.4 CollectiveMainloop:核心循环的抽象
CollectiveMainloop 描述 GEMM 的核心循环——从 HBM 拉 A/B tile 到 SMEM,从 SMEM 读 fragment 到寄存器,调用 mma 累加。
CUTLASS 3.x 提供了一组 Kernel Schedule:
KernelTmaWarpSpecialized // Hopper, TMA + Warp Specialized
KernelTmaWarpSpecializedPingpong // Hopper, ping-pong on accumulator
KernelTmaWarpSpecializedCooperative // Hopper, two warpgroups cooperate
KernelMultistage // Ampere, multi-stage cp.async pipeline
KernelPipelined // 2.x style, classic double buffer
每个 Schedule 对应一种 mainloop 实现。KernelTmaWarpSpecialized 是 Hopper 上的推荐 schedule——它实现了第 2 章讲的 Producer/Consumer warp specialization。
伪代码:
// 内部展开后的 KernelTmaWarpSpecialized mainloop
__global__ void mainloop() {
if (warp_group == 0) {
// Producer warp group: 持续发起 TMA
for (int k = 0; k < K_tiles; ++k) {
wait_pipeline_empty(k);
cp_async_bulk_tensor(smem_a[k % stages], tma_desc_a, k);
cp_async_bulk_tensor(smem_b[k % stages], tma_desc_b, k);
mbarrier_arrive(pipeline[k % stages]);
}
} else {
// Consumer warp groups: 等数据并算 WGMMA
for (int k = 0; k < K_tiles; ++k) {
mbarrier_wait(pipeline[k % stages]);
wgmma_mma_async(c_acc, smem_a[k % stages], smem_b[k % stages]);
wgmma_commit_group();
wgmma_wait_group(0);
release_pipeline(k % stages);
}
}
}
读者不需要手写这段代码——KernelTmaWarpSpecialized 把它实现好了。读者只要在 CollectiveBuilder 模板参数里指定就行。
13.5 CollectiveEpilogue:输出阶段的抽象
GEMM 的 epilogue 是 D = alpha * C_acc + beta * C_old + bias + activation 这种最终阶段。CUTLASS 3.x 把它抽象成 CollectiveEpilogue:
using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Shape<_128, _128, _32>,
Shape<_1, _1, _1>,
cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
half_t, LayoutC, 8,
half_t, LayoutD, 8,
cutlass::epilogue::TmaWarpSpecializedCooperative,
cutlass::epilogue::fusion::LinearCombination<half_t, float>
>::CollectiveOp;
最后一个参数 LinearCombination 是融合操作。可以替换成自定义 epilogue:
using FusedAddRelu = cutlass::epilogue::fusion::LinearCombinationRelu<half_t, float>;
CUTLASS 提供几十种预定义 epilogue:ReLU / GELU / SiLU / Softmax / Quantize / LoRA。如果不够用,可以用 EVT(Epilogue Visitor Tree)DSL 自定义复杂 epilogue。
13.6 把 GEMM 拼起来
完整的 CUTLASS 3.x GEMM 调用:
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int, int, int>, // ProblemShape (M, N, K)
CollectiveMainloop, // 上面定义的 mainloop
CollectiveEpilogue // 上面定义的 epilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
Gemm gemm;
typename Gemm::Arguments args = {
cutlass::gemm::GemmUniversalMode::kGemm,
{M, N, K, 1},
{ptr_A, stride_A, ptr_B, stride_B},
{{alpha, beta}, ptr_C, stride_C, ptr_D, stride_D}
};
gemm.initialize(args, workspace, stream);
gemm.run(stream);
这就是用户视角的 CUTLASS 3.x。看起来仍然有很多模板,但每个模板参数都有清晰含义和合理默认。
13.7 怎么读 CUTLASS 源码
最后给读者一份"CUTLASS 源码导航指南"。GitHub 仓库 https://github.com/NVIDIA/cutlass,关键目录:
include/cutlass/gemm/:GEMM 主体。device/:device-level API(用户调用层)。kernel/:kernel-level(GemmUniversal 这一层)。collective/:CollectiveOp 实现。threadblock//warp/:底层 building block。
include/cutlass/epilogue/:epilogue 实现。include/cute/:CuTe 子库。layout.hpp:Layout 定义。tensor.hpp:Tensor 抽象。algorithm/copy.hpp:copy 算法(自动选 cp.async / TMA)。
examples/:可工作的示例代码。从examples/00_basic_gemm开始读。
阅读建议:
- 从 example 开始:
examples/48_hopper_warp_specialized_gemm是 Hopper 上的标准例子。读它的 main + 内嵌的 collective 配置。 - 追到 mainloop:从 example 进入
cutlass::gemm::collective::CollectiveBuilder,看它如何根据模板参数选择具体的 collective 实现(SM90 + WarpSpecialized路径)。 - 看 mainloop 体:找到
cutlass/gemm/collective/sm90_mma_tma_gmma_warpspecialized.hpp(或类似文件),读mainloop()函数的展开。 - 看 CuTe 操作:mainloop 里大量的
copy、gemm、partition_*都是 CuTe 函数,跳到include/cute看。
第一遍读会很慢(一个 example 一周)。但读懂之后回头看 FA3 源码、cuBLAS Hopper kernel、Marlin GEMM,会发现它们都是 CUTLASS 的应用——风格一致。
13.8 第三篇收官:从 700 GFLOPs 到 750 TFLOPs
第三篇我们走完了 GEMM 的优化全程:
| 章节 | 路径 | 性能 |
|---|---|---|
| 第 10 章 | 朴素 GEMM | 700 GFLOPs |
| 第 11 章 | Tiled GEMM (SIMT 极限) | ~25 TFLOPs |
| 第 12 章 | Tensor Core HGEMM 骨架 | ~600 TFLOPs |
| 第 13 章 | CUTLASS 工业级 | ~750-800 TFLOPs |
从 700 GFLOPs 到 750 TFLOPs,1000 倍的性能差距。这就是"会写 CUDA"和"懂现代 GPU"的分野。
但 GEMM 不是最终目的。LLM 推理真正的难点是 Attention——它内部包含两个 GEMM(QK^T、PV),中间夹一个 softmax,且数据形态特殊(causal mask、长序列、KV cache)。
第 14-18 章的第四篇我们会把整个第二、三篇的工艺集中到一件事上:手写 FlashAttention v2 到 SOTA。从访存瓶颈分析(第 14 章),到 FA2 前向(第 15 章),到反向(第 16 章),到 Hopper 上的 TMA + Warp Specialization 优化(第 17 章),到 Persistent Kernel(第 18 章)。读完第 18 章,读者会拥有完整的 FA2 实现能力,并能看懂 FlashAttention v3 论文的所有创新。
本章动手练习:
- 在 H100 上跑 CUTLASS 的
examples/48_hopper_warp_specialized_gemm。把 ProblemShape 改成 LLaMA-7B 的 GEMM 尺寸 (M=4096, N=11008, K=4096),看实际性能。- 阅读 CuTe 的
Layout定义,理解Shape和Stride的嵌套语法。试着用 CuTe 描述一个 16×16 fp16 矩阵的 swizzle 布局。- 找到 CUTLASS 中 LinearCombination Epilogue 的实现,理解它是如何在最后阶段做 alpha/beta 的。