CUDA 算子工程:手写 FlashAttention v2 之路
第 15 章 FA2 前向:Tiling 与 Online Softmax
第 15 章 FA2 前向:Tiling 与 Online Softmax
"FA2 is to FA1 what a sequel often is to its predecessor — same plot, better cinematography. The math didn't change; the parallelism did." ——一种常见的 FA2 评价
15.1 FA2 vs FA1:parallelism 的重新分配
FlashAttention v2 论文(Dao, 2023)相比 v1 的核心改动不是新算法,而是新的并行划分。
FA1 的并行结构:
- 一个 thread block 处理一个 (Q_block) × (所有 K_block) 的循环。
- block 之间在不同 query 行上并行。
FA1 的问题:在 forward 时 Q_block 数 << K_block 数(典型 N=4096 时 Q_block=64 个,K_block=64 个,但 forward 在 query 维度切分),导致 GPU SM 数远多于 block 数,occupancy 严重不足。
FA2 的修正:
- 把 K_block 也作为并行轴:(Q_block) × (K_block 维度的 reduce) 仍由一个 block 处理,但 multiple Q heads × multiple Q blocks 组成 grid。
- 把 warp-level 的工作划分从"每 warp 算一行"改成"每 warp 算一组连续行"——更友好的 fragment 复用。
实际效果:在 A100 上 FA2 比 FA1 快 1.7-2x,主要来自更好的 occupancy 和 warp-level 并行。
15.2 Tile 大小的选择
FA2 forward 的 tile 大小有几个关键参数:
Br : Q block 的行数 (典型 64 或 128)
Bc : K/V block 的列数 (典型 64)
d : head dim (典型 64, 128, 256)
约束:
- SMEM 容量:要放下 Q_block (Br × d) + K_block (Bc × d) + V_block (Bc × d)。FP16 时 = 2 × (Br + 2 × Bc) × d 字节。
- 寄存器:每个 thread 要存 (Br/N_warps × d/8) 个 fp32 累加器。
- 算术强度:tile 越大,每次 K_tile 的 GEMM 越大,Tensor Core 利用率越好。
典型配置(H100, FP16, d=64):
| Br | Bc | SMEM 占用 | Block 内 warp 数 | 适用 |
|---|---|---|---|---|
| 64 | 64 | ~24 KB | 4 (128 thread) | 短序列 (N≤2048) |
| 128 | 64 | ~32 KB | 4 | 中序列 (N=4096) |
| 128 | 128 | ~64 KB | 8 (256 thread) | 长序列 (N=8192+) |
不同 head dim 也要调整。d=128 时所有 SMEM 翻倍,需要降低 Bc 维持 SMEM 预算。
15.3 完整的 FA2 Forward Kernel 骨架
下面是一个工作的 FA2 forward kernel。简化了一些边界处理和 swizzle 细节,但结构反映工业级实现:
template <int Br, int Bc, int d, int N_WARPS = 4>
__global__ void flash_attn_fwd(
const half* __restrict__ Q, // [B, H, N, d]
const half* __restrict__ K, // [B, H, N, d]
const half* __restrict__ V, // [B, H, N, d]
half* __restrict__ O, // [B, H, N, d]
float* __restrict__ LSE, // [B, H, N], for backward
int N,
float softmax_scale, // 1 / sqrt(d)
bool is_causal
) {
// Grid: (N / Br, H, B)
// Block: 128 threads (4 warps)
int q_tile_idx = blockIdx.x;
int head_idx = blockIdx.y;
int batch_idx = blockIdx.z;
int tid = threadIdx.x;
int warp_id = tid / 32;
int lane_id = tid % 32;
// ============ SMEM allocation ============
extern __shared__ half smem[];
half* sQ = smem; // [Br, d]
half* sK = sQ + Br * d; // [Bc, d]
half* sV = sK + Bc * d; // [Bc, d]
// ============ Pointer offset for this (B, H) ============
int q_offset = ((batch_idx * H + head_idx) * N + q_tile_idx * Br) * d;
int kv_offset = (batch_idx * H + head_idx) * N * d;
int o_offset = q_offset;
int lse_offset = (batch_idx * H + head_idx) * N + q_tile_idx * Br;
// ============ Load Q tile to SMEM (one-time) ============
cp_async_load_q_tile(sQ, Q + q_offset);
cp_async_commit_and_wait();
__syncthreads();
// ============ Output accumulator (in registers) ============
// Each warp handles Br/N_WARPS rows of Q.
// Each warp keeps Br_warp × d output accumulator in fp32 fragments.
constexpr int Br_warp = Br / N_WARPS;
constexpr int MMAS_M = Br_warp / 16; // # of mma.m16n8k16 in M
constexpr int MMAS_D = d / 8; // # of mma in N (output dim)
float O_acc[MMAS_M][MMAS_D][4] = {0}; // 累加 fragment
float row_max[MMAS_M][2] = {-INFINITY}; // 每 warp 持有 16 行 × 2 = 32 行的 max
// 实际每 mma 16 行, lane 持 2 行
float row_sum[MMAS_M][2] = {0};
// ============ Loop over K tiles ============
int k_tile_end = is_causal
? min(N, (q_tile_idx + 1) * Br)
: N;
for (int k_tile = 0; k_tile < k_tile_end; k_tile += Bc) {
// ---------- Load K, V tile to SMEM ----------
cp_async_load_kv_tile(sK, K + kv_offset + k_tile * d);
cp_async_load_kv_tile(sV, V + kv_offset + k_tile * d);
cp_async_commit_and_wait();
__syncthreads();
// ---------- 1) S_block = Q @ K^T (using mma.sync) ----------
// S_block is [Br_warp, Bc] in fragments (per warp).
constexpr int MMAS_N = Bc / 8;
float S_frag[MMAS_M][MMAS_N][4] = {0};
for (int kk = 0; kk < d; kk += 16) {
// ldmatrix Q fragments
unsigned q_frag[MMAS_M][4];
for (int i = 0; i < MMAS_M; ++i) {
int row = warp_id * Br_warp + i * 16;
ldmatrix_x4(sQ, row, kk, q_frag[i]);
}
// ldmatrix K fragments (with .trans for K^T)
unsigned k_frag[MMAS_N][2];
for (int j = 0; j < MMAS_N; ++j) {
int col = j * 8;
ldmatrix_x2_trans(sK, col, kk, k_frag[j]);
}
// mma accumulate into S_frag
for (int i = 0; i < MMAS_M; ++i)
for (int j = 0; j < MMAS_N; ++j)
mma_m16n8k16(S_frag[i][j], q_frag[i], k_frag[j]);
}
// ---------- 2) Apply softmax_scale and causal mask ----------
for (int i = 0; i < MMAS_M; ++i) {
for (int j = 0; j < MMAS_N; ++j) {
#pragma unroll
for (int e = 0; e < 4; ++e) {
S_frag[i][j][e] *= softmax_scale;
// Causal mask: lane_id 0/4/8/12... 持有不同列, 需精确计算
int my_row = q_tile_idx * Br + warp_id * Br_warp
+ i * 16 + (lane_id / 4) + (e / 2) * 8;
int my_col = k_tile + j * 8 + (lane_id % 4) * 2 + (e % 2);
if (is_causal && my_col > my_row) {
S_frag[i][j][e] = -INFINITY;
}
}
}
}
// ---------- 3) Online softmax: update row_max, row_sum ----------
// Each warp computes max/sum across columns within the tile.
// S_frag layout: lane (l/4, l%4) holds (rows [l/4, l/4+8], cols [2*(l%4), 2*(l%4)+1]).
// We need: for each row, find max and sum across all Bc columns.
for (int i = 0; i < MMAS_M; ++i) {
// 每 mma 块持有 16 行, 每 lane 拥有 2 行 (l/4 and l/4+8)
// 拿到 row 内所有 col 的 max
for (int row_local = 0; row_local < 2; ++row_local) {
float m_block = -INFINITY;
for (int j = 0; j < MMAS_N; ++j) {
for (int col_local = 0; col_local < 2; ++col_local) {
m_block = fmaxf(m_block,
S_frag[i][j][row_local * 2 + col_local]);
}
}
// Warp 内 reduce: 每行的 max 分布在 4 个 lane 上 (相同 lane_id/4)
// 用 shfl_xor 在 4 个 lane 之间归约
m_block = fmaxf(m_block, __shfl_xor_sync(0xFFFFFFFF, m_block, 1));
m_block = fmaxf(m_block, __shfl_xor_sync(0xFFFFFFFF, m_block, 2));
float m_old = row_max[i][row_local];
float m_new = fmaxf(m_old, m_block);
float alpha = expf(m_old - m_new);
// 更新 P_frag (in-place 改写 S_frag, 同时算 row sum 增量)
float l_inc = 0.0f;
for (int j = 0; j < MMAS_N; ++j) {
for (int col_local = 0; col_local < 2; ++col_local) {
float p = expf(
S_frag[i][j][row_local * 2 + col_local] - m_new);
S_frag[i][j][row_local * 2 + col_local] = p;
l_inc += p;
}
}
// Warp 归约 l_inc
l_inc += __shfl_xor_sync(0xFFFFFFFF, l_inc, 1);
l_inc += __shfl_xor_sync(0xFFFFFFFF, l_inc, 2);
row_sum[i][row_local] = row_sum[i][row_local] * alpha + l_inc;
row_max[i][row_local] = m_new;
// ---------- 4) 缩放之前累积的 O_acc ----------
for (int j_d = 0; j_d < MMAS_D; ++j_d) {
for (int e = 0; e < 4; ++e) {
// 只有 row_local 对应的元素需要 alpha 缩放
if ((e / 2) == row_local) {
O_acc[i][j_d][e] *= alpha;
}
}
}
}
}
// ---------- 5) O_acc += P @ V ----------
// P 是 fp32 fragment, 要先 cast 回 fp16 给 mma 用
// 实际上 FA2 选择 fp16 P × fp16 V, fp32 累加
// P_frag (fp16) layout: [Br_warp, Bc] = MMAS_M × MMAS_N × 4
unsigned P_frag_fp16[MMAS_M][MMAS_N];
for (int i = 0; i < MMAS_M; ++i) {
for (int j = 0; j < MMAS_N; ++j) {
// 把 4 个 fp32 打包成 4 个 fp16 (保留高 16 位结构)
__half2 p01, p23;
p01.x = __float2half(S_frag[i][j][0]);
p01.y = __float2half(S_frag[i][j][1]);
p23.x = __float2half(S_frag[i][j][2]);
p23.y = __float2half(S_frag[i][j][3]);
P_frag_fp16[i][j * 2 + 0] = *reinterpret_cast<unsigned*>(&p01);
// ...
}
}
// mma: O_acc[i][d_j] += P_frag[i][j] @ V_frag[j][d_j]
for (int j = 0; j < MMAS_N; j += 2) { // P fragment 一组 16 列
unsigned v_frag[MMAS_D][2];
for (int d_j = 0; d_j < MMAS_D; ++d_j) {
int v_row = j * 8;
int v_col = d_j * 8;
ldmatrix_x2(sV, v_row, v_col, v_frag[d_j]);
}
for (int i = 0; i < MMAS_M; ++i) {
for (int d_j = 0; d_j < MMAS_D; ++d_j) {
unsigned p_input[4] = {
P_frag_fp16[i][j], P_frag_fp16[i][j + 1],
P_frag_fp16[i][j], P_frag_fp16[i][j + 1]
};
mma_m16n8k16(O_acc[i][d_j], p_input, v_frag[d_j]);
}
}
}
__syncthreads();
}
// ============ Final normalization: O = O_acc / row_sum ============
for (int i = 0; i < MMAS_M; ++i) {
for (int row_local = 0; row_local < 2; ++row_local) {
float scale = 1.0f / row_sum[i][row_local];
for (int d_j = 0; d_j < MMAS_D; ++d_j) {
for (int e = 0; e < 4; ++e) {
if ((e / 2) == row_local) {
O_acc[i][d_j][e] *= scale;
}
}
}
}
}
// ============ Write O to HBM ============
// (omit detailed lane->location mapping; conceptually each lane writes its 8 values)
for (int i = 0; i < MMAS_M; ++i) {
for (int d_j = 0; d_j < MMAS_D; ++d_j) {
int row = q_tile_idx * Br + warp_id * Br_warp + i * 16 + (lane_id / 4);
int col = d_j * 8 + (lane_id % 4) * 2;
half2 v;
v.x = __float2half(O_acc[i][d_j][0]);
v.y = __float2half(O_acc[i][d_j][1]);
*reinterpret_cast<half2*>(&O[o_offset + row * d + col]) = v;
// ... 写其他元素
}
}
// ============ Write LSE for backward ============
if (warp_id == 0) {
for (int i = 0; i < MMAS_M; ++i) {
for (int row_local = 0; row_local < 2; ++row_local) {
int row = q_tile_idx * Br + i * 16 + (lane_id / 4) + row_local * 8;
if (lane_id % 4 == 0 && row < N) {
LSE[lse_offset + row]
= row_max[i][row_local] + logf(row_sum[i][row_local]);
}
}
}
}
}
注意:上面是简化骨架,省略了 swizzle layout、cp.async pipeline depth、edge case 边界处理。完整实现请参考
flash-attention/csrc/flash_attn/flash_fwd_kernel.h。
15.4 关键实现要点
15.4.1 Online Softmax 的 fragment 级实现
第 6 章讲 online softmax 时是数学层面的。落到 fragment 级实现,关键挑战是 mma 的 fragment layout 让"per-row"操作变复杂——每个 lane 持有的不是连续的"一行",而是分散的"两行的几列"。
具体做法:
- 每行 max:每 lane 在自己持有的列上算 local max,然后用
shfl_xor在持有同一行的 4 个 lane 之间归约。 - 每行 sum:同上,归约 sum。
- 同步更新 m, l, O:所有更新都基于 fragment 内的本地数据。
15.4.2 P 矩阵的 fp16 / fp32 切换
S_frag 是 fp32 累加(mma 输出),但 P @ V 的 mma 输入要求 fp16。所以中间需要把 fp32 P_frag 转回 fp16,再 ldmatrix 重新 layout。
这个转换是 FA2 的细节难点。直接转可能引入精度损失(fp32 → fp16 一次往返),且 layout 重排有 SMEM bank conflict 风险。FA2 的实现做了很多工程化细节避免这些。
15.4.3 Causal Mask 的位置
Causal mask 必须在 softmax 之前 apply(对 -inf 取 exp 是 0,不影响 sum)。在 fragment 级,每个 lane 知道自己持有哪些 (row, col),直接对越界元素写 -inf。
15.4.4 Swizzle Layout
Q/K/V 在 SMEM 里的布局必须是 swizzled,否则 ldmatrix 触发 bank conflict。FA2 用 CUTLASS 的 Swizzle<3,3,3> 标准布局。第 13 章讲过的概念在这里直接落地。
15.5 性能调优要点
15.5.1 K 维度的 cp.async pipeline
外层 K 循环里,加载下一个 K_block + V_block 应该和当前 block 的计算重叠:
// 启动 K_tile 0 的加载
cp_async_load(sK[0], sV[0], 0);
cp_async_commit();
for (int k_tile = 0; k_tile < N; k_tile += Bc) {
int next = (k_tile / Bc + 1) % 2;
int cur = (k_tile / Bc) % 2;
// 启动下一个 K_tile 加载
if (k_tile + Bc < N) {
cp_async_load(sK[next], sV[next], k_tile + Bc);
cp_async_commit();
}
// 等当前 K_tile 加载完成
cp_async_wait();
__syncthreads();
// 算当前 K_tile
compute(sK[cur], sV[cur], &O_acc, ...);
}
这样 SM 的 Tensor Core 一直在算,HBM 加载在后台进行。
15.5.2 寄存器压力管理
FA2 forward 的寄存器需求很大:
- 输出累加器:MMAS_M × MMAS_D × 4 个 fp32 = ~64 个寄存器/线程
- P fragment:~16 个寄存器
- KV fragment:~32 个寄存器
- 临时 m, l, alpha:~10 个
合计 ~120 个/线程。如果超过 128,编译器开始 spill 到 local memory(HBM),性能塌陷。
减少压力的技巧:
- 把不常用的中间值 cast 到更低精度(比如 m, l 用 fp16)。
- 让编译器更激进地复用寄存器(用
volatile或__launch_bounds__提示)。 - 减少 unroll 程度(牺牲 ILP 换寄存器)。
15.5.3 Block Size 与 Occupancy
128 thread/block × 4 active block/SM = 512 active threads/SM。这只是 SM 上限 2048 的 1/4——occupancy 约 25%。
但 attention 的延迟隐藏不依赖 occupancy——它的瓶颈是 Tensor Core 算力,不是延迟。低 occupancy 的高 ILP kernel 反而比高 occupancy 的低 ILP kernel 快——这是 FA 的反直觉特点。
15.6 性能实测
H100 上 FA2 forward 的典型性能(FP16, head_dim=64):
N=512: ~250 TFLOPs (long-context decoding)
N=2048: ~450 TFLOPs (typical training)
N=4096: ~530 TFLOPs (long sequence training)
N=8192: ~580 TFLOPs (long context)
N=16384: ~600 TFLOPs (very long context)
对比 cuBLAS HGEMM 的 ~750 TFLOPs,FA2 的算力利用率约 70-77%。
从手写 kernel 到能达到 70%+ 性能,这个 gap 主要差在:
- CUTLASS 的细致 fragment 调度
- Hopper 的 TMA + WGMMA(FA3 的内容,第 17 章)
- PTX 微优化(指令排布、依赖距离)
15.7 这一章的小结与下一章
第 15 章我们写出了一个能工作的 FA2 forward kernel:
- FA2 vs FA1 的关键差别是 parallelism 划分:FA2 引入更多 grid-level 并行 + 更友好的 warp-level 工作分配。
- 完整的 FA2 forward kernel 由 5 个阶段组成:load Q tile → loop K tile (load → S=QK^T → online softmax → O+=PV) → final normalize → write O/LSE。
- online softmax 在 fragment 级实现需要小心处理 lane 持有的不连续 row/col 布局。
- cp.async pipeline 和寄存器压力管理是性能调优的关键。
- 手写 FA2 在 H100 上能达到 ~530 TFLOPs(约 cuBLAS 70%)——已经是不小的成就。
第 16 章我们处理 FA2 的反向。反向比前向复杂得多——需要重计算 S 和 P,需要原子写 dQ,且循环方向变成"外层 K,内层 Q"。读完第 16 章读者会拥有完整的 FA2 训练能力。
本章动手练习:
- 把上面的 FA2 forward 骨架完整实现(包括 swizzle layout 和正确的 lane→location 映射),在 H100 上测试 N=4096 的性能。
- 阅读
flash-attention/csrc/flash_attn/flash_fwd_kernel.h,对照 15.3 节的骨架找差异点(你会发现 swizzle、async、unroll 调度都更精细)。- 思考:如果 head_dim=128,tile 大小该怎么调?SMEM 还放得下吗?