CUDA 算子工程:手写 FlashAttention v2 之路
第 14 章 Attention 的访存瓶颈与 IO-Aware 思想
第 14 章 Attention 的访存瓶颈与 IO-Aware 思想
"We have been computing attention wrong for years. Not the math — the I/O." ——Tri Dao, FlashAttention 论文摘要的一种非正式概括
14.1 Attention 的数学
经典 scaled dot-product attention:
形状:
- (N 个 query,每个 d 维)
- (N 个 key,每个 d 维)
- (N 个 value,每个 d 维)
- 输出
中间结果:
- (attention scores)
- (attention probabilities)
数学上每个步骤都很清晰。问题在于在 GPU 上这种逐步计算的 HBM 流量爆炸。
14.2 标准 Attention 的 HBM 流量精确计算
设 (序列长度),(head dim),FP16 数据。
Step 1:
- 读 Q: KB
- 读 K: KB
- 写 S: MB
Step 2:
朴素 softmax 3 pass:
- 读 S:3 次 × 32 MB = 96 MB
- 写 P:32 MB
Step 3:
- 读 P:32 MB
- 读 V:512 KB
- 写 O:512 KB
总 HBM 流量:
0.5 + 0.5 + 32 (写 S)
+ 96 + 32 (softmax)
+ 32 + 0.5 + 0.5 (写 O)
= 194 MB
而计算量只有:
- : FLOPs
- : FLOPs
- 合计: FLOPs = 4.2 GFLOPs
实际算术强度 = 4.2 GFLOPs / 194 MB = 22 FLOPs/byte
把这个数字放到 H100 的 Roofline 上:
- 22 < 295 (临界点)
- 所以 attention 是带宽 bound
理论上限 = 22 × 3.35 TB/s = 74 TFLOPs。但 Tensor Core 峰值 989 TFLOPs——理论上 attention 只能用到 7.4% 算力。
实测 PyTorch 朴素 attention 在 H100 上 N=4096, d=64 大约能跑到 40 TFLOPs(4% 算力)。比理论 74 TFLOPs 还低,因为还有 kernel launch 开销、cache 不完美等额外损失。
14.3 中间矩阵 S 和 P 的代价
仔细看 14.2 节的流量分解,会发现一个惊人的事实:194 MB HBM 流量中,绝大部分是 S 和 P 矩阵的反复读写:
真正"内容"流量:
Q: 0.5 MB K: 0.5 MB V: 0.5 MB O: 0.5 MB
合计: 2 MB
中间矩阵流量:
写 S: 32 MB
softmax 读 S 3次 + 写 P: 32+96+32 = 160 MB ?? (实际优化版会合并到一遍, 但需要 64 MB)
PV 读 P: 32 MB
中间流量 / 总流量 ≈ 95%
95% 的 HBM 带宽消耗在中间矩阵上——而这些矩阵的存在仅仅是因为我们一步一步串行算。如果能把它们消除,HBM 流量直接降到 ~3 MB。
这就是 FlashAttention 的核心 insight。
14.4 IO-Aware:FA 的核心思想
FlashAttention 论文(Dao et al., 2022)的关键 idea:不要把 S 和 P 写到 HBM。
具体做法:
- Tile 化 attention:把 沿 N 维切成 tile(block)。
- 每次只对一对 (Q_tile, K_tile, V_tile) 计算:S_tile 在 SMEM 里就生成,softmax_tile 在 SMEM 里完成,PV 累加到 O_tile,完全不写中间矩阵。
- 跨 K_tile 的 softmax 合并:用第 6 章讲的 online softmax,把多个 K_tile 的结果在线合并。
flowchart LR
subgraph Naive [朴素 Attention · 多 kernel]
N1[QK^T → S in HBM]
N2[softmax S → P in HBM]
N3[PV → O in HBM]
N1 --> N2 --> N3
end
subgraph FA [FlashAttention · 单 kernel]
F1[Tile by tile]
F1 --> F2[S_tile in SMEM]
F2 --> F3[Online softmax in SMEM]
F3 --> F4[Accumulate to O fragment in register]
F4 --> F5[全部 K_tile 处理完后, 写 O 一次]
end
14.5 FA 的算法骨架(Forward)
# FA Forward 算法 (Dao 2022 论文 Algorithm 1)
# 输入: Q, K, V ∈ [N, d]
# 输出: O ∈ [N, d], LSE ∈ [N]
# 沿 N 维分块
B_q = 64 # Query block size
B_k = 64 # Key block size
# 外层: 遍历 Q 的 block
for q_idx in range(0, N, B_q):
Q_block = Q[q_idx : q_idx + B_q] # [B_q, d]
# 累加状态
O_block = zeros(B_q, d)
m_block = -inf * ones(B_q) # 行级 max
l_block = zeros(B_q) # 行级 sum
# 内层: 遍历 K 的 block
for k_idx in range(0, N, B_k):
K_block = K[k_idx : k_idx + B_k] # [B_k, d]
V_block = V[k_idx : k_idx + B_k] # [B_k, d]
# 1) 计算 S_block = Q_block @ K_block^T
S_block = Q_block @ K_block.T # [B_q, B_k]
# 2) Online softmax 更新 m, l
m_new = max(m_block, max_per_row(S_block)) # [B_q]
P_block = exp(S_block - m_new) # [B_q, B_k]
alpha = exp(m_block - m_new) # [B_q]
l_new = alpha * l_block + sum_per_row(P_block)
# 3) 累加输出
O_block = alpha * O_block + P_block @ V_block
m_block = m_new
l_block = l_new
# 最终归一化
O[q_idx : q_idx + B_q] = O_block / l_block
LSE[q_idx : q_idx + B_q] = m_block + log(l_block)
关键点:
- 没有 N×N 矩阵:S 和 P 只在内层循环的当前 K_tile 内存在,B_q × B_k 大小(很小)。
- O 在外层循环中累加:每个 K_tile 都对当前的 O_block 贡献一份,用 online softmax 的 alpha 修正之前的累积。
- 最后才做归一化:除以最终的 l_block。
- LSE(log-sum-exp)作为副产物输出:反向传播时需要。
14.6 FA 的 HBM 流量重新计算
把 FA 的 HBM 流量算一遍。N=4096, d=64, FP16:
- 读 Q:每个 q_block 读 1 次,总共读 1 次 = = 512 KB
- 读 K:对每个 q_block,遍历所有 k_block,总共读 次 K =
- 读 V:同 K,读 次
,所以 K 读 64 次 = 32 MB,V 读 64 次 = 32 MB。
- 写 O:1 次 = 512 KB
- 写 LSE:很小(N 个 fp32)
总 HBM 流量 ≈ 64 MB
但等等,这比朴素的 194 MB 也才省了 65%。FA 真正的优势在哪?
答案是:省掉的是中间矩阵(S, P)的 HBM 写——这些是带宽里"真正不应该有"的部分。
但同时 FA 引入了一个新代价:K 和 V 被多次读(每个 q_block 都读一次 K 和 V)。这个代价在 N=4096 时是 64×。
总 HBM 流量从 194 MB 降到 64 MB,约 3× 提升。实际算术强度:
- FLOPs 不变:4.2 GFLOPs
- HBM:64 MB
- AI = 4.2 GFLOPs / 64 MB = 66 FLOPs/byte
66 比 22 提升 3×,但仍然 < 295(带宽 bound)。理论上限:66 × 3.35 TB/s = 221 TFLOPs。
实测 FA1 在 A100 上 ~120 TFLOPs,FA2 在 A100 上 ~230 TFLOPs,FA3 在 H100 上 ~740 TFLOPs。FA2/3 远超 FA1 不是因为算法变了——算法基本相同——而是因为更细致地利用了 GPU(warp 级并行、TMA、WGMMA)。
14.7 K/V 重读的代价:能不能进一步优化
FA Forward 的 K/V 重读看似浪费。能不能把 K/V 也只读一次?
答案是可以,但需要换一个外层循环——外层遍历 K,内层遍历 Q。这是 FA 论文中 "Algorithm 2" 的写法,代价是 Q 被多次读,且 O 需要 atomic 累加(多个 k_block 同时贡献到同一个 q 行)。
实践中两种循环方式各有优劣:
| 外层循环 | 重读 | 适用 |
|---|---|---|
| 外 Q 内 K | K, V 重读 | Forward 默认(FA1/2) |
| 外 K 内 Q | Q 重读,O 需 atomic | Backward 默认 |
第 16 章讲反向时会回到这一点。
14.7.1 长序列的进一步优化
如果 N 极大(比如 N=64K,长上下文),重读 K/V 的代价会主导。FA3 中引入了Splitting——把 K/V 维度也切分给多个 SM 同时算,最后跨 SM 合并 partial 结果。这样每个 K/V tile 只被一组 SM 读一次,但需要跨 SM 通信。
vLLM 的 Flash-Decoding 和 SGLang 的 RaggedBatch 都实现了类似优化,特别用于 long-context decoding(输入 128K,每次生成一个 token)。
14.8 关于 IO-Aware 的更广义理解
FlashAttention 的成功不只是一个算法。它代表了一种思维方式:
重写算法的数据流,而不是重写算法本身。
数学上,FA 算的是同一个 attention(输出严格相等到浮点误差)。但它的"数据流"完全重组了——把 N×N 中间矩阵从 HBM 挤出去,让 K/V 多读几次换中间矩阵不写。这种"用一种带宽换另一种带宽"的 trade-off 是 GPU 算法设计的核心模式。
类似的 IO-aware 思想已经被推广到很多其他算法:
- FlashAttention for Attention
- FlashConv for Convolution
- PagedAttention for KV Cache management
- Speculative Decoding for autoregressive sampling
- Tree-based Decoding for batch generation
每一种都是"重新组织数据流以拥抱 GPU 内存层级"。读懂 FA 之后,读者会发现这种思维在 LLM 系统的每一层都有应用。
14.9 这一章的小结与下一章
这一章把 attention 的访存瓶颈彻底剖开:
- 朴素 attention 是带宽 bound(AI=22):HBM 流量主要消耗在 N×N 中间矩阵 S 和 P 上。
- FA 的 idea 是把中间矩阵留在 SMEM:用 online softmax 跨 K-tile 累积。
- FA 的 HBM 流量降到 ~1/3:但代价是 K/V 多次重读。
- FA 的真实理论上限是 ~221 TFLOPs(仍带宽 bound):但 FA3 通过 Hopper 优化做到了 740 TFLOPs。
- IO-Aware 是一种思维方式:在 LLM 系统的每一层都有应用。
第 15 章我们正式动手——把 14.5 节的伪代码翻译成具体的 CUDA kernel。我们会用第三篇 GEMM 优化的所有工具(Tensor Core、ldmatrix、SMEM tile、double buffer)来组装 FA2 前向。读完第 15 章读者会拥有一个 ~70% Tensor Core 利用率的 FA2 forward kernel。
本章动手练习:
- 用 PyTorch 写朴素 attention 和调用
torch.nn.functional.scaled_dot_product_attention(内部走 FA),用 Nsight Compute 测两者的 HBM 读写流量,验证差距。- 推导 FA 反向的 HBM 流量公式,对比朴素反向(需要 N×N 中间梯度矩阵)。
- 阅读 FlashAttention 论文 Algorithm 1,对照 14.5 节的伪代码逐行核对。