CUDA 算子工程:手写 FlashAttention v2 之路
第 20 章 PTX 与 SASS:编译器到底干了什么
第 20 章 PTX 与 SASS:编译器到底干了什么
"PTX is what you write down. SASS is what actually runs. The gap between them is where compiler magic — and compiler tragedy — lives." ——CUDA 工程师的常见自嘲
20.1 CUDA 编译流程
CUDA C++ 到 GPU 跑代码经过三层抽象:
flowchart LR CXX[CUDA C++ source] -->|nvcc 前端| PTX[PTX · 虚拟 ISA] PTX -->|nvcc 后端 ptxas| CUBIN[CUBIN · SASS 机器码] CUBIN -->|GPU 加载| EXEC[GPU 执行] PTX -.JIT 编译.-> CUBIN
- CUDA C++ source:你写的
.cu文件。 - PTX (Parallel Thread Execution):NVIDIA 的虚拟 ISA。前向兼容——同一份 PTX 可以在 Volta/Ampere/Hopper 上跑(运行时再 JIT 成具体 SASS)。
- SASS (Streaming ASsembler):真正在硬件上跑的机器码。每代架构的 SASS 不一样(Hopper 的 SASS 和 Ampere 的不同)。
编译时:
nvcc -arch=sm_90 my_kernel.cu -o my_kernel
-arch=sm_90 告诉编译器目标是 Hopper。nvcc 会生成同时包含 PTX 和 sm_90 SASS 的 fatbin。
可以用 --keep 保留中间产物:
nvcc -arch=sm_90 --keep my_kernel.cu
# 生成 my_kernel.ptx, my_kernel.cubin, my_kernel.sass 等
20.2 PTX 的语法
PTX 是一种汇编语言,但不是 GPU 真正的机器码——它是"假装的汇编",给编译器后端处理用的。
PTX 的几个关键语法元素:
// 加载一个浮点数到寄存器
ld.global.f32 %r1, [%rd1]; // %r1 = *(float*)%rd1
// 浮点加法
add.f32 %r3, %r1, %r2; // %r3 = %r1 + %r2
// 浮点 fma (融合乘加)
fma.rn.f32 %r4, %r1, %r2, %r3; // %r4 = %r1 * %r2 + %r3
// 存储
st.global.f32 [%rd2], %r4;
// 控制流
@%p1 bra TARGET; // if (%p1) goto TARGET
PTX 寄存器:
%r0..%rN:32-bit 通用寄存器%rd0..%rdN:64-bit 寄存器(指针)%fX:浮点寄存器%pX:谓词(条件)寄存器
PTX 是无限寄存器的——程序员(或编译器)可以申明任意多个 %r0..%rN,最终 ptxas 后端会把它们映射到物理寄存器。
20.3 看 PTX 输出
编译时加 -ptx 让 nvcc 只生成 PTX:
nvcc -arch=sm_90 -ptx my_kernel.cu -o my_kernel.ptx
或者 -keep 保留。打开看:
.entry _Z11my_kernelPfS_i(
.param .u64 _Z11my_kernelPfS_i_param_0, // float* a
.param .u64 _Z11my_kernelPfS_i_param_1, // float* b
.param .u32 _Z11my_kernelPfS_i_param_2 // int n
)
{
.reg .b64 %rd<10>;
.reg .b32 %r<20>;
.reg .f32 %f<20>;
.reg .pred %p<5>;
ld.param.u64 %rd1, [_Z11my_kernelPfS_i_param_0];
ld.param.u64 %rd2, [_Z11my_kernelPfS_i_param_1];
ld.param.u32 %r1, [_Z11my_kernelPfS_i_param_2];
mov.u32 %r2, %ctaid.x;
mov.u32 %r3, %ntid.x;
mov.u32 %r4, %tid.x;
mad.lo.s32 %r5, %r3, %r2, %r4; // r5 = blockIdx.x * blockDim.x + threadIdx.x
setp.ge.s32 %p1, %r5, %r1;
@%p1 bra L0;
cvta.to.global.u64 %rd3, %rd1;
mul.wide.s32 %rd4, %r5, 4;
add.s64 %rd5, %rd3, %rd4;
ld.global.f32 %f1, [%rd5];
add.f32 %f2, %f1, 1.0f;
cvta.to.global.u64 %rd6, %rd2;
add.s64 %rd7, %rd6, %rd4;
st.global.f32 [%rd7], %f2;
L0: ret;
}
读 PTX 的几个要点:
.reg申明寄存器。.b32是 32-bit,.f32是浮点。%ctaid/%ntid/%tid是 blockIdx / blockDim / threadIdx 的内置寄存器。mad.lo.s32是 32-bit 整数 multiply-add(a*b+c)。.lo表示取低 32 位。@%p1 bra是有条件跳转。cvta.to.global.u64是 generic 指针转 global 指针(CUDA 有 generic/global/shared 多种地址空间)。
20.4 看 SASS 输出
SASS 是真正的机器码。用 cuobjdump 或 nvdisasm 反汇编:
cuobjdump --dump-sass my_kernel.cubin
# 或者直接对 .o 文件:
cuobjdump --dump-sass my_kernel.o
输出(Hopper SASS 示例):
Function : _Z11my_kernelPfS_i
.headerflags @"EF_CUDA_SM90_TEXMODE_UNIFIED EF_CUDA_64BIT_ADDRESS EF_CUDA_SM90 ..."
/*0000*/ IMAD.MOV.U32 R1, RZ, RZ, c[0x0][0x28] ;
/*0010*/ S2R R0, SR_CTAID.X ;
/*0020*/ ULDC.64 UR4, c[0x0][0x208] ;
/*0030*/ S2R R3, SR_TID.X ;
/*0040*/ IMAD R0, R0, c[0x0][0x0], R3 ;
/*0050*/ ISETP.GE.AND P0, PT, R0, c[0x0][0x178], PT ;
/*0060*/ @P0 EXIT ;
/*0070*/ IMAD.WIDE R2, R0, 0x4, c[0x0][0x160] ;
/*0080*/ LDG.E R4, [R2.64] ;
/*0090*/ FADD R5, R4, 1 ;
/*00a0*/ IMAD.WIDE R2, R0, 0x4, c[0x0][0x168] ;
/*00b0*/ STG.E [R2.64], R5 ;
/*00c0*/ EXIT ;
读 SASS 的关键点:
IMAD.MOV.U32是 fused multiply-add。Hopper 上经常用 IMAD 替代 MOV 因为更省功耗。S2R R0, SR_CTAID.X把 special register(CTAID)读到 R0。c[0x0][0x178]是 constant memory 访问,c[bank][offset]。LDG.E是 32-bit Load Global,E表示 elegant addressing。@P0 EXIT是有条件退出(基于 predicate P0)。IMAD.WIDE是 32×32→64 乘加,常用于地址计算。
注意 SASS 比 PTX 少了几条指令——编译器把 PTX 中的 cvta 等"虚拟操作"消掉了。
20.5 看 SASS 找性能问题
SASS 能看到 ncu 看不到的细节:
20.5.1 寄存器 spill 检测
如果 SASS 里看到大量 LDL/STL(local memory load/store),说明编译器把寄存器溢出到 local memory(实际是 HBM)。这是性能毒药。
LDL.LU.U8 R10, [R1+0x80] ; // ← 从 local memory 加载, 慢!
修复:减少局部变量、降低 unroll 程度、用 __launch_bounds__ 提示编译器降低寄存器使用。
20.5.2 Bank Conflict 实证
SMEM 访问指令是 LDS 和 STS。如果 SASS 里看到这些指令,可以用 ncu 测 bank conflict 数。但有些情况 SASS 能直接告诉你模式:
LDS.U.128 R4, [R20] ; // 一次读 128 bit (4 个 fp32)
LDS.U.128 R8, [R20+0x80] ;
LDS.U.128 R12, [R20+0x100] ;
如果几个 LDS 的地址间隔是 128 bytes(32 banks × 4 bytes),它们一定 bank conflict。
20.5.3 控制依赖与延迟
Hopper SASS 里每条指令前有一个调度信息(control codes),格式像 --:-:-:-:1。它告诉硬件这条指令需要等多少 cycle 才能发下一条。如果看到很长的 stall(比如 :8:),说明这条指令有依赖等待。
编译器通常做得不错,但偶尔会看到次优的调度——这时候手写 PTX 调整指令顺序可能有用。
20.6 Inline PTX:何时用、怎么用
CUDA C++ 大部分情况下足够。但以下情况需要 inline PTX:
20.6.1 用还没暴露 C++ API 的硬件特性
Hopper 引入的很多新指令(TMA、WGMMA、setmaxnreg)在 C++ 层面只有部分包装,最完整的接口是 PTX。例如 wgmma.mma_async:
asm volatile(
"wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16 "
"{%0, %1, ..., %63}, %64, %65, p, 0, 0;\n"
: "+f"(c[0]), "+f"(c[1]), ..., "+f"(c[63])
: "l"(desc_a), "l"(desc_b), "n"(scale_d)
);
CUDA 12.x 才开始有 cuda::ptx::wgmma_* 包装,且不一定覆盖所有变体。FA3 和 CUTLASS 大量手写 PTX。
20.6.2 控制编译器无法表达的优化
某些 PTX 指令的"flag"在 C++ 层无法表达。比如 ld.global.cs (cache streaming) 指示这个 load 不应该污染 L2 cache。Triton/CUTLASS 里常见:
// 用 streaming load 避免污染 L2
asm("ld.global.cs.b128 {%0,%1,%2,%3}, [%4];\n"
: "=r"(v0), "=r"(v1), "=r"(v2), "=r"(v3)
: "l"(addr));
20.6.3 强制特定指令
编译器有时会"自作聪明"——把你的 int4 解码成多条指令而不是一条 lop3.b32。这时手写 PTX 强制用 lop3 能省指令带宽(第 9 章 Marlin 的例子)。
20.6.4 让 fragment 直接进 mma
mma.sync 的输入要求 fragment 在特定寄存器。直接用 inline PTX 控制寄存器分配,比让 C++ 编译器折腾更可靠。
20.7 一组高频 PTX/SASS Pattern
LLM kernel 中高频出现的指令:
A. mma.sync (Tensor Core 矩阵乘)
mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32
{%0, %1, %2, %3},
{%4, %5, %6, %7},
{%8, %9},
{%0, %1, %2, %3};
SASS:
HMMA.16816.F32 R4, R8, R12, R4 ;
B. cp.async (Ampere+ 异步拷贝)
cp.async.cg.shared.global [%0], [%1], 16;
cp.async.commit_group;
cp.async.wait_group 0;
SASS:
LDGSTS.E.128 [R10], [R20.64] ;
LDGDEPBAR ;
DEPBAR.LE SB0, 0x0 ;
C. cp.async.bulk.tensor (Hopper TMA)
cp.async.bulk.tensor.2d.shared::cluster.global.tile.mbarrier::complete_tx::bytes
[%0], [%1, {%2, %3}], [%4];
SASS(Hopper):
UTMACPY.2D.SP /* TMA Async copy 2D */
D. ldmatrix (从 SMEM 加载 fragment)
ldmatrix.sync.aligned.m8n8.x4.shared.b16
{%0, %1, %2, %3}, [%4];
SASS:
LDSM.16.MT88.4 R4, [R20] ;
E. lop3 (三输入逻辑运算)
lop3.b32 %0, %1, 0x000F000F, 0x64006400, 0xea;
SASS:
LOP3.LUT R4, R8, 0xf000f, 0x64006400, 0xea, !PT ;
熟悉这些 pattern 后,读 SASS 会变成"翻译"——每条 SASS 都对应一个清晰的工程意图。
20.8 一个有趣的实战:让编译器生成正确的指令
经常见到这种情况:你写了一段看起来高效的 C++ 代码,但 SASS 显示它生成了次优指令。
例:从 int8_t 转 float:
int8_t x = ...;
float f = (float)x;
朴素期望:1 条指令。
实际 SASS(某些情况下):
I2F.F32.S8 R4, R8 ; // OK
但如果是 uint8_t:
uint8_t x = ...;
float f = (float)x;
SASS:
PRMT R4, R8, ... ; // permute 把 8-bit 提取出来
I2F.F32.S32 R4, R4 ; // 然后 32→f32
两条指令而不是一条!原因是 SASS 没有 unsigned 8-bit 直接转 float 的指令。要省指令,用 int8_t 而不是 uint8_t,或者手写 inline PTX 用 cvt.rn.f32.u8。
这种"编译器没生成你期望的指令"的事经常发生。读 SASS 是发现这种问题的唯一方法。
20.9 这一章的小结与下一章
PTX/SASS 是 CUDA 工程师的"反汇编技能":
- PTX 是虚拟 ISA,SASS 是真实机器码:JIT 把 PTX 转成 SASS,每代架构 SASS 不同。
-keep保留 PTX,cuobjdump --dump-sass看 SASS:这是日常工具。- 读 SASS 找寄存器 spill / bank conflict / 次优指令:ncu 看不出的细节用 SASS。
- Inline PTX 是高级优化的最后一招:当 C++ 编译器不够聪明时,手写 PTX 强制特定指令。
- 熟悉高频指令 pattern:mma、cp.async、ldmatrix、lop3 等 LLM kernel 的关键指令。
第 21 章我们结束第五篇(也是本书的核心内容)——讲性能陷阱与反模式。一些常见的"看起来对、实际上慢"的写法,把它们罗列出来作为读者的避坑指南。读完第 21 章读者就完成了从理论到实战到诊断的完整训练。
本章动手练习:
- 编译你之前写的某个 kernel,用
cuobjdump --dump-sass看 SASS。找一行 SASS,对应回 CUDA C++ 源码。- 写一段故意有寄存器 spill 的 kernel(比如几百个局部变量),看 SASS 中的
LDL/STL指令。- 用 inline PTX 写一条
lop3.b32,对比让编译器自动生成的版本。