CUDA 算子工程:手写 FlashAttention v2 之路

第 20 章 PTX 与 SASS:编译器到底干了什么

作者 杨艺韬 · 2,249 字

第 20 章 PTX 与 SASS:编译器到底干了什么

"PTX is what you write down. SASS is what actually runs. The gap between them is where compiler magic — and compiler tragedy — lives." ——CUDA 工程师的常见自嘲

20.1 CUDA 编译流程

CUDA C++ 到 GPU 跑代码经过三层抽象:

flowchart LR
  CXX[CUDA C++ source] -->|nvcc 前端| PTX[PTX · 虚拟 ISA]
  PTX -->|nvcc 后端 ptxas| CUBIN[CUBIN · SASS 机器码]
  CUBIN -->|GPU 加载| EXEC[GPU 执行]
  PTX -.JIT 编译.-> CUBIN

编译时:

nvcc -arch=sm_90 my_kernel.cu -o my_kernel

-arch=sm_90 告诉编译器目标是 Hopper。nvcc 会生成同时包含 PTX 和 sm_90 SASS 的 fatbin。

可以用 --keep 保留中间产物:

nvcc -arch=sm_90 --keep my_kernel.cu
# 生成 my_kernel.ptx, my_kernel.cubin, my_kernel.sass 等

20.2 PTX 的语法

PTX 是一种汇编语言,但不是 GPU 真正的机器码——它是"假装的汇编",给编译器后端处理用的。

PTX 的几个关键语法元素:

// 加载一个浮点数到寄存器
ld.global.f32 %r1, [%rd1];     // %r1 = *(float*)%rd1

// 浮点加法
add.f32 %r3, %r1, %r2;          // %r3 = %r1 + %r2

// 浮点 fma (融合乘加)
fma.rn.f32 %r4, %r1, %r2, %r3;  // %r4 = %r1 * %r2 + %r3

// 存储
st.global.f32 [%rd2], %r4;

// 控制流
@%p1 bra TARGET;                // if (%p1) goto TARGET

PTX 寄存器:

PTX 是无限寄存器的——程序员(或编译器)可以申明任意多个 %r0..%rN,最终 ptxas 后端会把它们映射到物理寄存器。

20.3 看 PTX 输出

编译时加 -ptx 让 nvcc 只生成 PTX:

nvcc -arch=sm_90 -ptx my_kernel.cu -o my_kernel.ptx

或者 -keep 保留。打开看:

.entry _Z11my_kernelPfS_i(
    .param .u64 _Z11my_kernelPfS_i_param_0,   // float* a
    .param .u64 _Z11my_kernelPfS_i_param_1,   // float* b
    .param .u32 _Z11my_kernelPfS_i_param_2    // int n
)
{
    .reg .b64       %rd<10>;
    .reg .b32       %r<20>;
    .reg .f32       %f<20>;
    .reg .pred      %p<5>;

    ld.param.u64    %rd1, [_Z11my_kernelPfS_i_param_0];
    ld.param.u64    %rd2, [_Z11my_kernelPfS_i_param_1];
    ld.param.u32    %r1,  [_Z11my_kernelPfS_i_param_2];

    mov.u32         %r2, %ctaid.x;
    mov.u32         %r3, %ntid.x;
    mov.u32         %r4, %tid.x;
    mad.lo.s32      %r5, %r3, %r2, %r4;       // r5 = blockIdx.x * blockDim.x + threadIdx.x

    setp.ge.s32     %p1, %r5, %r1;
    @%p1 bra        L0;

    cvta.to.global.u64  %rd3, %rd1;
    mul.wide.s32    %rd4, %r5, 4;
    add.s64         %rd5, %rd3, %rd4;
    ld.global.f32   %f1, [%rd5];
    add.f32         %f2, %f1, 1.0f;
    cvta.to.global.u64  %rd6, %rd2;
    add.s64         %rd7, %rd6, %rd4;
    st.global.f32   [%rd7], %f2;

L0: ret;
}

读 PTX 的几个要点:

  1. .reg 申明寄存器。.b32 是 32-bit,.f32 是浮点。
  2. %ctaid / %ntid / %tid 是 blockIdx / blockDim / threadIdx 的内置寄存器。
  3. mad.lo.s32 是 32-bit 整数 multiply-add(a*b+c)。.lo 表示取低 32 位。
  4. @%p1 bra 是有条件跳转。
  5. cvta.to.global.u64 是 generic 指针转 global 指针(CUDA 有 generic/global/shared 多种地址空间)。

20.4 看 SASS 输出

SASS 是真正的机器码。用 cuobjdumpnvdisasm 反汇编:

cuobjdump --dump-sass my_kernel.cubin
# 或者直接对 .o 文件:
cuobjdump --dump-sass my_kernel.o

输出(Hopper SASS 示例):

Function : _Z11my_kernelPfS_i
.headerflags @"EF_CUDA_SM90_TEXMODE_UNIFIED EF_CUDA_64BIT_ADDRESS EF_CUDA_SM90 ..."
        /*0000*/                   IMAD.MOV.U32 R1, RZ, RZ, c[0x0][0x28] ;
        /*0010*/                   S2R R0, SR_CTAID.X ;
        /*0020*/                   ULDC.64 UR4, c[0x0][0x208] ;
        /*0030*/                   S2R R3, SR_TID.X ;
        /*0040*/                   IMAD R0, R0, c[0x0][0x0], R3 ;
        /*0050*/                   ISETP.GE.AND P0, PT, R0, c[0x0][0x178], PT ;
        /*0060*/             @P0   EXIT ;
        /*0070*/                   IMAD.WIDE R2, R0, 0x4, c[0x0][0x160] ;
        /*0080*/                   LDG.E R4, [R2.64] ;
        /*0090*/                   FADD R5, R4, 1 ;
        /*00a0*/                   IMAD.WIDE R2, R0, 0x4, c[0x0][0x168] ;
        /*00b0*/                   STG.E [R2.64], R5 ;
        /*00c0*/                   EXIT ;

读 SASS 的关键点:

  1. IMAD.MOV.U32 是 fused multiply-add。Hopper 上经常用 IMAD 替代 MOV 因为更省功耗。
  2. S2R R0, SR_CTAID.X 把 special register(CTAID)读到 R0。
  3. c[0x0][0x178] 是 constant memory 访问,c[bank][offset]
  4. LDG.E 是 32-bit Load Global,E 表示 elegant addressing。
  5. @P0 EXIT 是有条件退出(基于 predicate P0)。
  6. IMAD.WIDE 是 32×32→64 乘加,常用于地址计算。

注意 SASS 比 PTX 少了几条指令——编译器把 PTX 中的 cvta 等"虚拟操作"消掉了。

20.5 看 SASS 找性能问题

SASS 能看到 ncu 看不到的细节:

20.5.1 寄存器 spill 检测

如果 SASS 里看到大量 LDL/STL(local memory load/store),说明编译器把寄存器溢出到 local memory(实际是 HBM)。这是性能毒药。

LDL.LU.U8 R10, [R1+0x80] ;       // ← 从 local memory 加载, 慢!

修复:减少局部变量、降低 unroll 程度、用 __launch_bounds__ 提示编译器降低寄存器使用。

20.5.2 Bank Conflict 实证

SMEM 访问指令是 LDSSTS。如果 SASS 里看到这些指令,可以用 ncu 测 bank conflict 数。但有些情况 SASS 能直接告诉你模式:

LDS.U.128 R4, [R20] ;   // 一次读 128 bit (4 个 fp32)
LDS.U.128 R8, [R20+0x80] ;
LDS.U.128 R12, [R20+0x100] ;

如果几个 LDS 的地址间隔是 128 bytes(32 banks × 4 bytes),它们一定 bank conflict。

20.5.3 控制依赖与延迟

Hopper SASS 里每条指令前有一个调度信息(control codes),格式像 --:-:-:-:1。它告诉硬件这条指令需要等多少 cycle 才能发下一条。如果看到很长的 stall(比如 :8:),说明这条指令有依赖等待。

编译器通常做得不错,但偶尔会看到次优的调度——这时候手写 PTX 调整指令顺序可能有用。

20.6 Inline PTX:何时用、怎么用

CUDA C++ 大部分情况下足够。但以下情况需要 inline PTX:

20.6.1 用还没暴露 C++ API 的硬件特性

Hopper 引入的很多新指令(TMA、WGMMA、setmaxnreg)在 C++ 层面只有部分包装,最完整的接口是 PTX。例如 wgmma.mma_async

asm volatile(
    "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16 "
    "{%0, %1, ..., %63}, %64, %65, p, 0, 0;\n"
    : "+f"(c[0]), "+f"(c[1]), ..., "+f"(c[63])
    : "l"(desc_a), "l"(desc_b), "n"(scale_d)
);

CUDA 12.x 才开始有 cuda::ptx::wgmma_* 包装,且不一定覆盖所有变体。FA3 和 CUTLASS 大量手写 PTX。

20.6.2 控制编译器无法表达的优化

某些 PTX 指令的"flag"在 C++ 层无法表达。比如 ld.global.cs (cache streaming) 指示这个 load 不应该污染 L2 cache。Triton/CUTLASS 里常见:

// 用 streaming load 避免污染 L2
asm("ld.global.cs.b128 {%0,%1,%2,%3}, [%4];\n"
    : "=r"(v0), "=r"(v1), "=r"(v2), "=r"(v3)
    : "l"(addr));

20.6.3 强制特定指令

编译器有时会"自作聪明"——把你的 int4 解码成多条指令而不是一条 lop3.b32。这时手写 PTX 强制用 lop3 能省指令带宽(第 9 章 Marlin 的例子)。

20.6.4 让 fragment 直接进 mma

mma.sync 的输入要求 fragment 在特定寄存器。直接用 inline PTX 控制寄存器分配,比让 C++ 编译器折腾更可靠。

20.7 一组高频 PTX/SASS Pattern

LLM kernel 中高频出现的指令:

A. mma.sync (Tensor Core 矩阵乘)

mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32
    {%0, %1, %2, %3},
    {%4, %5, %6, %7},
    {%8, %9},
    {%0, %1, %2, %3};

SASS:

HMMA.16816.F32 R4, R8, R12, R4 ;

B. cp.async (Ampere+ 异步拷贝)

cp.async.cg.shared.global [%0], [%1], 16;
cp.async.commit_group;
cp.async.wait_group 0;

SASS:

LDGSTS.E.128 [R10], [R20.64] ;
LDGDEPBAR ;
DEPBAR.LE SB0, 0x0 ;

C. cp.async.bulk.tensor (Hopper TMA)

cp.async.bulk.tensor.2d.shared::cluster.global.tile.mbarrier::complete_tx::bytes
    [%0], [%1, {%2, %3}], [%4];

SASS(Hopper):

UTMACPY.2D.SP /* TMA Async copy 2D */

D. ldmatrix (从 SMEM 加载 fragment)

ldmatrix.sync.aligned.m8n8.x4.shared.b16
    {%0, %1, %2, %3}, [%4];

SASS:

LDSM.16.MT88.4 R4, [R20] ;

E. lop3 (三输入逻辑运算)

lop3.b32 %0, %1, 0x000F000F, 0x64006400, 0xea;

SASS:

LOP3.LUT R4, R8, 0xf000f, 0x64006400, 0xea, !PT ;

熟悉这些 pattern 后,读 SASS 会变成"翻译"——每条 SASS 都对应一个清晰的工程意图。

20.8 一个有趣的实战:让编译器生成正确的指令

经常见到这种情况:你写了一段看起来高效的 C++ 代码,但 SASS 显示它生成了次优指令。

例:从 int8_tfloat

int8_t x = ...;
float f = (float)x;

朴素期望:1 条指令。

实际 SASS(某些情况下):

I2F.F32.S8 R4, R8 ;   // OK

但如果是 uint8_t:

uint8_t x = ...;
float f = (float)x;

SASS:

PRMT R4, R8, ... ;    // permute 把 8-bit 提取出来
I2F.F32.S32 R4, R4 ; // 然后 32→f32

两条指令而不是一条!原因是 SASS 没有 unsigned 8-bit 直接转 float 的指令。要省指令,用 int8_t 而不是 uint8_t,或者手写 inline PTX 用 cvt.rn.f32.u8

这种"编译器没生成你期望的指令"的事经常发生。读 SASS 是发现这种问题的唯一方法。

20.9 这一章的小结与下一章

PTX/SASS 是 CUDA 工程师的"反汇编技能":

  1. PTX 是虚拟 ISA,SASS 是真实机器码:JIT 把 PTX 转成 SASS,每代架构 SASS 不同。
  2. -keep 保留 PTXcuobjdump --dump-sass 看 SASS:这是日常工具。
  3. 读 SASS 找寄存器 spill / bank conflict / 次优指令:ncu 看不出的细节用 SASS。
  4. Inline PTX 是高级优化的最后一招:当 C++ 编译器不够聪明时,手写 PTX 强制特定指令。
  5. 熟悉高频指令 pattern:mma、cp.async、ldmatrix、lop3 等 LLM kernel 的关键指令。

第 21 章我们结束第五篇(也是本书的核心内容)——讲性能陷阱与反模式。一些常见的"看起来对、实际上慢"的写法,把它们罗列出来作为读者的避坑指南。读完第 21 章读者就完成了从理论到实战到诊断的完整训练。

本章动手练习

  1. 编译你之前写的某个 kernel,用 cuobjdump --dump-sass 看 SASS。找一行 SASS,对应回 CUDA C++ 源码。
  2. 写一段故意有寄存器 spill 的 kernel(比如几百个局部变量),看 SASS 中的 LDL/STL 指令。
  3. 用 inline PTX 写一条 lop3.b32,对比让编译器自动生成的版本。