Transformer 解剖:从 Attention 到推理系统
第 3 章 Multi-Head Attention:分头的代数与几何意义
第 3 章 Multi-Head Attention:分头的代数与几何意义
第 2 章我们把 Self-Attention 的内核拆开看了。但实际工程里你看到的不会是单头 Attention,永远是 Multi-Head Attention——多个头并行计算。Llama 70B 有 64 个头,GPT-3 有 96 个头,DeepSeek-V3 有 128 个头。
为什么要多头?把同一件事并行做多遍有什么好处?这一章我们要彻底讲清楚两件事:第一,Multi-Head 在数学和几何上到底在做什么——它不是「把单头跑 N 遍取平均」;第二,从原始 MHA 到今天主流的 GQA / MQA,这条演化路线背后的工程压力是什么。
读完这章你会理解:为什么 Llama 3 的 70B 模型只有 8 个 KV 头但有 64 个 Q 头、为什么 GPT-3 用 MHA 而 Llama 3 改用了 GQA、为什么 PaLM 用 MQA 而 Mistral 早期用 GQA。这些选择背后都是同一道工程账。
3.1 单头不够:语言里同时存在多种关系
我们先停下来想一个问题:自然语言里,一个 token 和上下文之间的关系只有一种吗?
回到第 2 章那个例子:
The animal didn't cross the street because it was too tired.
"it" 这个 token 至少有这几条不同的关系链:
- 指代关系:it → animal(共指)
- 语法关系:it 是 was 的主语
- 语义关系:it / tired 描述同一个实体的状态
- 位置关系:it 处在主从句的从句开头
- 修辞关系:整个 because-从句解释为什么 didn't cross
如果你只用一个 Self-Attention 头来处理,所有这些关系都得挤在同一个注意力分布里。结果就是:要么模型只学到最强的那种关系(比如指代),其他的都被压制;要么模型试图同时学所有关系,导致每种都学得不够好。
更深的问题是几何上的。一个头的注意力权重 是由一组 投影出来的 的点积决定的。 一旦选定,整个注意力分布就被限制在它们投影出来的那个 维子空间里。而一个子空间一次只能编码一类相关性——你没法用同一组投影矩阵同时编码「语法」和「语义」这两种正交的语言学结构。
Multi-Head 的解决方案非常直接:让模型有多组 ,每一组独立学一种关系,然后把它们的输出拼起来。每个头活在自己那个子空间里、专心学一类关系,所有头加起来就覆盖了语言里多样的相关性结构。
flowchart LR X[输入 x_i] --> H1["Head 1<br/>学指代关系"] X --> H2["Head 2<br/>学语法关系"] X --> H3["Head 3<br/>学语义关系"] X --> HD[...] X --> HN["Head h<br/>学修辞关系"] H1 --> CONCAT[拼接所有头的输出] H2 --> CONCAT H3 --> CONCAT HD --> CONCAT HN --> CONCAT CONCAT --> WO["× W_O 投影"] --> OUT[最终输出]
注意「学指代 / 学语法 / 学语义」是为了讲解直觉给的标签。模型不会真的有这些标签——它们是模型从数据中自发涌现出来的分工。后面 3.5 节会用真实的 BERT 注意力可视化让你看到这种分工长什么样。
3.2 Multi-Head 的代数定义
把直觉翻译成数学。给定输入 、头数 、每个头的维度 。
第一步:每个头有自己的一组投影矩阵 ,其中 。
第二步:每个头独立做一次完整的 Self-Attention:
每个头的输出形状是 。
第三步:把 个头的输出沿特征维拼起来(不是相加!这是常见误解),得到形状 :
第四步:再过一次输出投影 :
合起来就是这一个公式:
看着复杂,但每一步都是第 2 章已经讲过的运算,只是被并行复制了 份。
Concat 是「沿最后一维拼接」——Head1 (T, d_v) 和 Head2 (T, d_v) 拼起来变成 (T, 2·d_v)。这个操作在 PyTorch 里是 torch.cat([head1, head2, ...], dim=-1)。
3.3 关键的工程实现:一次大矩乘
上面的描述里我们说 " 个头独立做",听起来像要跑 次 for 循环——这显然在 GPU 上不效率。实际工程上,所有头都被打包到一次大矩阵乘法里:
把所有头的投影矩阵沿特征维拼起来:
由于 ,所以 。 同理。
一次大矩乘把所有头的 Q/K/V 算出来:
接下来用一个 reshape,把 这一维拆成 :
PyTorch 里这一步是 Q.view(T, h, d_k).transpose(0, 1)。然后所有头的 attention 就可以用一个带 batch 维的矩阵乘法统一算:
# Q, K, V 形状都是 (h, T, d_k)
S = Q @ K.transpose(-2, -1) # (h, T, T)
S = S / math.sqrt(d_k)
A = torch.softmax(S, dim=-1) # (h, T, T)
out = A @ V # (h, T, d_k)
out = out.transpose(0, 1).contiguous().view(T, d_model)
out = out @ W_O # (T, d_model)
整个过程没有 for 循环。所有 个头的全部计算被压到 4 次矩阵乘法(QKV 投影各一次 + 输出投影一次)加 1 次 softmax 里——这是 Multi-Head 在工程上能高效跑的根本原因。
flowchart LR X["X<br/>(T, d_model)"] --> WQB["× W_Q_big<br/>(d_model, d_model)"] WQB --> QB["Q_big<br/>(T, d_model)"] --> RS["reshape<br/>(h, T, d_k)"] RS --> Q["Q<br/>(h, T, d_k)"] X --> KV["类似得到 K, V<br/>(h, T, d_k)"] Q --> ATT["batched<br/>scaled dot product<br/>attention"] KV --> ATT ATT --> CAT["transpose + reshape<br/>(T, d_model)"] CAT --> WO["× W_O"] --> OUT["MultiHead 输出<br/>(T, d_model)"]
3.4 参数量与计算量
来算一下 Multi-Head Attention 的参数量。每个头有 ,每个是 ;总共 个头:
(用了 )
加上输出投影 :
注意:Multi-Head 的总参数量等于「等价单头但 」的参数量——分头并没有让参数变多。这是一个非常重要的观察:MHA 不是用更多参数换更强表达力,而是用同样的参数实现了结构上更丰富的表达(多个独立子空间)。
计算量呢?一次 Self-Attention 的计算量主要是两块:
- QKV 投影:
- Attention 矩阵(QK^T、softmax、AV):
对 Multi-Head 来说总数和单头一致(因为参数量一致),分头只是把计算重新组织。但当 很大时,第 2 项 是主导——这是 Transformer 在长上下文下的瓶颈,第 13 章会展开。
3.5 不同头到底学到了什么
直觉上我们说「不同头学不同关系」,是真的吗?这个问题在 BERT 出来之后被研究得很透彻。
2019 年 Clark et al. 的论文 What Does BERT Look At? An Analysis of BERT's Attention 系统分析了 BERT-base(12 层 × 12 头 = 144 个头)的注意力分布,发现了几类典型的「专家头」:
- 位置头:有些头几乎只关注前一个 token 或后一个 token(局部位置关系)。
- 句法头:有些头集中关注当前 token 的句法依存对象(如动词的主语、形容词修饰的名词)。
- 共指头:有些头会把代词和它指代的名词强烈连接("it" 与 "animal")。
- 分隔符头:相当多的头(特别是高层)大量关注
[CLS]或[SEP]这类特殊 token——这不是模型在「重视分隔符」,而是 attention 软件结构自己产生的「注意力 sink」(垃圾桶头),把不知道往哪放的注意力倾倒到这些位置。
flowchart TB
subgraph BERT 注意力分工 简化示意
L1[Layer 1 - 主要是位置头与局部句法]
L4[Layer 4-6 - 大量句法依存头]
L8[Layer 8-10 - 共指 与 语义聚类头]
L12[Layer 11-12 - 任务相关头 与 sink 头]
end
L1 --> L4 --> L8 --> L12
这个分工不是模型被显式训练出来的——它是从大规模语言建模这个单一目标里自发涌现的。这是 Multi-Head 设计真正的妙处:你不需要告诉模型「这个头学语法、那个头学指代」,模型自己会按照「让 loss 最低」的原则自动分工。
下面这张图是我们用第 9 章那个 1.7M 参数的 mini-GPT(4 层 × 4 头)在唐诗上训完后,把《静夜思》整首灌进去得到的最后一层 4 个头的 attention 模式。可以看到分工已经清晰浮现:

仔细看 Head 3——它有两条平行的对角线,一条是当前位置(紧挨着自己),另一条比当前位置往前 5 个 token。这正是五言诗的「对偶结构」——「举头望明月」对「低头思故乡」、「举头」对「低头」、「望明月」对「思故乡」——每个字在对偶句的对应位置上都有一个共指对象。模型从未被告诉过「这是首五言诗,第 6 字对第 11 字」,但 Head 3 自发学到了这个隔 5 字的呼应模式。
而 Head 1、2 主要看相邻位置(局部句法),Head 4 在更长距离上分散——「位置头 + 局部句法头 + 对偶头 + 长距离汇聚头」的分工自发涌现。这和 BERT 实验里观察到的「不同头学不同关系」性质完全一致——只不过我们的模型小了三个数量级。
但有一个发现让后续的工作产生了深远影响:很多头其实是冗余的。Voita et al. 在 Analyzing Multi-Head Self-Attention(2019)和 Michel et al. 在 Are Sixteen Heads Really Better than One?(2019)里证明,BERT 的 144 个头中相当一部分可以被剪掉而几乎不影响下游任务表现——一些层只剩 1–2 个头还能正常工作。
这个观察后来引出了一条工程优化路线:既然很多头是冗余的,能不能在保留 query 多头的同时,让 key 和 value 共享?——这就是 MQA / GQA 的来历。
3.6 推理瓶颈推动的演化:MHA → MQA → GQA
进入大模型时代后,Multi-Head 的代价开始显形。问题不在训练阶段,而在推理阶段的 KV Cache。
第 15 章会详细讲 KV Cache,这里先给一个直觉:自回归生成时,每生成一个新 token,模型不需要重新算前面所有 token 的 K 和 V——它们可以缓存下来反复用。这个缓存的大小,对一个有 个头、序列长度 、每个头维度 、 层的模型来说是:
「2」是 K 和 V 各一份。我们用 Llama-3-70B 算一笔账(先假设它没有用 GQA、保持 MHA):、、,FP16 就是 2 bytes。在 时:
光是 KV Cache 一个 batch 就要 10 GiB——而 Llama-3-70B 模型权重本身在 FP16 下是 140 GB。如果你想 batch=8 同时跑 8 个用户的请求,KV Cache 一项就要 80 GiB——已经把整张 H100 的显存吃完。这是「如果继续用 MHA」的反事实——下面我们会看到正是这个压力推动了 GQA / MLA 的演化。
这就是 KV Cache 的工程压力:它正比于 head 数 。如果能降低有效的 KV head 数,KV Cache 立刻线性下降。
Multi-Query Attention (MQA)
最激进的优化是 2019 年 Noam Shazeer 提出的 MQA:所有 query 头共享同一组 key 和 value:
也就是说, 还是 个头,但 和 只有 1 个头。
KV Cache 立刻从 降到 ——降低了 倍。Llama-3-70B 在上面 MHA 假设下的 10 GiB 直接砍到约 160 MiB。
但 MQA 有代价:因为所有 query 都用同一份 key、value,模型表达能力降低。PaLM 论文(Chowdhery et al., 2022)和 GShard 系列实验都报告 MQA 会带来一定程度的质量损失,特别是在小模型上更明显。
Grouped-Query Attention (GQA)
Ainslie et al. 在 2023 年提出的折中方案:把 query 头分组,组内共享一组 key 和 value。
设 个 query 头, 个 KV 头(且 是 的整数倍)。每 个 query 头共享一组 K、V:
GQA 是 MHA 和 MQA 的连续插值:
- → MHA(每个 query 头独享 K、V)
- → MQA(所有 query 头共享 K、V)
- 中间值 → GQA
flowchart TB
subgraph mha ["MHA: h=8"]
Q1[Q1-K1-V1]
Q2[Q2-K2-V2]
Q3[Q3-K3-V3]
Q4[Q4-K4-V4]
Q5[Q5-K5-V5]
Q6[Q6-K6-V6]
Q7[Q7-K7-V7]
Q8[Q8-K8-V8]
end
subgraph gqa ["GQA: hq=8, hkv=4"]
GQ1["Q1, Q2 共享 K1-V1"]
GQ2["Q3, Q4 共享 K2-V2"]
GQ3["Q5, Q6 共享 K3-V3"]
GQ4["Q7, Q8 共享 K4-V4"]
end
subgraph mqa ["MQA: hq=8, hkv=1"]
MQ["Q1, Q2, ..., Q8 全部共享 K-V"]
end
GQA 在 Llama-2 之后成为主流。Llama-3-70B 用的就是 的 GQA——KV Cache 比 MHA 降到原来的 1/8(4K 上下文下从 10 GiB 降到 1.25 GiB),同时质量损失非常小(论文实验显示与 MHA 几乎不可区分)。
主流模型的注意力变种汇总:
| 模型 | h_q | h_kv | 类型 |
|---|---|---|---|
| GPT-3 175B | 96 | 96 | MHA |
| Llama-1 7B | 32 | 32 | MHA |
| Llama-2 70B | 64 | 8 | GQA |
| Llama-3 70B | 64 | 8 | GQA |
| Mistral 7B | 32 | 8 | GQA |
| PaLM 540B | 48 | 1 | MQA |
| GPT-4(推测) | – | – | GQA |
| DeepSeek-V2/V3 | 128 | 128 (MLA) | MLA(更激进的 KV 压缩,第 15 章会讲) |
可以看到 2023 年之后开源模型几乎全部从 MHA 切到 GQA,主要驱动就是推理时的 KV Cache 压力。MQA 太激进,GQA 是「质量损失最小、显存收益最大」的甜点。
3.7 几何上理解 Multi-Head:把空间切成子空间
我们在 3.1 节用「不同头学不同关系」给了直觉,3.5 节用 BERT 实验给了证据,3.6 节讲了工程演化。这一小节我们再换一个角度——几何视角——给 Multi-Head 一个更精确的解读。
把 看作一个线性投影:把 维空间投到一个 维子空间。每个头有自己的 投影对,把输入投到一个专属的 维子空间,在这个子空间里计算注意力。
如果所有头的投影矩阵随机初始化,这些子空间在 维空间里大致是相互正交(或低相关)的——也就是说,每个头看的是输入信息的一个不同的「视角」。在子空间 1 看上去高度相关的两个 token,在子空间 2 可能完全没关系。
模型训练的过程,就是让每个 旋转到一个「让某种特定相关性凸显出来」的方向。最终:
- 头 1 的子空间编码了「相邻位置的关系」
- 头 2 的子空间编码了「主谓一致」
- 头 3 的子空间编码了「指代」
- ……
这不是模型「学到了多套独立的 attention」,而是同一个 d_model 维空间被同时切分成多个互补的视角,每个视角看到的是同一份输入的不同侧面。
这个几何观点能解释为什么 Multi-Head 的总参数量不变但表达力增强:参数量取决于「自由度」,而 个头的 拼起来等价于一个 的大矩阵——参数自由度和单头的 一样。但 Multi-Head 通过对最后这个矩阵施加分块结构(每 列对应一个独立的 attention 子空间),约束了模型一定要把表达力切成 个独立部分,强行避免了所有注意力都坍缩到同一种关系上。
约束就是先验。Multi-Head 的先验是「语言里有多种关系,必须分别建模」。这个先验和数据本身吻合得很好,所以 MHA 在表达力上反而比单头大 的 attention 要强。
3.8 如何选 head 数
现在你看模型卡片里 num_attention_heads 这个数字,应该不再是个抽象选项。来对照几条工程经验法则:
法则一: 必须能被 整除。否则没法均匀分头。常见约定是 (Vaswani 原论文这么用)或 128(Llama 系列), 由 决定。
法则二:每个头的维度 不能太小。从 2.5 节我们知道 影响 softmax 的尖锐度—— 太小(比如 8)时即使有 也压不住分布的不稳定性。实践中 是合理下限,多数主流模型是 64 或 128。
法则三:头数太多不一定更好。Multi-Head 的「先验」在 太大时会变成「过度切分」——每个子空间太小,反而让模型学不到有意义的表达。GPT-3 用 96 头是相当激进的选择,Llama-1 用 32 头是更稳妥的。
法则四:推理压力大时用 GQA。如果要在线服务大量用户、KV Cache 是瓶颈,GQA 的 是被工业验证过的甜点。
3.9 Multi-Head 的训练动力学
为了完整,我们再讲一下 Multi-Head 在训练过程中的「动力学」——也就是它在反向传播下怎么演化。这一小节稍微高阶,跳过不影响后续阅读。
每个头独立做注意力,反向传播时也独立。 的梯度只来自第 个头的 loss 贡献——头与头之间在反向传播路径上是完全分离的,唯一的耦合是输出投影 (它把所有头的输出合到一起做最后的变换)。
这意味着两件事:
第一,初始化非常重要。如果两个头初始化得太接近(比如几乎相同的随机种子),它们在训练中会沿着相似的梯度方向演化,最终学到几乎相同的 attention 模式——「头崩塌」(head collapse)。这就是为什么 PyTorch 的 nn.MultiheadAttention 默认用 Xavier/Kaiming 初始化,让每个头从足够不同的起点出发。
第二,稀疏注意力下的训练不稳定。如果某个头的 attention 分布在初始化阶段就接近饱和(softmax 输出几乎是 one-hot),它的梯度会很小——头几乎不再更新,永远停在初始的那个模式。这也是为什么 缩放(第 2.5 节)必须存在的原因之一。
GPT-2 / GPT-3 时代曾流行过一种做法:用更小的 std 初始化注意力投影,配合 LayerNorm 的特殊放置,来稳定 Multi-Head 的训练动力学。Llama 把这套实践标准化为 RMSNorm + Pre-LN(第 5 章会讲)。
3.10 为什么 Cross-Attention 也用 Multi-Head
Multi-Head 的设计在 Self-Attention 里讲的最多,但它在 Cross-Attention 里同样适用——而且同样重要。
Cross-Attention 是 Encoder-Decoder 架构里 Decoder 的一种 attention 变体: 来自 Decoder, 来自 Encoder:
机器翻译里 Cross-Attention 是 Decoder 「回头看英文原文」的机制(参考第 1 章 Bahdanau Attention)。它同样可以做成 Multi-Head——多个头并行「看」原文,每个头关注一种对齐关系(词级对齐、短语级对齐、句法对齐等)。
不过到了 Decoder-only 模型时代(GPT 系列),Encoder 和 Cross-Attention 都被消解了——所有事都用 Self-Attention 加上因果掩码完成。第 6 章会讲清楚这次架构演化的代价和收益。
本章小结
- 单头 Self-Attention 一次只能学一种关系——因为它被 投影出的那个 维子空间限制住了。
- Multi-Head 通过把空间切成 个互补子空间,让每个头独立学一类相关性(语法、共指、语义、位置等),最后用 把它们合在一起。
- MHA 的总参数量等于等价单头——分头不是加参数,而是给同样的参数量加结构约束。
- 工程实现上所有头被打包到一次大矩阵乘法,没有 for 循环,GPU 上完全并行。
- 不同头有自发分工:BERT 实验里能识别出位置头、句法头、共指头、sink 头等专家。但很多头是冗余的——这促成了 MQA / GQA 的演化。
- MHA → MQA → GQA 是被推理时 KV Cache 压力推动的演化。今天主流模型几乎全部用 GQA,典型配置是 是 的 8 倍。
- 几何视角:Multi-Head 是把 空间切成 个子空间,每个子空间从输入的某个独立视角看相关性。
下一章我们解决一个第 2 章里挖下的坑:Self-Attention 不感知位置。(animal, it) 和 (it, animal) 在 Self-Attention 看来是一样的。语言里位置非常重要,必须从外部注入。第 4 章会讲位置编码——从原始 sinusoidal 到 RoPE 的完整演化路径。
延伸阅读
- Vaswani et al., Attention Is All You Need, NeurIPS 2017,3.2.2 节是 Multi-Head 的原始定义。
- Clark et al., What Does BERT Look At? An Analysis of BERT's Attention, BlackboxNLP 2019. arXiv:1906.04341
- Voita et al., Analyzing Multi-Head Self-Attention: Specialized Heads Do the Heavy Lifting, ACL 2019——剪头实验。
- Michel et al., Are Sixteen Heads Really Better than One?, NeurIPS 2019——头剪枝的理论分析。
- Shazeer, Fast Transformer Decoding: One Write-Head is All You Need, 2019——MQA 的原始论文。
- Ainslie et al., GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints, EMNLP 2023——GQA 的原始论文。
- BertViz(vig-2019)—— Multi-Head 注意力的可视化工具,强烈推荐自己跑一下看不同头的分布。