CUDA 算子工程:手写 FlashAttention v2 之路

第 14 章 Attention 的访存瓶颈与 IO-Aware 思想

作者 杨艺韬 · 2,176 字

第 14 章 Attention 的访存瓶颈与 IO-Aware 思想

"We have been computing attention wrong for years. Not the math — the I/O." ——Tri Dao, FlashAttention 论文摘要的一种非正式概括

14.1 Attention 的数学

经典 scaled dot-product attention:

Attention(Q,K,V)=softmax(QKTd)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right) V

形状:

中间结果:

数学上每个步骤都很清晰。问题在于在 GPU 上这种逐步计算的 HBM 流量爆炸

14.2 标准 Attention 的 HBM 流量精确计算

N=4096N = 4096(序列长度),d=64d = 64(head dim),FP16 数据。

Step 1: S=QKTS = QK^T

Step 2: P=softmax(S)P = \text{softmax}(S)

朴素 softmax 3 pass:

Step 3: O=PVO = PV

总 HBM 流量

0.5 + 0.5 + 32 (写 S)
+ 96 + 32 (softmax)
+ 32 + 0.5 + 0.5 (写 O)
= 194 MB

计算量只有:

实际算术强度 = 4.2 GFLOPs / 194 MB = 22 FLOPs/byte

把这个数字放到 H100 的 Roofline 上:

理论上限 = 22 × 3.35 TB/s = 74 TFLOPs。但 Tensor Core 峰值 989 TFLOPs——理论上 attention 只能用到 7.4% 算力

实测 PyTorch 朴素 attention 在 H100 上 N=4096, d=64 大约能跑到 40 TFLOPs(4% 算力)。比理论 74 TFLOPs 还低,因为还有 kernel launch 开销、cache 不完美等额外损失。

14.3 中间矩阵 S 和 P 的代价

仔细看 14.2 节的流量分解,会发现一个惊人的事实:194 MB HBM 流量中,绝大部分是 S 和 P 矩阵的反复读写

真正"内容"流量:
  Q: 0.5 MB    K: 0.5 MB    V: 0.5 MB    O: 0.5 MB
  合计: 2 MB

中间矩阵流量:
  写 S: 32 MB
  softmax 读 S 3次 + 写 P: 32+96+32 = 160 MB ?? (实际优化版会合并到一遍, 但需要 64 MB)
  PV 读 P: 32 MB

中间流量 / 总流量 ≈ 95%

95% 的 HBM 带宽消耗在中间矩阵上——而这些矩阵的存在仅仅是因为我们一步一步串行算。如果能把它们消除,HBM 流量直接降到 ~3 MB。

这就是 FlashAttention 的核心 insight。

14.4 IO-Aware:FA 的核心思想

FlashAttention 论文(Dao et al., 2022)的关键 idea:不要把 S 和 P 写到 HBM

具体做法:

  1. Tile 化 attention:把 Q,K,VQ, K, V 沿 N 维切成 tile(block)。
  2. 每次只对一对 (Q_tile, K_tile, V_tile) 计算:S_tile 在 SMEM 里就生成,softmax_tile 在 SMEM 里完成,PV 累加到 O_tile,完全不写中间矩阵
  3. 跨 K_tile 的 softmax 合并:用第 6 章讲的 online softmax,把多个 K_tile 的结果在线合并。
flowchart LR
  subgraph Naive [朴素 Attention · 多 kernel]
    N1[QK^T → S in HBM]
    N2[softmax S → P in HBM]
    N3[PV → O in HBM]
    N1 --> N2 --> N3
  end
  subgraph FA [FlashAttention · 单 kernel]
    F1[Tile by tile]
    F1 --> F2[S_tile in SMEM]
    F2 --> F3[Online softmax in SMEM]
    F3 --> F4[Accumulate to O fragment in register]
    F4 --> F5[全部 K_tile 处理完后, 写 O 一次]
  end

14.5 FA 的算法骨架(Forward)

# FA Forward 算法 (Dao 2022 论文 Algorithm 1)
# 输入: Q, K, V ∈ [N, d]
# 输出: O ∈ [N, d], LSE ∈ [N]

# 沿 N 维分块
B_q = 64   # Query block size
B_k = 64   # Key block size

# 外层: 遍历 Q 的 block
for q_idx in range(0, N, B_q):
    Q_block = Q[q_idx : q_idx + B_q]          # [B_q, d]

    # 累加状态
    O_block = zeros(B_q, d)
    m_block = -inf * ones(B_q)                 # 行级 max
    l_block = zeros(B_q)                       # 行级 sum

    # 内层: 遍历 K 的 block
    for k_idx in range(0, N, B_k):
        K_block = K[k_idx : k_idx + B_k]      # [B_k, d]
        V_block = V[k_idx : k_idx + B_k]      # [B_k, d]

        # 1) 计算 S_block = Q_block @ K_block^T
        S_block = Q_block @ K_block.T          # [B_q, B_k]

        # 2) Online softmax 更新 m, l
        m_new = max(m_block, max_per_row(S_block))   # [B_q]
        P_block = exp(S_block - m_new)               # [B_q, B_k]
        alpha = exp(m_block - m_new)                  # [B_q]
        l_new = alpha * l_block + sum_per_row(P_block)

        # 3) 累加输出
        O_block = alpha * O_block + P_block @ V_block

        m_block = m_new
        l_block = l_new

    # 最终归一化
    O[q_idx : q_idx + B_q] = O_block / l_block
    LSE[q_idx : q_idx + B_q] = m_block + log(l_block)

关键点:

  1. 没有 N×N 矩阵:S 和 P 只在内层循环的当前 K_tile 内存在,B_q × B_k 大小(很小)。
  2. O 在外层循环中累加:每个 K_tile 都对当前的 O_block 贡献一份,用 online softmax 的 alpha 修正之前的累积。
  3. 最后才做归一化:除以最终的 l_block。
  4. LSE(log-sum-exp)作为副产物输出:反向传播时需要。

14.6 FA 的 HBM 流量重新计算

把 FA 的 HBM 流量算一遍。N=4096, d=64, FP16:

N/Bq=64N/B_q = 64,所以 K 读 64 次 = 32 MB,V 读 64 次 = 32 MB。

总 HBM 流量 ≈ 64 MB

但等等,这比朴素的 194 MB 也才省了 65%。FA 真正的优势在哪?

答案是:省掉的是中间矩阵(S, P)的 HBM 写——这些是带宽里"真正不应该有"的部分。

但同时 FA 引入了一个新代价:K 和 V 被多次读(每个 q_block 都读一次 K 和 V)。这个代价在 N=4096 时是 64×。

总 HBM 流量从 194 MB 降到 64 MB,约 3× 提升。实际算术强度

66 比 22 提升 3×,但仍然 < 295(带宽 bound)。理论上限:66 × 3.35 TB/s = 221 TFLOPs

实测 FA1 在 A100 上 ~120 TFLOPs,FA2 在 A100 上 ~230 TFLOPs,FA3 在 H100 上 ~740 TFLOPs。FA2/3 远超 FA1 不是因为算法变了——算法基本相同——而是因为更细致地利用了 GPU(warp 级并行、TMA、WGMMA)。

14.7 K/V 重读的代价:能不能进一步优化

FA Forward 的 K/V 重读看似浪费。能不能把 K/V 也只读一次?

答案是可以,但需要换一个外层循环——外层遍历 K,内层遍历 Q。这是 FA 论文中 "Algorithm 2" 的写法,代价是 Q 被多次读,且 O 需要 atomic 累加(多个 k_block 同时贡献到同一个 q 行)。

实践中两种循环方式各有优劣:

外层循环 重读 适用
外 Q 内 K K, V 重读 Forward 默认(FA1/2)
外 K 内 Q Q 重读,O 需 atomic Backward 默认

第 16 章讲反向时会回到这一点。

14.7.1 长序列的进一步优化

如果 N 极大(比如 N=64K,长上下文),重读 K/V 的代价会主导。FA3 中引入了Splitting——把 K/V 维度也切分给多个 SM 同时算,最后跨 SM 合并 partial 结果。这样每个 K/V tile 只被一组 SM 读一次,但需要跨 SM 通信。

vLLM 的 Flash-Decoding 和 SGLang 的 RaggedBatch 都实现了类似优化,特别用于 long-context decoding(输入 128K,每次生成一个 token)。

14.8 关于 IO-Aware 的更广义理解

FlashAttention 的成功不只是一个算法。它代表了一种思维方式

重写算法的数据流,而不是重写算法本身

数学上,FA 算的是同一个 attention(输出严格相等到浮点误差)。但它的"数据流"完全重组了——把 N×N 中间矩阵从 HBM 挤出去,让 K/V 多读几次换中间矩阵不写。这种"用一种带宽换另一种带宽"的 trade-off 是 GPU 算法设计的核心模式。

类似的 IO-aware 思想已经被推广到很多其他算法:

每一种都是"重新组织数据流以拥抱 GPU 内存层级"。读懂 FA 之后,读者会发现这种思维在 LLM 系统的每一层都有应用。

14.9 这一章的小结与下一章

这一章把 attention 的访存瓶颈彻底剖开:

  1. 朴素 attention 是带宽 bound(AI=22):HBM 流量主要消耗在 N×N 中间矩阵 S 和 P 上。
  2. FA 的 idea 是把中间矩阵留在 SMEM:用 online softmax 跨 K-tile 累积。
  3. FA 的 HBM 流量降到 ~1/3:但代价是 K/V 多次重读。
  4. FA 的真实理论上限是 ~221 TFLOPs(仍带宽 bound):但 FA3 通过 Hopper 优化做到了 740 TFLOPs。
  5. IO-Aware 是一种思维方式:在 LLM 系统的每一层都有应用。

第 15 章我们正式动手——把 14.5 节的伪代码翻译成具体的 CUDA kernel。我们会用第三篇 GEMM 优化的所有工具(Tensor Core、ldmatrix、SMEM tile、double buffer)来组装 FA2 前向。读完第 15 章读者会拥有一个 ~70% Tensor Core 利用率的 FA2 forward kernel。

本章动手练习

  1. 用 PyTorch 写朴素 attention 和调用 torch.nn.functional.scaled_dot_product_attention(内部走 FA),用 Nsight Compute 测两者的 HBM 读写流量,验证差距。
  2. 推导 FA 反向的 HBM 流量公式,对比朴素反向(需要 N×N 中间梯度矩阵)。
  3. 阅读 FlashAttention 论文 Algorithm 1,对照 14.5 节的伪代码逐行核对。