CUDA 算子工程:手写 FlashAttention v2 之路

第 16 章 FA2 反向:dQ/dK/dV 的重计算

作者 杨艺韬 · 2,516 字

第 16 章 FA2 反向:dQ/dK/dV 的重计算

"Backward propagation is forward propagation in reverse — except every detail of the data layout that worked forward, breaks backward." ——做过 attention 反向的人共同的心声

16.1 反向的数学:从 dO 到 dQ/dK/dV

设 attention 前向:

S=QKT,P=softmax(S),O=PVS = QK^T, \quad P = \text{softmax}(S), \quad O = PV

给定 dO=L/OdO = \partial L / \partial O,我们要算:

dQ=L/Q,dK=L/K,dV=L/VdQ = \partial L / \partial Q, \quad dK = \partial L / \partial K, \quad dV = \partial L / \partial V

链式法则一步步:

dV 最直接

dV=PTdOdV = P^T \cdot dO

dP:

dP=dOVTdP = dO \cdot V^T

dS(softmax 的反向): 设 pi=softmax(si)p_i = \text{softmax}(s_i),则 pisj=pi(δijpj)\frac{\partial p_i}{\partial s_j} = p_i (\delta_{ij} - p_j)。所以:

dSij=Pij(dPijkPikdPik)dS_{ij} = P_{ij} \cdot (dP_{ij} - \sum_k P_{ik} dP_{ik})

定义 Di=kPikdPik=kPik(dOVT)ikD_i = \sum_k P_{ik} dP_{ik} = \sum_k P_{ik} (dO V^T)_{ik},则:

dSij=Pij(dPijDi)dS_{ij} = P_{ij} \cdot (dP_{ij} - D_i)

dQ, dK

dQ=dSK,dK=dSTQdQ = dS \cdot K, \quad dK = dS^T \cdot Q

总结,反向需要算的量:

1. D = rowsum(P ⊙ dP) = rowsum(P ⊙ (dO @ V^T))
2. dV = P^T @ dO
3. dS = P ⊙ (dP - D)
4. dK = dS^T @ Q
5. dQ = dS @ K

注意 DD 这一项——它是 PdPP \odot dP 按行求和。论文里有个简化技巧:

Di=kPikdPik=kPik(dOiVT)k=(dOiVT)PiT=dOi(VTPiT)=dOiOiT(scale)D_i = \sum_k P_{ik} \cdot dP_{ik} = \sum_k P_{ik} \cdot (dO_i V^T)_k = (dO_i \cdot V^T) \cdot P_i^T = dO_i \cdot (V^T P_i^T) = dO_i \cdot O_i^T \cdot \text{(scale)}

更简洁:D=rowsum(OdO)D = \text{rowsum}(O \odot dO)。这避免了显式算 P。

16.2 反向的循环顺序:外层 K,内层 Q

前向时我们用"外层 Q,内层 K" —— 这样 O 和 m, l 状态可以在外层循环里增量更新,每个 Q_block 的 O 一次性算完。

但反向的 dV 和 dK 是按 K 维度累加的:

dVj=iPijdOidV_j = \sum_i P_{ij} \cdot dO_i

如果用"外层 Q 内层 K",每个 K_block 的 dV 会被多次部分更新——必须用 atomic 或多次 kernel launch。

反向应该用"外层 K,内层 Q"

for k_idx in range(0, N, Bc):
    K_block = K[k_idx : k_idx + Bc]
    V_block = V[k_idx : k_idx + Bc]

    dV_acc = zeros(Bc, d)
    dK_acc = zeros(Bc, d)

    for q_idx in range(0, N, Br):
        Q_block = Q[q_idx : q_idx + Br]
        dO_block = dO[q_idx : q_idx + Br]
        LSE_block = LSE[q_idx : q_idx + Br]    # 来自前向

        # 1) 重计算 S = Q @ K^T
        S = Q_block @ K_block.T

        # 2) 重计算 P (用前向存的 LSE)
        P = exp(S - LSE_block.unsqueeze(1))

        # 3) D = rowsum(O ⊙ dO) -- 用前向的 O
        D_block = sum_per_row(O[q_idx:q_idx+Br] * dO_block)

        # 4) dV 累加
        dV_acc += P.T @ dO_block

        # 5) dP, dS
        dP = dO_block @ V_block.T
        dS = P * (dP - D_block.unsqueeze(1))

        # 6) dQ -- 这一步需要 atomic 累加 (跨 k_idx 多次更新)
        dQ_partial = dS @ K_block
        atomic_add(dQ[q_idx : q_idx + Br], dQ_partial)

        # 7) dK 累加
        dK_acc += dS.T @ Q_block

    dV[k_idx : k_idx + Bc] = dV_acc
    dK[k_idx : k_idx + Bc] = dK_acc

这种循环方式的特点:

dQ 的处理是 FA2 反向的一个工程难点。

16.3 dQ 的处理:atomic vs 二阶段

dQ 累加的两种工程方案:

16.3.1 方案 A:atomic add 直接写

atomicAdd(dQ + q_idx * d + col, dQ_partial);

简单但慢——HBM atomic 几百 cycle,N×N 量级的 atomic 加起来很多。

实际工业实现用 atomic 时会做几个优化:

  1. Atomic 在 SMEM 上,最后一次性写 HBM:每个 k_block 的 q 部分只写一次 SMEM atomic,最后写 HBM。但这要求所有 k_block 都在同一个 thread block 内——不现实,因为外层循环遍历所有 K。
  2. 半精度 atomic:HBM 上 fp16 atomic 比 fp32 atomic 快。但 fp16 atomic 在 H100 之前不被原生支持。

16.3.2 方案 B:两阶段 kernel

第一个 kernel 算 dV 和 dK(外层 K 循环结构),把 dS 写到一个临时缓冲。

第二个 kernel 把 dQ = dS @ K 算出来(不同 K_block 的 dS 已经分别 stage)。

但这要求中间存储 dS——而 dS 又是 N×N 的 attention 矩阵,就是 FA 想避免的东西

FA2 论文最终选择方案 A(atomic),因为:

实测 FA2 反向比朴素反向(先算 N×N 中间矩阵,再求梯度)快 3-5x

16.4 LSE 在反向中的角色

注意 16.2 节的伪代码中第 2 步 P = exp(S - LSE_block.unsqueeze(1))

这里 LSE 是前向输出的副产物:LSEi=mi+logi\text{LSE}_i = m_i + \log \ell_i

为什么需要 LSE?因为反向重计算 P 需要前向算过的归一化常数。如果不存 LSE,反向时要重新 online softmax 一遍——多一次完整的 K 维度遍历。

存 LSE 的代价:N 个 fp32 = 16 KB(N=4096 时),可以忽略。

16.5 反向的 SMEM 与寄存器布局

反向比前向需要更多状态:

SMEM 占用更紧,dV_acc/dK_acc 在寄存器中。寄存器压力比前向更大。

16.6 反向 Kernel 骨架

template <int Br, int Bc, int d>
__global__ void flash_attn_bwd(
    const half* Q, const half* K, const half* V, const half* O, const half* dO,
    const float* LSE,
    half* dQ, half* dK, half* dV,
    int N, float softmax_scale, bool is_causal
) {
    // Grid: (N / Bc, H, B)
    int k_tile_idx = blockIdx.x;
    int head_idx = blockIdx.y;
    int batch_idx = blockIdx.z;

    extern __shared__ half smem[];
    half* sK = smem;
    half* sV = sK + Bc * d;
    half* sQ = sV + Bc * d;
    half* sO = sQ + Br * d;
    half* sdO = sO + Br * d;
    float* sLSE = (float*)(sdO + Br * d);

    // Load K, V tile (一次, 整个 inner 循环用)
    load_tile(sK, K + k_tile_idx * Bc * d, Bc * d);
    load_tile(sV, V + k_tile_idx * Bc * d, Bc * d);
    __syncthreads();

    // 寄存器中的 dV, dK 累加器
    float dV_acc[MMAS_K][MMAS_D][4] = {0};
    float dK_acc[MMAS_K][MMAS_D][4] = {0};

    // 内层: 遍历所有 Q tile
    int q_start = is_causal ? k_tile_idx * Bc : 0;
    for (int q_tile = q_start; q_tile < N; q_tile += Br) {
        // Load Q, O, dO, LSE
        load_tile(sQ, Q + q_tile * d, Br * d);
        load_tile(sO, O + q_tile * d, Br * d);
        load_tile(sdO, dO + q_tile * d, Br * d);
        load_lse(sLSE, LSE + q_tile, Br);
        __syncthreads();

        // 1) S = Q @ K^T (与前向一样)
        float S_frag[MMAS_M][MMAS_N][4];
        compute_QKt(S_frag, sQ, sK);
        scale_and_mask(S_frag, softmax_scale, q_tile, k_tile_idx, is_causal);

        // 2) P = exp(S - LSE) -- 比 online softmax 简单, 因为 LSE 已知
        float P_frag[MMAS_M][MMAS_N][4];
        for (int i = 0; i < MMAS_M; ++i) {
            for (int j = 0; j < MMAS_N; ++j) {
                for (int e = 0; e < 4; ++e) {
                    int row = ...;  // lane 持有的行
                    P_frag[i][j][e] = expf(S_frag[i][j][e] - sLSE[row]);
                }
            }
        }

        // 3) D = rowsum(O ⊙ dO)
        float D_local[MMAS_M][2] = {0};
        compute_D(D_local, sO, sdO);  // load O, dO 的 fragment 并按行 reduce

        // 4) dV_acc += P^T @ dO
        // 注意 P 是 fp16, dO 是 fp16, dV_acc 是 fp32
        unsigned P_fp16[MMAS_M][MMAS_N];
        cast_fp32_to_fp16(P_fp16, P_frag);
        unsigned dO_frag[MMAS_M][MMAS_D];
        load_dO_fragment(dO_frag, sdO);
        // 注意是 P^T, 需要 ldmatrix.trans 或者交换 fragment 角色
        accumulate_matmul_T(dV_acc, P_fp16, dO_frag);

        // 5) dP = dO @ V^T
        float dP_frag[MMAS_M][MMAS_N][4];
        compute_dOVt(dP_frag, sdO, sV);

        // 6) dS = P ⊙ (dP - D)
        for (int i = 0; i < MMAS_M; ++i) {
            for (int j = 0; j < MMAS_N; ++j) {
                for (int e = 0; e < 4; ++e) {
                    int row_local = e / 2;
                    dP_frag[i][j][e]
                        = P_frag[i][j][e] * (dP_frag[i][j][e] - D_local[i][row_local]);
                }
            }
        }
        // 现在 dP_frag 实际是 dS_frag

        // 7) dQ_partial = dS @ K, atomic add to dQ
        unsigned dS_fp16[MMAS_M][MMAS_N];
        cast_fp32_to_fp16(dS_fp16, dP_frag);
        unsigned K_frag[MMAS_N][MMAS_D];
        load_K_fragment(K_frag, sK);
        float dQ_partial[MMAS_M][MMAS_D][4] = {0};
        accumulate_matmul(dQ_partial, dS_fp16, K_frag);

        // atomic add to global dQ
        atomic_write_dQ(dQ + q_tile * d, dQ_partial);

        // 8) dK_acc += dS^T @ Q
        unsigned Q_frag[MMAS_M][MMAS_D];
        load_Q_fragment(Q_frag, sQ);
        accumulate_matmul_T(dK_acc, dS_fp16, Q_frag);

        __syncthreads();
    }

    // 写 dV, dK
    write_dV(dV + k_tile_idx * Bc * d, dV_acc);
    write_dK(dK + k_tile_idx * Bc * d, dK_acc);
}

完整实现请参考 flash-attention/csrc/flash_attn/flash_bwd_kernel.h

16.7 反向的性能与挑战

反向比前向慢的原因:

  1. 三个 GEMM vs 前向的两个:dV、dK、dQ 都是独立的 GEMM 累加。
  2. dQ 的 atomic write:每次 inner Q 循环一次 atomic,N 大时累计严重。
  3. 更多 SMEM 输入:Q, K, V, O, dO, LSE 全部要在 SMEM 中。
  4. 重计算开销:S 和 P 在反向时要重算(这是 FA 的设计选择,避免存中间矩阵)。

实测 H100 上 FA2 反向(N=4096)大约 250-300 TFLOPs,是前向的一半左右。这个 ratio(反向 / 前向 ~= 0.5)和大多数算子的反向慢类似。

16.8 工业级实现的关键技巧

FlashAttention 官方实现(Tri Dao 的 GitHub 仓库)做了大量工程优化:

  1. Block-level 同步原语:使用 mbarrier 而不是 __syncthreads,减少同步开销。
  2. 动态调度 Q tile 顺序:causal mask 时,跳过 q < k 的 tile,加速明显。
  3. SwiGLU 融合:把 LN 和 attention 输出连起来 fuse 一些 element-wise op。
  4. LSE 精度选择:fp32 LSE 还是 fp16 LSE?后者更省 SMEM,但精度影响。
  5. dQ atomic 精度选择:fp32 atomic 慢但准,fp16 atomic 快但有精度损失。

每一个技巧都需要在精度和性能间权衡。FA2 论文里给出了详细的 ablation。

16.9 这一章的小结与下一章

FA2 反向的核心要点:

  1. 数学:dV/dK/dQ 都涉及重计算 S 和 P——这是用计算换存储的核心 trade-off。
  2. 循环顺序翻转:外层 K,内层 Q。dV/dK 在外层累加,dQ 必须 atomic。
  3. LSE 是前向到反向的桥梁:让反向不需要重新 online softmax。
  4. D = rowsum(O ⊙ dO) 是简化 dS 公式的关键。
  5. 反向比前向慢约 2x,但相比朴素反向快 3-5x。

到第 16 章为止,读者已经能写出一个能用、性能合理的 FA2(forward + backward)。但还没用到 Hopper 的杀手锏——TMA 和 Warp Specialization。

第 17 章我们把 FA2 重写到 Hopper 的最优形态——TMA 异步拷贝代替 cp.async、WGMMA 代替 mma.sync、Producer/Consumer warp 流水线。读完第 17 章读者会理解为什么 FA3 在 H100 上能跑到 740 TFLOPs(FP16),以及"现代 GPU kernel"和"Volta 时代的 GPU kernel"在编程模型上的根本差异。

本章动手练习

  1. 推导 dS = P ⊙ (dP - D) 的代数过程。验证 D 的两种等价定义。
  2. 实现 16.6 节骨架,在 H100 上测试反向性能。预期 ~250 TFLOPs。
  3. 思考:如果不存 LSE,反向时怎么重新算?需要多少额外 HBM 流量?