CUDA 算子工程:手写 FlashAttention v2 之路

第 10 章 朴素 GEMM 与 Roofline 分析

作者 杨艺韬 · 2,280 字

第 10 章 朴素 GEMM 与 Roofline 分析

"GEMM is to GPU what printf is to C — the canonical example, the first benchmark, and the hidden gateway to mastering the platform." ——经典 CUDA 工程师箴言

10.1 GEMM 是什么,为什么它占据 LLM 算力的 90%

GEMM = General Matrix Multiplication,通用矩阵乘:

Cm,n=αk=1KAm,kBk,n+βCm,nC_{m,n} = \alpha \sum_{k=1}^{K} A_{m,k} B_{k,n} + \beta C_{m,n}

LLM 推理 / 训练中无处不在:

一个 LLaMA-7B 单 token decoding 的浮点操作分布大约是:

GEMM:          ~85-90%
Attention:     ~5-8%   (但 attention 的内部本身也是 GEMM)
LayerNorm:     ~1-2%
Element-wise:  ~1-2%
其他:           <1%

如果把 attention 内部的 GEMM 也算进去,95%+ 的浮点运算都在 GEMM 上。所以 GEMM 性能直接决定 LLM 性能。

cuBLAS 的 SGEMM (FP32 GEMM) 在 H100 上 M=N=K=4096 能跑到 ~50 TFLOPs,约 75% 的 FP32 SIMT 峰值。Tensor Core 的 HGEMM (FP16 GEMM) 能跑到 ~750 TFLOPs,约 76% 的 FP16 Tensor Core 峰值。这个性能不是凭空来的——它是几十年 GPU 架构演进 + 几十万行 CUTLASS 代码堆出来的。

要理解 GEMM 性能的来路,得从最简单的版本开始。

10.2 朴素 GEMM:教科书式的三重循环

最直观的写法:每个线程算 C 的一个元素。

__global__ void gemm_naive(
    const float* A, const float* B, float* C,
    int M, int N, int K
) {
    int row = blockIdx.y * blockDim.y + threadIdx.y;
    int col = blockIdx.x * blockDim.x + threadIdx.x;

    if (row >= M || col >= N) return;

    float sum = 0.0f;
    for (int k = 0; k < K; ++k) {
        sum += A[row * K + k] * B[k * N + col];
    }
    C[row * N + col] = sum;
}

启动配置:

dim3 block(16, 16);
dim3 grid((N + 15) / 16, (M + 15) / 16);
gemm_naive<<<grid, block>>>(A, B, C, M, N, K);

这段代码逻辑正确,能算对 GEMM。但性能呢?

10.2.1 实测性能

H100, M=N=K=4096, FP32:

朴素 GEMM:         ~700 GFLOPs    (1% of FP32 peak)
cuBLAS SGEMM:      ~50 TFLOPs     (75% of FP32 peak)
差距:              70x

70 倍的差距。这就是"会写"和"写好"的距离。

10.3 用 Roofline 找瓶颈

朴素 GEMM 的瓶颈在哪?用第 4 章学的 Roofline 模型分析。

10.3.1 朴素 GEMM 的算术强度

每个线程计算 C 的一个元素,需要:

每个 thread 的算术强度:

AIthread=2K4K+4K+4=2K8K+414 FLOPs/byte\text{AI}_{\text{thread}} = \frac{2K}{4K + 4K + 4} = \frac{2K}{8K + 4} \approx \frac{1}{4} \text{ FLOPs/byte}

K 大时趋近于 1/4。这是非常低的算术强度——比 LayerNorm 和 RMSNorm 还低!

整个 kernel 的算术强度不是这样算的。多个线程之间会复用 A 和 B 的不同部分

如果 cache 完美命中,整个 kernel 的总访存只需要:

整体算术强度:

AIkernel=2MNK4(MK+KN+MN)\text{AI}_{\text{kernel}} = \frac{2 M N K}{4(MK + KN + MN)}

M=N=K=4096 时:

AI=2409634340962=2409612683 FLOPs/byte\text{AI} = \frac{2 \cdot 4096^3}{4 \cdot 3 \cdot 4096^2} = \frac{2 \cdot 4096}{12} \approx 683 \text{ FLOPs/byte}

683 FLOPs/byte 的算术强度,远超过 H100 的临界点 295。这意味着 GEMM 在大尺寸下应该是完全 compute-bound 的——理论上能跑到算力峰值。

但朴素 GEMM 只跑了 1%。问题不是算力或带宽不够,而是访存模式让 cache 完全失效

10.3.2 朴素 GEMM 为什么 cache 失效

看朴素 GEMM 的访存模式:

sum += A[row * K + k] * B[k * N + col];

线程 (row, col=0) 和 (row, col=1) 都读 A 的同一行,但它们在不同的 SM 上(gridDim 维度大),它们之间的 A 数据在 L2 cache 里有可能命中,但远不是必然

更致命的是 B 的访问。线程 (row, col) 读 B[k * N + col]

B 的访问步长是 16 KB——意味着每读一次 B,都跳过 4 个 cache line 大小。L1 cache 完全无效(线性扫描下 cache 永远 miss)。L2 cache 50 MB,N=4096 时 B 整个矩阵 64 MB——也放不下。

最终结果:朴素 GEMM 的 B 矩阵几乎每次都要从 HBM 读,算术强度从理论 683 退化到约 1(每读 1 字节算 1 FLOP)。带宽 bound,跑不快。

10.3.3 朴素 GEMM 的 Roofline 位置

                  实际 TFLOPs/s

              989 ─────┼──────────────────────  Tensor Core FP16 峰值

                  500 ─┤

                       │  cuBLAS SGEMM
                       │  AI~683, perf~50
                  100 ─┤


                  10  ─┤        ╱
                       │       ╱
                       │  朴素 GEMM
                       │  实际 AI ~1, perf ~0.7
                  1  ──┤
                       └─────────────────────►  算术强度
                       1     10   100  1000
                       带宽屋顶 (3.35 TB/s)

朴素 GEMM 的"实际算术强度"约 1,落在带宽屋顶下,所以无法发挥算力优势。优化的方向是把算术强度从 1 提到 ~683——这要求充分复用 SMEM 和寄存器中的 A/B 数据,让"理论算术强度"和"实际算术强度"接近。

10.4 GEMM 优化的三大武器

要让 GEMM 跑到算力峰值,需要三层 tile:

flowchart TB
  HBM[A, B 在 HBM] -->|每次 K_tile 列| GTile[Block-level tile]
  GTile -->|放进 SMEM| SMEM[SMEM tile A, B]
  SMEM -->|每次 mma_k 列| WTile[Warp-level tile]
  WTile -->|放进寄存器| REG[Register fragment]
  REG -->|喂给 Tensor Core| TC[mma.sync 一次<br/>16×8×16 矩阵乘]
  TC -->|累加| ACC[寄存器中的 C accumulator]

三层 tile:

  1. Block-level tile:一个 Block 处理 C 的一个 M_tile × N_tile 子块(典型 128×128)。一个 SM 上 active 的 block 数 = SMEM_total / (block_smem_size)。
  2. Warp-level tile:一个 warp 处理 block tile 内的一个 M_warp × N_warp 子块(典型 64×64)。一个 block 通常 8 个 warp。
  3. Tensor Core mma:一个 mma.sync 操作 16×8×16(Hopper 上是 64×N×16 with WGMMA)。

这三层 tile 配合 double buffering(异步拷贝下一个 K tile 时算当前),构成现代 GEMM 的核心结构。

10.4.1 数据复用计算

理想情况下:

这种层层复用让每个从 HBM 读上来的字节被多次使用,有效算术强度 接近理论上限。

10.4.2 Tensor Core 的角色

第 12 章会详细讲,这里先建立直觉:

这就是 H100 FP16 Tensor Core 峰值的来路。要跑满这个峰值,每个 SM 每周期必须发出一条 mma 指令——这要求所有数据准备好、SMEM/寄存器完美配合。任何一个环节 stall(cache miss、bank conflict、寄存器 spill),峰值就打不到。

10.5 优化路径预告

后三章我们会一步步从朴素 GEMM 推进到 cuBLAS 性能:

每一步都对应一组真实的工程技巧,每一组技巧都把性能往上推一截。这条优化路径同时也是后续 FA2 优化的基础——FA2 内部的 QK^T 和 PV 矩阵乘,都基于这套 GEMM 框架。

10.6 一个朴素优化:Block 内复用

在进入第 11 章之前,先做一个小优化作为"warm-up"——让 thread 算多个 C 元素

template <int TM = 4, int TN = 4>  // 每线程算 TM × TN 个 C 元素
__global__ void gemm_thread_tile(
    const float* A, const float* B, float* C,
    int M, int N, int K
) {
    int row = (blockIdx.y * blockDim.y + threadIdx.y) * TM;
    int col = (blockIdx.x * blockDim.x + threadIdx.x) * TN;

    float c[TM][TN] = {0};

    for (int k = 0; k < K; ++k) {
        float a[TM], b[TN];
        #pragma unroll
        for (int i = 0; i < TM; ++i) a[i] = A[(row + i) * K + k];
        #pragma unroll
        for (int j = 0; j < TN; ++j) b[j] = B[k * N + (col + j)];
        #pragma unroll
        for (int i = 0; i < TM; ++i)
            #pragma unroll
            for (int j = 0; j < TN; ++j)
                c[i][j] += a[i] * b[j];
    }

    #pragma unroll
    for (int i = 0; i < TM; ++i)
        #pragma unroll
        for (int j = 0; j < TN; ++j)
            C[(row + i) * N + (col + j)] = c[i][j];
}

每线程算 4×4=16 个 C 元素。读取 A 的 4 个元素 + B 的 4 个元素 = 8 次访存,做 16 次 mul-add。算术强度变成 32 / 32 = 1 FLOPs/byte——比朴素的 0.25 提升 4 倍。

实测性能:朴素 GEMM 700 GFLOPs → thread tile 4×4 ≈ 2800 GFLOPs(4× 提升)。

但这远没到目标。下一步要做的是让多个 thread 共享 A/B——这就是 Block tile + SMEM 的舞台。

10.7 这一章的小结与下一章

这一章我们建立了 GEMM 优化的"地图":

  1. GEMM 是 LLM 算力消耗的 95%+:所有优化的最大投资点。
  2. 朴素 GEMM 只跑 1% 峰值:不是算力或带宽不够,而是 cache 完全失效。
  3. 理论算术强度高(~683),实际算术强度低(~1):必须靠 SMEM + 寄存器层层复用提升实际 AI。
  4. 三层 tile 是 GEMM 优化的核心结构:Block tile / Warp tile / Tensor Core mma。
  5. 简单的 thread tile 就能 4× 提升:但还远没到峰值,需要 SMEM 和 Tensor Core。

第 11 章我们正式进入 SMEM 优化的世界——把 Block tile 放进 SMEM,配合 double buffering 让 SMEM 和 HBM 流水起来。这个版本(不用 Tensor Core)能达到 ~30-40% 算力峰值,是 SIMT GEMM 的合理目标。第 12 章再加上 Tensor Core 后,性能会一举跃升到 ~80%。

本章动手练习

  1. 实现朴素 GEMM 和 thread tile 4×4 GEMM,对比性能。
  2. 用 Nsight Compute 看朴素 GEMM 的 dram__sectors_read.sum,估算实际 HBM 流量,验证"实际算术强度 ~1"的判断。
  3. 思考:为什么 thread tile 要选 4×4 而不是 8×8?(提示:寄存器压力)