CUDA 算子工程:手写 FlashAttention v2 之路

第 11 章 Tiled GEMM:Shared Memory 与 Double Buffer

作者 杨艺韬 · 2,165 字

第 11 章 Tiled GEMM:Shared Memory 与 Double Buffer

"Tiling is to GPU what divide-and-conquer is to algorithms — once you see it, you start seeing it everywhere." ——CUDA 教学传统

11.1 为什么需要 Tiling

第 10 章看到朴素 GEMM 只能跑到 1% 峰值,瓶颈是实际算术强度只有 ~1。要让算力发挥出来,必须把 A 和 B 的元素复用——同一个元素被读到 SMEM 后,被多个线程多次使用,再写回。

Tiling(瓦片化)就是这种复用的工程实现。基本结构是:

flowchart TB
  subgraph HBM_Layer [HBM 层]
    A[A 矩阵 M×K]
    B[B 矩阵 K×N]
    C[C 矩阵 M×N]
  end
  subgraph SMEM_Layer [SMEM 层 · per Block]
    SA[A_tile 128×Tk]
    SB[B_tile Tk×128]
  end
  subgraph Reg_Layer [Register 层 · per Warp]
    RA[A_frag 64×Tk_inner]
    RB[B_frag Tk_inner×64]
    RC[C_acc 64×64]
  end
  HBM_Layer -->|Block iteration<br/>每次 K 维移动 Tk| SMEM_Layer
  SMEM_Layer -->|Warp iteration<br/>每次 K 维移动 Tk_inner| Reg_Layer
  Reg_Layer -->|FMA 累加| Reg_Layer

每个 Block 处理 C 的一个 M_block × N_block 子块(典型 128×128),扫过 K 维度时不断从 HBM 加载新的 A_tile / B_tile 到 SMEM。每个 Warp 处理 block 内的一个 M_warp × N_warp 子块(典型 64×64),从 SMEM 读 fragment 到寄存器。最内层是寄存器中的 8×8 累加(每线程)。

11.2 Block Tile:SMEM 中的复用

定义 tile 大小:

constexpr int BM = 128;  // Block tile M 维
constexpr int BN = 128;  // Block tile N 维
constexpr int BK = 16;   // Block tile K 维

每个 Block 处理 BM × BN 的 C 子块。内层 K 维分块 BK,每次从 HBM 拉 BM × BK 的 A_tile 和 BK × BN 的 B_tile 到 SMEM。

SMEM 占用:

A_tile: 128 × 16 × 4 = 8192 bytes = 8 KB
B_tile: 16 × 128 × 4 = 8192 bytes = 8 KB
合计:                   16 KB

H100 单 SM SMEM 228 KB,能放 14 个 block 的 single-buffer(实际 occupancy 受寄存器限制更严,2-4 个 block)。

11.2.1 加载 A_tile / B_tile 到 SMEM

每个 Block 启动 BM × BN / (TM × TN) = 256 个线程(每线程算 8×8 = 64 个 C 元素)。256 线程协作加载 8 KB 的 A_tile 和 8 KB 的 B_tile。

__shared__ float sA[BM][BK];   // 128 × 16
__shared__ float sB[BK][BN];   // 16 × 128

const int tid = threadIdx.y * blockDim.x + threadIdx.x;
const int row = blockIdx.y * BM;
const int col = blockIdx.x * BN;

// A_tile 加载: 256 线程加载 128 × 16 = 2048 个元素
// 每线程加载 2048 / 256 = 8 个元素 (但 128 行 × 16 列, 安排成每 16 线程一行)
for (int load = 0; load < BM * BK / 256; load++) {
    int idx = tid + load * 256;
    int load_row = idx / BK;
    int load_col = idx % BK;
    sA[load_row][load_col] = A[(row + load_row) * K + (k_step + load_col)];
}

类似地加载 B_tile。

11.2.2 内层乘法:从 SMEM 到寄存器

每个 Warp(32 线程)处理 64 × 64 子块,每线程算 8 × 8 个 C 元素:

constexpr int WM = 64;   // Warp tile M
constexpr int WN = 64;   // Warp tile N
constexpr int TM = 8;    // Thread tile M
constexpr int TN = 8;    // Thread tile N

float c[TM][TN] = {0};   // 64 个寄存器

for (int kk = 0; kk < BK; ++kk) {
    float a[TM], b[TN];
    #pragma unroll
    for (int i = 0; i < TM; ++i) a[i] = sA[warp_m * WM + thread_m * TM + i][kk];
    #pragma unroll
    for (int j = 0; j < TN; ++j) b[j] = sB[kk][warp_n * WN + thread_n * TN + j];

    #pragma unroll
    for (int i = 0; i < TM; ++i)
        #pragma unroll
        for (int j = 0; j < TN; ++j)
            c[i][j] += a[i] * b[j];
}

每个线程从 SMEM 读 8 个 a + 8 个 b = 16 个浮点(64 字节),做 64 次 mul-add(128 FLOPs)。单线程算术强度 = 128 / 64 = 2 FLOPs/byte——比朴素 GEMM 提升 8×。

11.3 完整的 Tiled GEMM Kernel

把上面的 piece 组装起来:

template <int BM, int BN, int BK, int WM, int WN, int TM, int TN>
__global__ void gemm_tiled(
    const float* A, const float* B, float* C,
    int M, int N, int K
) {
    __shared__ float sA[BM][BK];
    __shared__ float sB[BK][BN];

    const int tid = threadIdx.y * blockDim.x + threadIdx.x;
    const int warp_id = tid / 32;
    const int lane_id = tid % 32;
    const int warp_m = warp_id / (BN / WN);
    const int warp_n = warp_id % (BN / WN);
    const int thread_m = lane_id / (WN / TN);
    const int thread_n = lane_id % (WN / TN);

    const int block_row = blockIdx.y * BM;
    const int block_col = blockIdx.x * BN;

    float c[TM][TN] = {0};

    // 沿 K 维度迭代
    for (int k_step = 0; k_step < K; k_step += BK) {
        // 1. 协作加载 A_tile, B_tile 到 SMEM
        #pragma unroll
        for (int i = tid; i < BM * BK; i += blockDim.x * blockDim.y) {
            int r = i / BK, c_ = i % BK;
            sA[r][c_] = A[(block_row + r) * K + (k_step + c_)];
        }
        #pragma unroll
        for (int i = tid; i < BK * BN; i += blockDim.x * blockDim.y) {
            int r = i / BN, c_ = i % BN;
            sB[r][c_] = B[(k_step + r) * N + (block_col + c_)];
        }
        __syncthreads();

        // 2. 每线程算 TM × TN 个累加
        for (int kk = 0; kk < BK; ++kk) {
            float a[TM], b[TN];
            #pragma unroll
            for (int i = 0; i < TM; ++i)
                a[i] = sA[warp_m * WM + thread_m * TM + i][kk];
            #pragma unroll
            for (int j = 0; j < TN; ++j)
                b[j] = sB[kk][warp_n * WN + thread_n * TN + j];

            #pragma unroll
            for (int i = 0; i < TM; ++i)
                #pragma unroll
                for (int j = 0; j < TN; ++j)
                    c[i][j] += a[i] * b[j];
        }
        __syncthreads();
    }

    // 3. 写 C
    #pragma unroll
    for (int i = 0; i < TM; ++i) {
        #pragma unroll
        for (int j = 0; j < TN; ++j) {
            int r = block_row + warp_m * WM + thread_m * TM + i;
            int c_ = block_col + warp_n * WN + thread_n * TN + j;
            if (r < M && c_ < N) C[r * N + c_] = c[i][j];
        }
    }
}

// Launch
dim3 block(16, 16);  // 256 线程
dim3 grid(N / BN, M / BM);
gemm_tiled<128, 128, 16, 64, 64, 8, 8><<<grid, block>>>(A, B, C, M, N, K);

实测:H100, M=N=K=4096, FP32:

朴素 GEMM:        ~700 GFLOPs    (1%)
Thread tile 4×4:  ~2800 GFLOPs   (4%)
Tiled GEMM:       ~12 TFLOPs     (18%)

12 TFLOPs,提升到 18%。但还远没到 SIMT 上限。

11.4 优化 1:Double Buffer + Async Copy

上面的代码有一个明显的"同步等待":每次 K 迭代完,要 __syncthreads() 等所有线程算完,才能加载下一个 K_tile。这段时间 SM 处于"等待加载"状态——计算单元闲置。

Double buffer(双缓冲)让计算和加载重叠:

__shared__ float sA[2][BM][BK];   // 两份缓冲
__shared__ float sB[2][BK][BN];

// 预加载第一个 buffer
load_to_smem(sA[0], sB[0], k_step=0);
__syncthreads();

for (int k_step = BK; k_step < K; k_step += BK) {
    int cur = ((k_step / BK) - 1) % 2;
    int next = (k_step / BK) % 2;

    // 异步加载下一个 buffer
    cp_async_global_to_shared(sA[next], A_addr_at(k_step));
    cp_async_global_to_shared(sB[next], B_addr_at(k_step));
    cp_async_commit();

    // 同时计算当前 buffer
    compute_on_smem(sA[cur], sB[cur], &c);

    cp_async_wait_all();
    __syncthreads();
}

// 算最后一个 buffer
compute_on_smem(sA[(K/BK - 1) % 2], sB[(K/BK - 1) % 2], &c);

这里用到了 Ampere+ 的 cp.async.cg.shared.global 指令——异步地从 HBM 拷贝到 SMEM,期间 SIMT cores 可以继续算。

实测带来 ~30% 提升:12 TFLOPs → 16 TFLOPs(24%)

11.5 优化 2:解决 SMEM Bank Conflict

sA[warp_m * WM + thread_m * TM + i][kk] 这一句,KK 固定时,TM=8 个连续行:

thread_m=0: sA[0][kk], sA[1][kk], ..., sA[7][kk]
            分别落到 bank (0, 16, 0, 16, 0, 16, 0, 16) (重复!)

每行 BK=16 个 float,跨行步长 16 个 float = 16 banks。所以 sA[i][kk] 和 sA[i+1][kk] 落在同一个 bank。一个 warp 内 32 线程读 8 个连续行 → 严重 bank conflict。

解决方法:+1 paddingswizzled 布局

11.5.1 +1 Padding

__shared__ float sA[BM][BK + 1];   // 17 列而不是 16

每行 17 列后,sA[i][kk] 和 sA[i+1][kk] 落到 bank (kk % 32, (kk + 17) % 32)——错开了。但每行多一列浪费 SMEM 6%。

11.5.2 Swizzled Layout

更高级的做法是按"对角线"放置元素:

__device__ __forceinline__ int swizzle(int row, int col) {
    return col ^ (row & 0x7);
}
sA[i][swizzle(i, kk)] = ...;

把行号和列号 XOR 一下,让相邻行的元素自动落到不同 bank。这是 CUTLASS 中的标准技巧,第 12 章会详细看。

实测引入 padding 后:16 TFLOPs → 21 TFLOPs(31%)

11.6 优化 3:寄存器 Blocking 与读取顺序

最内层的 mul-add 循环:

for (int kk = 0; kk < BK; ++kk) {
    for (int i = 0; i < TM; ++i) a[i] = sA[...][kk];
    for (int j = 0; j < TN; ++j) b[j] = sB[kk][...];
    for (int i = 0; i < TM; ++i)
        for (int j = 0; j < TN; ++j)
            c[i][j] += a[i] * b[j];
}

这段代码每次 kk 都要重新 load ab。如果 BK=16,那总共 16 次 load × (TM + TN) = 256 次 SMEM 访问。

更好的方式是外层 unroll kk 几步,复用 ab 寄存器

for (int kk = 0; kk < BK; kk += 4) {
    float a[4][TM], b[4][TN];
    #pragma unroll
    for (int u = 0; u < 4; ++u) {
        for (int i = 0; i < TM; ++i) a[u][i] = sA[...][kk + u];
        for (int j = 0; j < TN; ++j) b[u][j] = sB[kk + u][...];
    }
    #pragma unroll
    for (int u = 0; u < 4; ++u) {
        #pragma unroll
        for (int i = 0; i < TM; ++i)
            #pragma unroll
            for (int j = 0; j < TN; ++j)
                c[i][j] += a[u][i] * b[u][j];
    }
}

或者用更流行的"Outer Product 累加"方式。这些细节在 CUTLASS 里都有现成实现。

11.7 性能演进表

把所有优化加上:

版本 优化 TFLOPs (H100, FP32) % FP32 SIMT 峰值
v0 朴素 0.7 1%
v1 Thread tile 4×4 2.8 4%
v2 + Block tile + SMEM 12 18%
v3 + Double buffer (cp.async) 16 24%
v4 + Bank conflict fix 21 31%
v5 + 寄存器 blocking + unroll 25 37%
目标 SIMT 极限 (cuBLAS SGEMM) ~50 ~75%

到 v5 我们达到了 SIMT 写法的合理水平(37%),距离 cuBLAS 还有 2× 差距。剩下的 2× 差距,70% 来自 Tensor Core——SIMT 单核的浮点率根本打不过 Tensor Core,FP32 SIMT 峰值 67 TFLOPs vs FP16 Tensor Core 989 TFLOPs

这就是为什么第 12 章必须引入 Tensor Core——SIMT 路线已经到顶了

11.8 这一章的小结与下一章

Tiled GEMM 是 GEMM 优化的"地基":

  1. 三层 tile(block / warp / thread)让数据在 SMEM 和寄存器中层层复用,把实际算术强度从 1 提到 ~100+。
  2. Double buffer + cp.async 让 HBM 拷贝和计算重叠,隐藏 HBM 延迟。
  3. Bank conflict 处理(padding 或 swizzle)确保 SMEM 带宽。
  4. 寄存器 blocking 和 unroll 减少 SMEM 访问。
  5. SIMT 路径的极限是 ~37%——不是技巧不够,是 FP32 SIMT 算力本身就打不过 Tensor Core。

第 12 章我们引入 Tensor Core——mma.sync 指令、ldmatrix 指令、layout swizzle。这是 GEMM 性能再提升 2× 的关键。读完第 12 章,读者写出的 HGEMM kernel 能达到 cuBLAS 80%+ 的水平。

本章动手练习

  1. 把 v0..v5 都实现一遍,记录性能演进。
  2. 用 Nsight Compute 看 v2 vs v4 的 smsp__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum 指标,验证 padding 消除了 bank conflict。
  3. 思考:为什么 BM=BN=128 比 BM=BN=64 更优?(提示:复用率与 SMEM 占用)