Transformer 解剖:从 Attention 到推理系统

第 18 章 Flash Attention 与分布式推理:内存层级与多卡协同

作者 杨艺韬 · 4,906 字

第 18 章 Flash Attention 与分布式推理:内存层级与多卡协同

第六部分到这里讲了 KV Cache、量化、投机解码——但都是「在已有 attention 实现上做加法」的优化。这一章我们 改 attention 本身的实现——同样的数学,跑得快几倍——以及 把模型切到多张卡 的分布式策略。

这两件事看似无关,其实共享同一个底层逻辑:在内存层级上找最优点。Flash Attention 在「HBM ↔ SRAM」之间找最优;TP / PP / EP 在「单卡 ↔ 多卡 ↔ 多机」之间找最优。理解了一边,另一边就豁然开朗。

读完这章你能:

18.1 GPU 内存层级:物理事实

先把基础事实摆出来。一张 H100 GPU 的内存层级:

flowchart TB
  REG["寄存器 (Registers)<br/>~20 MB 总量<br/>带宽 ~250 TB/s<br/>SM 内私有"]
  SMEM["共享内存 SMEM (SRAM)<br/>~228 KB / SM × 132 SM ≈ 30 MB<br/>带宽 ~100 TB/s<br/>SM 内共享"]
  L2["L2 Cache<br/>50 MB<br/>带宽 ~10 TB/s<br/>所有 SM 共享"]
  HBM["HBM (Global Memory)<br/>80 GB<br/>带宽 3.35 TB/s<br/>整卡共享"]
  
  HBM --> L2 --> SMEM --> REG

关键事实:

  1. HBM 容量大、带宽相对小——80 GB,3.35 TB/s
  2. SMEM 容量小、带宽极大——单 SM 228 KB,100+ TB/s
  3. 越靠近计算单元越快越小——寄存器 > SMEM > L2 > HBM

模型参数(140 GB Llama-70B)远超 HBM;即使是 KV Cache(几十 GB)也远超 SMEM。所以模型推理必须从 HBM 读数据,但能在 SMEM 里完成的计算尽量在 SMEM 里完成

这就是 Flash Attention 的全部精髓。

18.2 朴素 attention 的内存访问账

回顾标准 attention 的实现(第 8 章):

S = Q @ K.transpose(-2, -1) / sqrt(d)   # (N, N)
A = softmax(S, dim=-1)                   # (N, N)
out = A @ V                              # (N, d)

每一步在 GPU 上的数据流:

flowchart LR
  HBM1["HBM: Q, K, V"] --> COMP1["算 S = QK^T"]
  COMP1 --> HBM2["HBM 写 S 矩阵 N×N"]
  HBM2 --> COMP2["算 softmax(S)"]
  COMP2 --> HBM3["HBM 写 A 矩阵 N×N"]
  HBM3 --> COMP3["算 A @ V"]
  COMP3 --> HBM4["HBM 写 out 矩阵"]

每一步都要把数据写回 HBM,再从 HBM 读回来——HBM 流量 = O(N²) + O(N²) + O(N · d) ≈ O(N²)

N=8K 时,单层一次 attention 的 HBM 读写量是 2×N2×4(FP32 中间存)0.52 \times N^2 \times 4 \text{(FP32 中间存)} \approx 0.5 GB——80 层模型一次推理 attention 部分就读写 40 GB HBM,远超模型权重大小。

HBM 带宽 = attention 实际瓶颈。理论 GPU 算力够用,但因为内存访问太多,attention 跑不快。

18.3 Flash Attention 1:分块 + Online Softmax

Flash Attention(Tri Dao et al., NeurIPS 2022)的核心想法:不要把 N×N 矩阵物化到 HBM,把整个 attention 算在 SMEM 里面

但有个问题:N×N 矩阵 N=8K 时占 1 GB,远超 SMEM 的 30 MB——直接放不下。

解决方案是 tiling(分块)+ online softmax

  1. 把 Q、K、V 切成小块(tile)
  2. 每次只把一块 K、V 从 HBM 读到 SMEM
  3. 在 SMEM 里对当前 Q tile 算这一小块的 attention 部分
  4. 通过 online softmax 把多个 K/V 块的部分结果正确合并

Online Softmax 的数学

直接说:能不能把 softmax 「拆成多个块单独算然后合并」?

朴素的 softmax 公式:

softmax(xi)=exijexj\text{softmax}(x_i) = \frac{e^{x_i}}{\sum_j e^{x_j}}

需要先扫一遍找 max\max(数值稳定),再扫一遍算分母 exjmax\sum e^{x_j - \max},再扫一遍算每个值的输出——三遍扫描 + 中间存 N 个值

Online softmax 可以一遍扫过去就完成。维护两个变量 mm(当前最大值)和 \ell(当前分母):

初始 m0=m_0 = -\infty, 0=0\ell_0 = 0

每读一个新值 xix_i

mi=max(mi1,xi)m_i = \max(m_{i-1}, x_i) i=i1emi1mi+eximi\ell_i = \ell_{i-1} \cdot e^{m_{i-1} - m_i} + e^{x_i - m_i}

最终的 softmax 输出可以从 m,m, \ell 推出来。关键点:每读一块新数据时,老的 \ell 需要被「rescale」(乘 emoldmnewe^{m_{\text{old}} - m_{\text{new}}})以适应新的 max。

把它扩展到 attention 输出(O=AVO = A \cdot V):

每遇到一块新的 K(j),V(j)K^{(j)}, V^{(j)}

  1. S(j)=QK(j)TS^{(j)} = Q \cdot K^{(j)T}(在 SMEM 里)
  2. 找新 block 的 max:m(j)=maxlS:,l(j)m^{(j)} = \max_l S^{(j)}_{:,l}
  3. 更新全局 max:mnew=max(mold,m(j))m_{\text{new}} = \max(m_{\text{old}}, m^{(j)})
  4. rescale 老的 output 累加器:OOemoldmnewO \leftarrow O \cdot e^{m_{\text{old}} - m_{\text{new}}}
  5. 算新 block 的部分 output 累加:OO+eS(j)mnewV(j)O \leftarrow O + e^{S^{(j)} - m_{\text{new}}} \cdot V^{(j)}
  6. 更新分母:new=oldemoldmnew+eS(j)mnew\ell_{\text{new}} = \ell_{\text{old}} \cdot e^{m_{\text{old}} - m_{\text{new}}} + \sum e^{S^{(j)} - m_{\text{new}}}

最后一步:OO/newO \leftarrow O / \ell_{\text{new}}

数学上完全等价于一次性计算的标准 softmax + AV——但内存使用从 O(N²) 降到 O(N)(只需要存当前块)。

Flash Attention 1 的算法流程

flowchart LR
  subgraph "外层循环 over K, V blocks"
    LOAD[Load 一块 K^j, V^j 到 SMEM]
    LOAD --> COMP[Compute S^j, 更新 max / sum]
    COMP --> ACCUM[累加到 O,rescale]
  end
  subgraph "内层循环 over Q blocks"
    QLOAD[Load 一块 Q 到 SMEM]
  end
  QLOAD --> LOAD
  ACCUM --> END
  END --> WRITE[最后一次性 写 O 到 HBM]

伪代码(简化):

# Q, K, V in HBM
# Q: (N, d), tile size B_r;  K, V: (N, d), tile size B_c
for i in range(N // B_r):                    # 外层: Q 块
    Q_i = load Q tile to SMEM
    O_i = zeros, m_i = -inf, l_i = 0
    for j in range(N // B_c):                # 内层: K, V 块
        K_j, V_j = load to SMEM
        S_ij = Q_i @ K_j.T                    # in SMEM
        m_new = max(m_i, S_ij.max(dim=-1))
        # rescale 老 output
        O_i = O_i * exp(m_i - m_new)
        # 累加新 block 贡献
        l_new = l_i * exp(m_i - m_new) + exp(S_ij - m_new).sum(dim=-1)
        O_i += exp(S_ij - m_new) @ V_j
        m_i, l_i = m_new, l_new
    O_i /= l_i
    write O_i to HBM

Flash Attention 的收益

HBM 流量从 O(N²) 降到 O(N)(只需要扫 Q、K、V 一遍)——对长上下文场景收益特别大:

上下文长度 朴素 HBM 流量 Flash Attention HBM 流量 加速比
4K
8K ~3×
32K 64× ~5×
128K 1024× ~6×

这就是为什么 Flash Attention 在长上下文场景下加速这么明显。它没有改变 attention 的算法,只是把内存访问模式从「物化到 HBM」改成「在 SMEM 里完成」。

18.4 Flash Attention 2:减少非矩阵乘开销

Flash Attention 2(Dao, 2023)在 V1 基础上做了几个细节优化:

优化 1:减少非矩阵乘 op

GPU 的 Tensor Core 算矩阵乘极快(Ampere 是 312 TFLOPs FP16,Hopper 是 989 TFLOPs FP16),但算其他操作(exp、scale、reciprocal)只有 ~20 TFLOPs。FA1 里的 exp 和 rescale 操作占了不小比例。

FA2 通过重排算法,把更多时间花在矩阵乘上、把非矩阵乘的次数压到最少——单卡 FA1 的 30% 算力损失被压到 10% 左右。

优化 2:循环顺序对调

FA1 的外层循环是 K/V,内层是 Q——每次进入内层都要 reload Q(K/V 在 SMEM 里复用,Q 多次加载)。

FA2 把循环对调:外层 Q、内层 K/V。Q 加载一次,K/V 块多次加载——更适合 GPU 的访问模式。

优化 3:work partition 优化

GPU 上每个 SM(Streaming Multiprocessor)独立工作。FA1 把每个 (Q tile, head) 分配到一个 SM——N 短时 SM 数量不够用。

FA2 让每个 SM 处理多个 (Q tile, head),让 SM 更均衡地负载。

综合下来,FA2 在 H100 上比 FA1 快 1.7-2×。这是 vLLM、SGLang、TensorRT-LLM 默认的 attention 实现。

18.5 Flash Attention 3:Hopper 特化

Flash Attention 3(Dao et al., 2024)针对 Hopper 架构(H100)做了进一步特化。Hopper 引入了几项新硬件能力:

  1. TMA(Tensor Memory Accelerator):异步内存搬运——计算和数据搬运可以重叠
  2. WGMMA(Warp Group Matrix Multiply Accumulate):异步矩阵乘——SM 级别的 pipeline
  3. FP8 支持:原生 FP8 矩阵乘

FA3 用三种新技巧吃干净 Hopper 的能力:

技巧 1:异步流水线(async pipeline)

把「load 下一个 K/V 块」和「计算当前块」并行执行。当一个 warp 在算 GEMM 时,另一个 warp 在用 TMA 加载下一块——TMA 不占用 SM 计算资源,纯异步。

技巧 2:Warp specialization

把 SM 内的 warp 分成两组:

两组通过 SMEM 上的「pipeline」交换数据——produce 写入、consume 读取——形成软件级别的硬件 pipeline。

技巧 3:FP8 支持

整个 FA3 在 FP8 下运行。质量略损失(< 0.5% PPL)但速度再快 2×。

FA3 在 H100 上达到 740 TFLOPs(相对 FP16 理论峰值 989 TFLOPs,75% MFU——对比 FA2 的 35%)。这是单 attention kernel 能达到的工程极限之一。

flowchart LR
  V1["FA1<br/>Tiling + Online Softmax<br/>~25% MFU"] --> V2["FA2<br/>+ work partition<br/>+ 减 non-matmul<br/>~35% MFU"]
  V2 --> V3["FA3<br/>+ async pipeline<br/>+ warp specialization<br/>+ FP8<br/>~75% MFU"]

18.6 现在转向分布式:为什么单卡不够

到这里我们解决了「单卡 attention 怎么跑得快」的问题。但回到现实:Llama-3 70B 是 140 GB,单 H100 的 80 GB 显存放不下。671B 模型 1.3 TB,要 17 张 H100 才装得下权重。

这就是分布式推理要解决的问题:怎么把超大模型切到多张 GPU 上

主流的三种切分维度:

flowchart TB
  MODEL[一个大模型<br/>太大单卡放不下]
  MODEL --> TP[Tensor Parallel TP<br/>把大矩阵切成小矩阵 多卡协同]
  MODEL --> PP[Pipeline Parallel PP<br/>不同 GPU 跑不同层 流水线]
  MODEL --> EP[Expert Parallel EP<br/>MoE 专属 不同专家分到不同 GPU]

每种并行解决不同的问题,工程上常常叠加使用:8 张 GPU 跑 70B 用 TP=8;64 张 GPU 跑 671B MoE 用 TP=4 + EP=16。下面分别讲。

18.7 Tensor Parallelism (TP):切大矩阵

TP 的核心想法:把大矩阵乘按某个维度切到多 GPU,每卡算一部分,然后做 all-reduce 合并

举例:FFN 第一层 W1Rd×4dW_1 \in \mathbb{R}^{d \times 4d}(典型 d=8192,4d=32768)。把 W1W_1 沿列切 4 份:

W_1 = [W_1^{(1)} | W_1^{(2)} | W_1^{(3)} | W_1^{(4)}]

GPU 1 持有 W1(1)Rd×dW_1^{(1)} \in \mathbb{R}^{d \times d},以此类推。

输入 xRdx \in \mathbb{R}^d(每张 GPU 都有一份):

然后 W2R4d×dW_2 \in \mathbb{R}^{4d \times d} 沿行切 4 份。每卡算它那部分:

这样一对 GEMM(xW1W2x \cdot W_1 \cdot W_2)被切到 4 张卡,每卡算量减 4 倍,最后只需要一次 all-reduce 通信。

flowchart LR
  subgraph "GPU 1"
    G1[x] --> G1_1["× W_1^(1)"] --> G1_2["× W_2^(1)"] --> Z1[z_1]
  end
  subgraph "GPU 2"
    G2[x] --> G2_1["× W_1^(2)"] --> G2_2["× W_2^(2)"] --> Z2[z_2]
  end
  subgraph "GPU 3"
    G3[x] --> G3_1["× W_1^(3)"] --> G3_2["× W_2^(3)"] --> Z3[z_3]
  end
  subgraph "GPU 4"
    G4[x] --> G4_1["× W_1^(4)"] --> G4_2["× W_2^(4)"] --> Z4[z_4]
  end
  Z1 & Z2 & Z3 & Z4 --> AR[all-reduce sum]
  AR --> Z[完整 z]

TP 在 Attention 上

Multi-Head Attention 用 TP 也很自然:按 head 切。h 个头切到 N 卡,每卡算 h/N 个头:

GQA 下 KV head 数比 Q head 数少,要确保 TP 维度能整除 KV head 数(否则 K、V 在某些 GPU 上没有,需要 broadcast)。

TP 的通信开销

每个 Block 需要 2 次 all-reduce(attention 后一次、FFN 后一次),每次传输的 hidden state 是 (B,T,d)(B, T, d)。对 H100 NVLink(900 GB/s)来说,一次 all-reduce 几百微秒——对 token 级延迟(30-50 ms)来说占 5-10%。

TP 在同机内(NVLink)下高效——一台 8 GPU 服务器内可以做 TP=8。但跨机就不行了——InfiniBand 带宽(400-800 Gbps)远低于 NVLink,跨机 TP 通信占比 30%+,得不偿失。

18.8 Pipeline Parallelism (PP):切层

如果模型大到 8 卡 NVLink 也装不下(比如 DeepSeek-V3 的 671B),需要跨机扩展——PP 是这个场景的解决方案。

PP 的想法:把不同层放到不同 GPU。比如 80 层的 Llama-3 70B,前 20 层放 GPU 1,21-40 层放 GPU 2,41-60 层放 GPU 3,61-80 层放 GPU 4。

数据流:token 进入 GPU 1,经过 20 层,把结果传给 GPU 2,再经过 20 层,依此类推。

flowchart LR
  X[输入 token] --> GPU1[GPU 1<br/>Layers 1-20]
  GPU1 -.传递 hidden state.-> GPU2[GPU 2<br/>Layers 21-40]
  GPU2 -.传递.-> GPU3[GPU 3<br/>Layers 41-60]
  GPU3 -.传递.-> GPU4[GPU 4<br/>Layers 61-80]
  GPU4 --> OUT[输出 logits]

Pipeline Bubble

PP 看似简单但有个根本问题:在 GPU 1 处理 batch i 时,GPU 2、3、4 都在等——大部分时间空转。

解决方案:micro-batching——把一个大 batch 切成几个小 micro-batch,让流水线流起来:

flowchart TB
  T1[t=1: GPU1 处理 mb1]
  T2[t=2: GPU1 处理 mb2; GPU2 处理 mb1]
  T3[t=3: GPU1 处理 mb3; GPU2 处理 mb2; GPU3 处理 mb1]
  T4[t=4: GPU1 处理 mb4; GPU2 处理 mb3; GPU3 处理 mb2; GPU4 处理 mb1]
  T5[t=5: GPU1 idle; GPU2 处理 mb4; GPU3 处理 mb3; GPU4 处理 mb2]
  T6[t=6: GPU3 处理 mb4; GPU4 处理 mb3]
  T7[t=7: GPU4 处理 mb4]
  
  T1 --> T2 --> T3 --> T4 --> T5 --> T6 --> T7

可以看到流水线开头和结尾都有 idle GPU——这叫 pipeline bubble。bubble 比例 = (Npp1)/Nmicro(N_{\text{pp}} - 1) / N_{\text{micro}}。要减小 bubble 必须增大 micro-batch 数——但这又增加了内存压力。

PP 适合场景

PP 的特点:

适合:

不适合:

18.9 Expert Parallelism (EP):MoE 专属

EP 是 MoE 模型专门的并行方式:不同 GPU 持有不同的专家,token 按路由结果发到对应 GPU

第 12 章我们讲过 MoE 的 All-to-All 通信——EP 就是它的具体实现。

DeepSeek-V3 的部署示例:256 个专家分到 32 张卡,每张卡 8 个专家。一个 token 的 K=8 路由可能选了「分布在不同 GPU 上的」8 个专家——All-to-All 把 token 发到目标 GPU、计算完后 All-to-All 收回来。

flowchart LR
  GPU1[GPU 1<br/>专家 1-8] -.All-to-All.-> GPU2[GPU 2<br/>专家 9-16]
  GPU2 -.-> GPU3[GPU 3<br/>专家 17-24]
  GPU3 -.-> GPU32[...GPU 32<br/>专家 249-256]
  GPU1 -.-> GPU32

EP 的通信代价

All-to-All 是「N×N 通信」——N 张 GPU 同时给所有其他 GPU 发数据。带宽要求极高。

对 DeepSeek-V3 这种 256 专家 + 32 EP 的部署,每次推理两次 All-to-All(attention 后 / FFN 中间)。在 NVLink 800Gbps + InfiniBand 400Gbps 的硬件下,All-to-All 的开销 ~5-10% 总时间——可以接受但需要精细调度。

DeepSeek 的工程团队为此做了专门的通信库(DualPipe、PD-Sep 等),把 All-to-All 和计算重叠——实际成本压到更低。

18.10 三种并行的组合:3D 并行

实际工程几乎从不只用一种并行——TP × PP × EP 是常见的「3D 并行」。

举例:DeepSeek-V3 部署在 256 张 H100 上:

总卡数 = TP × PP × EP × DP = 8 × 4 × 8 × 1 = 256

flowchart TB
  M[671B MoE 模型] --> TP[TP=8 切 attention 和 FFN 矩阵]
  TP --> PP[PP=4 切 61 层]
  PP --> EP[EP=8 切 256 个专家]
  EP --> SCALE[256 GPU 部署]

每种并行解决不同的问题:

并行 解决什么 通信 适合
TP 单层放不下 频繁 all-reduce(慢) 同机 NVLink
PP 整模型放不下 阶段间传 hidden state(快) 跨机
EP MoE 专家放不下 All-to-All(中等) 同机 + 跨机

设计部署拓扑的经验法则:

  1. 先 TP 占满同机 NVLink:8 卡 NVLink 服务器内 TP=8
  2. TP 不够再 EP:MoE 模型用 EP 把专家分到多机
  3. 再不够再 PP:超大模型 + PP 跨机
  4. 最后 DP:吞吐扩展用数据并行

18.11 通信带宽决定一切

分布式推理的核心制约不是算力,是通信带宽。把所有相关带宽放一起看:

通信类型 带宽 用于
HBM ↔ SM 3.35 TB/s 单卡内权重读写
NVLink (H100) 900 GB/s 同机 GPU 间
NVL Switch (NVL72) 1.8 TB/s 同机 GPU 间(增强版)
InfiniBand 8x400G 50 GB/s 跨机
Ethernet 100G 12.5 GB/s 跨机(更便宜,更慢)

差异巨大。NVLink 比 InfiniBand 快 18×,比 Ethernet 快 70×。这就是为什么:

构建大模型推理集群时,硬件选型的核心是网络拓扑

这种硬件投入是 frontier model 推理的「门票」——开源模型再好,没有几亿美元的硬件配套,也跑不出 OpenAI / Anthropic 同等的服务质量。

18.12 一个完整的部署案例

把这一章所有内容串起来,看一个真实部署:

任务:部署 Llama-3 70B 服务千万 DAU 的中文 chat 应用,要求 TPOT < 50 ms、TTFT < 1s。

Step 1:模型量化

Step 2:单卡 attention 优化

Step 3:推理引擎选择

Step 4:单卡能服务多少

Step 5:横向扩展

Step 6:成本优化

Step 7:跨机扩展

这种规模的部署是 OpenAI、Anthropic 等公司的日常——他们的 GPT-4 / Claude 服务背后就是这种规模的工程。开源团队、小公司想达到同等服务能力,要么花 $500M 自建集群,要么用 OpenAI / Anthropic API。

本章小结

Flash Attention 部分

  1. GPU 内存层级:HBM 大慢、SMEM 小快——attention 必须吃尽 SMEM。
  2. 朴素 attention HBM 流量是 O(N²)——长上下文场景下 attention 是瓶颈。
  3. Flash Attention = tiling + online softmax——把 attention 算在 SMEM 里,HBM 流量降到 O(N)。
  4. Online softmax 数学等价——一遍扫过去,结果与朴素 softmax 完全一致。
  5. FA1 → FA2 → FA3 持续优化:work partition、async pipeline、warp specialization、FP8 支持。FA3 在 H100 上 75% MFU。

分布式推理部分

  1. 三种并行:TP(切大矩阵)、PP(切层)、EP(切 MoE 专家)。
  2. TP 适合同机 NVLink:8 卡内 TP=8 是常见配置。跨机 TP 几乎不可行。
  3. PP 适合跨机扩展:通信少、bubble 限制。在线低延迟用得少。
  4. EP 是 MoE 专属:All-to-All 通信,需要专门优化。
  5. 3D 并行 = TP × PP × EP:超大模型部署的标准组合。
  6. 通信带宽是关键:NVLink > NVL Switch > InfiniBand > Ethernet——选硬件就是在选拓扑。

第六部分到这里完结。我们用 5 章把 LLM 推理系统从「为什么慢」(两阶段)到「怎么快」(KV Cache、量化、投机解码、Flash Attention、分布式)全讲清楚了。

下一章是终章——第 19 章 Transformer 之后。Transformer 已经九年没变过骨架了,但研究界一直在探索替代方案:Mamba 等线性复杂度架构、Hybrid 混合架构、Diffusion-based 生成。我们会沿着这条路看一眼未来。

延伸阅读