CUDA 算子工程:手写 FlashAttention v2 之路
第 16 章 FA2 反向:dQ/dK/dV 的重计算
第 16 章 FA2 反向:dQ/dK/dV 的重计算
"Backward propagation is forward propagation in reverse — except every detail of the data layout that worked forward, breaks backward." ——做过 attention 反向的人共同的心声
16.1 反向的数学:从 dO 到 dQ/dK/dV
设 attention 前向:
给定 ,我们要算:
链式法则一步步:
dV 最直接:
dP:
dS(softmax 的反向): 设 ,则 。所以:
定义 ,则:
dQ, dK:
总结,反向需要算的量:
1. D = rowsum(P ⊙ dP) = rowsum(P ⊙ (dO @ V^T))
2. dV = P^T @ dO
3. dS = P ⊙ (dP - D)
4. dK = dS^T @ Q
5. dQ = dS @ K
注意 这一项——它是 按行求和。论文里有个简化技巧:
更简洁:。这避免了显式算 P。
16.2 反向的循环顺序:外层 K,内层 Q
前向时我们用"外层 Q,内层 K" —— 这样 O 和 m, l 状态可以在外层循环里增量更新,每个 Q_block 的 O 一次性算完。
但反向的 dV 和 dK 是按 K 维度累加的:
如果用"外层 Q 内层 K",每个 K_block 的 dV 会被多次部分更新——必须用 atomic 或多次 kernel launch。
反向应该用"外层 K,内层 Q":
for k_idx in range(0, N, Bc):
K_block = K[k_idx : k_idx + Bc]
V_block = V[k_idx : k_idx + Bc]
dV_acc = zeros(Bc, d)
dK_acc = zeros(Bc, d)
for q_idx in range(0, N, Br):
Q_block = Q[q_idx : q_idx + Br]
dO_block = dO[q_idx : q_idx + Br]
LSE_block = LSE[q_idx : q_idx + Br] # 来自前向
# 1) 重计算 S = Q @ K^T
S = Q_block @ K_block.T
# 2) 重计算 P (用前向存的 LSE)
P = exp(S - LSE_block.unsqueeze(1))
# 3) D = rowsum(O ⊙ dO) -- 用前向的 O
D_block = sum_per_row(O[q_idx:q_idx+Br] * dO_block)
# 4) dV 累加
dV_acc += P.T @ dO_block
# 5) dP, dS
dP = dO_block @ V_block.T
dS = P * (dP - D_block.unsqueeze(1))
# 6) dQ -- 这一步需要 atomic 累加 (跨 k_idx 多次更新)
dQ_partial = dS @ K_block
atomic_add(dQ[q_idx : q_idx + Br], dQ_partial)
# 7) dK 累加
dK_acc += dS.T @ Q_block
dV[k_idx : k_idx + Bc] = dV_acc
dK[k_idx : k_idx + Bc] = dK_acc
这种循环方式的特点:
- dV 和 dK 在外层 K 循环中累加,每个 K_block 处理完一次性写出,不需要 atomic。
- dQ 在内层 Q 循环中累加,但不同 k_idx 的 dQ 部分会落在同一个 q_idx 上——需要 atomic add 或者用第二个 kernel 汇总。
dQ 的处理是 FA2 反向的一个工程难点。
16.3 dQ 的处理:atomic vs 二阶段
dQ 累加的两种工程方案:
16.3.1 方案 A:atomic add 直接写
atomicAdd(dQ + q_idx * d + col, dQ_partial);
简单但慢——HBM atomic 几百 cycle,N×N 量级的 atomic 加起来很多。
实际工业实现用 atomic 时会做几个优化:
- Atomic 在 SMEM 上,最后一次性写 HBM:每个 k_block 的 q 部分只写一次 SMEM atomic,最后写 HBM。但这要求所有 k_block 都在同一个 thread block 内——不现实,因为外层循环遍历所有 K。
- 半精度 atomic:HBM 上 fp16 atomic 比 fp32 atomic 快。但 fp16 atomic 在 H100 之前不被原生支持。
16.3.2 方案 B:两阶段 kernel
第一个 kernel 算 dV 和 dK(外层 K 循环结构),把 dS 写到一个临时缓冲。
第二个 kernel 把 dQ = dS @ K 算出来(不同 K_block 的 dS 已经分别 stage)。
但这要求中间存储 dS——而 dS 又是 N×N 的 attention 矩阵,就是 FA 想避免的东西!
FA2 论文最终选择方案 A(atomic),因为:
- N×N 的 atomic 总量虽大,但每次 atomic 只是 8-16 个 fp16,颗粒度小。
- 并发竞争实际不严重(不同 k_idx 不同时刻访问同一 q)。
- 对比中间矩阵存储节省的带宽,atomic 开销可以接受。
实测 FA2 反向比朴素反向(先算 N×N 中间矩阵,再求梯度)快 3-5x。
16.4 LSE 在反向中的角色
注意 16.2 节的伪代码中第 2 步 P = exp(S - LSE_block.unsqueeze(1))。
这里 LSE 是前向输出的副产物:。
为什么需要 LSE?因为反向重计算 P 需要前向算过的归一化常数。如果不存 LSE,反向时要重新 online softmax 一遍——多一次完整的 K 维度遍历。
存 LSE 的代价:N 个 fp32 = 16 KB(N=4096 时),可以忽略。
16.5 反向的 SMEM 与寄存器布局
反向比前向需要更多状态:
- Q_block, K_block, V_block:和前向一样
- dO_block:新增 (Br, d) 大小
- LSE_block:新增 (Br) 大小
- O_block:新增 (Br, d)(用来算 D = rowsum(O ⊙ dO))
- dV_acc, dK_acc:(Bc, d) 累加器,保留在寄存器
SMEM 占用更紧,dV_acc/dK_acc 在寄存器中。寄存器压力比前向更大。
16.6 反向 Kernel 骨架
template <int Br, int Bc, int d>
__global__ void flash_attn_bwd(
const half* Q, const half* K, const half* V, const half* O, const half* dO,
const float* LSE,
half* dQ, half* dK, half* dV,
int N, float softmax_scale, bool is_causal
) {
// Grid: (N / Bc, H, B)
int k_tile_idx = blockIdx.x;
int head_idx = blockIdx.y;
int batch_idx = blockIdx.z;
extern __shared__ half smem[];
half* sK = smem;
half* sV = sK + Bc * d;
half* sQ = sV + Bc * d;
half* sO = sQ + Br * d;
half* sdO = sO + Br * d;
float* sLSE = (float*)(sdO + Br * d);
// Load K, V tile (一次, 整个 inner 循环用)
load_tile(sK, K + k_tile_idx * Bc * d, Bc * d);
load_tile(sV, V + k_tile_idx * Bc * d, Bc * d);
__syncthreads();
// 寄存器中的 dV, dK 累加器
float dV_acc[MMAS_K][MMAS_D][4] = {0};
float dK_acc[MMAS_K][MMAS_D][4] = {0};
// 内层: 遍历所有 Q tile
int q_start = is_causal ? k_tile_idx * Bc : 0;
for (int q_tile = q_start; q_tile < N; q_tile += Br) {
// Load Q, O, dO, LSE
load_tile(sQ, Q + q_tile * d, Br * d);
load_tile(sO, O + q_tile * d, Br * d);
load_tile(sdO, dO + q_tile * d, Br * d);
load_lse(sLSE, LSE + q_tile, Br);
__syncthreads();
// 1) S = Q @ K^T (与前向一样)
float S_frag[MMAS_M][MMAS_N][4];
compute_QKt(S_frag, sQ, sK);
scale_and_mask(S_frag, softmax_scale, q_tile, k_tile_idx, is_causal);
// 2) P = exp(S - LSE) -- 比 online softmax 简单, 因为 LSE 已知
float P_frag[MMAS_M][MMAS_N][4];
for (int i = 0; i < MMAS_M; ++i) {
for (int j = 0; j < MMAS_N; ++j) {
for (int e = 0; e < 4; ++e) {
int row = ...; // lane 持有的行
P_frag[i][j][e] = expf(S_frag[i][j][e] - sLSE[row]);
}
}
}
// 3) D = rowsum(O ⊙ dO)
float D_local[MMAS_M][2] = {0};
compute_D(D_local, sO, sdO); // load O, dO 的 fragment 并按行 reduce
// 4) dV_acc += P^T @ dO
// 注意 P 是 fp16, dO 是 fp16, dV_acc 是 fp32
unsigned P_fp16[MMAS_M][MMAS_N];
cast_fp32_to_fp16(P_fp16, P_frag);
unsigned dO_frag[MMAS_M][MMAS_D];
load_dO_fragment(dO_frag, sdO);
// 注意是 P^T, 需要 ldmatrix.trans 或者交换 fragment 角色
accumulate_matmul_T(dV_acc, P_fp16, dO_frag);
// 5) dP = dO @ V^T
float dP_frag[MMAS_M][MMAS_N][4];
compute_dOVt(dP_frag, sdO, sV);
// 6) dS = P ⊙ (dP - D)
for (int i = 0; i < MMAS_M; ++i) {
for (int j = 0; j < MMAS_N; ++j) {
for (int e = 0; e < 4; ++e) {
int row_local = e / 2;
dP_frag[i][j][e]
= P_frag[i][j][e] * (dP_frag[i][j][e] - D_local[i][row_local]);
}
}
}
// 现在 dP_frag 实际是 dS_frag
// 7) dQ_partial = dS @ K, atomic add to dQ
unsigned dS_fp16[MMAS_M][MMAS_N];
cast_fp32_to_fp16(dS_fp16, dP_frag);
unsigned K_frag[MMAS_N][MMAS_D];
load_K_fragment(K_frag, sK);
float dQ_partial[MMAS_M][MMAS_D][4] = {0};
accumulate_matmul(dQ_partial, dS_fp16, K_frag);
// atomic add to global dQ
atomic_write_dQ(dQ + q_tile * d, dQ_partial);
// 8) dK_acc += dS^T @ Q
unsigned Q_frag[MMAS_M][MMAS_D];
load_Q_fragment(Q_frag, sQ);
accumulate_matmul_T(dK_acc, dS_fp16, Q_frag);
__syncthreads();
}
// 写 dV, dK
write_dV(dV + k_tile_idx * Bc * d, dV_acc);
write_dK(dK + k_tile_idx * Bc * d, dK_acc);
}
完整实现请参考 flash-attention/csrc/flash_attn/flash_bwd_kernel.h。
16.7 反向的性能与挑战
反向比前向慢的原因:
- 三个 GEMM vs 前向的两个:dV、dK、dQ 都是独立的 GEMM 累加。
- dQ 的 atomic write:每次 inner Q 循环一次 atomic,N 大时累计严重。
- 更多 SMEM 输入:Q, K, V, O, dO, LSE 全部要在 SMEM 中。
- 重计算开销:S 和 P 在反向时要重算(这是 FA 的设计选择,避免存中间矩阵)。
实测 H100 上 FA2 反向(N=4096)大约 250-300 TFLOPs,是前向的一半左右。这个 ratio(反向 / 前向 ~= 0.5)和大多数算子的反向慢类似。
16.8 工业级实现的关键技巧
FlashAttention 官方实现(Tri Dao 的 GitHub 仓库)做了大量工程优化:
- Block-level 同步原语:使用
mbarrier而不是__syncthreads,减少同步开销。 - 动态调度 Q tile 顺序:causal mask 时,跳过 q < k 的 tile,加速明显。
- SwiGLU 融合:把 LN 和 attention 输出连起来 fuse 一些 element-wise op。
- LSE 精度选择:fp32 LSE 还是 fp16 LSE?后者更省 SMEM,但精度影响。
- dQ atomic 精度选择:fp32 atomic 慢但准,fp16 atomic 快但有精度损失。
每一个技巧都需要在精度和性能间权衡。FA2 论文里给出了详细的 ablation。
16.9 这一章的小结与下一章
FA2 反向的核心要点:
- 数学:dV/dK/dQ 都涉及重计算 S 和 P——这是用计算换存储的核心 trade-off。
- 循环顺序翻转:外层 K,内层 Q。dV/dK 在外层累加,dQ 必须 atomic。
- LSE 是前向到反向的桥梁:让反向不需要重新 online softmax。
- D = rowsum(O ⊙ dO) 是简化 dS 公式的关键。
- 反向比前向慢约 2x,但相比朴素反向快 3-5x。
到第 16 章为止,读者已经能写出一个能用、性能合理的 FA2(forward + backward)。但还没用到 Hopper 的杀手锏——TMA 和 Warp Specialization。
第 17 章我们把 FA2 重写到 Hopper 的最优形态——TMA 异步拷贝代替 cp.async、WGMMA 代替 mma.sync、Producer/Consumer warp 流水线。读完第 17 章读者会理解为什么 FA3 在 H100 上能跑到 740 TFLOPs(FP16),以及"现代 GPU kernel"和"Volta 时代的 GPU kernel"在编程模型上的根本差异。
本章动手练习:
- 推导 dS = P ⊙ (dP - D) 的代数过程。验证 D 的两种等价定义。
- 实现 16.6 节骨架,在 H100 上测试反向性能。预期 ~250 TFLOPs。
- 思考:如果不存 LSE,反向时怎么重新算?需要多少额外 HBM 流量?