Transformer 解剖:从 Attention 到推理系统

第 3 章 Multi-Head Attention:分头的代数与几何意义

作者 杨艺韬 · 5,617 字

第 3 章 Multi-Head Attention:分头的代数与几何意义

第 2 章我们把 Self-Attention 的内核拆开看了。但实际工程里你看到的不会是单头 Attention,永远是 Multi-Head Attention——多个头并行计算。Llama 70B 有 64 个头,GPT-3 有 96 个头,DeepSeek-V3 有 128 个头。

为什么要多头?把同一件事并行做多遍有什么好处?这一章我们要彻底讲清楚两件事:第一,Multi-Head 在数学和几何上到底在做什么——它不是「把单头跑 N 遍取平均」;第二,从原始 MHA 到今天主流的 GQA / MQA,这条演化路线背后的工程压力是什么。

读完这章你会理解:为什么 Llama 3 的 70B 模型只有 8 个 KV 头但有 64 个 Q 头、为什么 GPT-3 用 MHA 而 Llama 3 改用了 GQA、为什么 PaLM 用 MQA 而 Mistral 早期用 GQA。这些选择背后都是同一道工程账。

3.1 单头不够:语言里同时存在多种关系

我们先停下来想一个问题:自然语言里,一个 token 和上下文之间的关系只有一种吗?

回到第 2 章那个例子:

The animal didn't cross the street because it was too tired.

"it" 这个 token 至少有这几条不同的关系链:

  1. 指代关系:it → animal(共指)
  2. 语法关系:it 是 was 的主语
  3. 语义关系:it / tired 描述同一个实体的状态
  4. 位置关系:it 处在主从句的从句开头
  5. 修辞关系:整个 because-从句解释为什么 didn't cross

如果你只用一个 Self-Attention 头来处理,所有这些关系都得挤在同一个注意力分布里。结果就是:要么模型只学到最强的那种关系(比如指代),其他的都被压制;要么模型试图同时学所有关系,导致每种都学得不够好。

更深的问题是几何上的。一个头的注意力权重 αij\alpha_{ij} 是由一组 WQ,WKW_Q, W_K 投影出来的 qi,kjq_i, k_j 的点积决定的。WQ,WKW_Q, W_K 一旦选定,整个注意力分布就被限制在它们投影出来的那个 dkd_k 维子空间里。而一个子空间一次只能编码一类相关性——你没法用同一组投影矩阵同时编码「语法」和「语义」这两种正交的语言学结构。

Multi-Head 的解决方案非常直接:让模型有多组 WQ,WK,WVW_Q, W_K, W_V,每一组独立学一种关系,然后把它们的输出拼起来。每个头活在自己那个子空间里、专心学一类关系,所有头加起来就覆盖了语言里多样的相关性结构。

flowchart LR
  X[输入 x_i] --> H1["Head 1<br/>学指代关系"]
  X --> H2["Head 2<br/>学语法关系"]
  X --> H3["Head 3<br/>学语义关系"]
  X --> HD[...]
  X --> HN["Head h<br/>学修辞关系"]
  H1 --> CONCAT[拼接所有头的输出]
  H2 --> CONCAT
  H3 --> CONCAT
  HD --> CONCAT
  HN --> CONCAT
  CONCAT --> WO["× W_O 投影"] --> OUT[最终输出]

注意「学指代 / 学语法 / 学语义」是为了讲解直觉给的标签。模型不会真的有这些标签——它们是模型从数据中自发涌现出来的分工。后面 3.5 节会用真实的 BERT 注意力可视化让你看到这种分工长什么样。

3.2 Multi-Head 的代数定义

把直觉翻译成数学。给定输入 XRT×dmodelX \in \mathbb{R}^{T \times d_{\text{model}}}、头数 hh、每个头的维度 dk=dv=dmodel/hd_k = d_v = d_{\text{model}} / h

第一步:每个头有自己的一组投影矩阵 WQ(i),WK(i),WV(i)Rdmodel×dkW_Q^{(i)}, W_K^{(i)}, W_V^{(i)} \in \mathbb{R}^{d_{\text{model}} \times d_k},其中 i=1,,hi = 1, \dots, h

第二步:每个头独立做一次完整的 Self-Attention:

headi=Attention(XWQ(i),XWK(i),XWV(i))\text{head}_i = \text{Attention}(X W_Q^{(i)}, X W_K^{(i)}, X W_V^{(i)})

每个头的输出形状是 (T,dv)(T, d_v)

第三步:把 hh 个头的输出沿特征维拼起来(不是相加!这是常见误解),得到形状 (T,hdv)=(T,dmodel)(T, h \cdot d_v) = (T, d_{\text{model}})

Concat(head1,,headh)RT×dmodel\text{Concat}(\text{head}_1, \dots, \text{head}_h) \in \mathbb{R}^{T \times d_{\text{model}}}

第四步:再过一次输出投影 WORdmodel×dmodelW_O \in \mathbb{R}^{d_{\text{model}} \times d_{\text{model}}}

MultiHead(X)=Concat(head1,,headh)WO\text{MultiHead}(X) = \text{Concat}(\text{head}_1, \dots, \text{head}_h) W_O

合起来就是这一个公式:

MultiHead(X)=Concati=1h[softmax((XWQ(i))(XWK(i))Tdk)XWV(i)]WO\text{MultiHead}(X) = \text{Concat}_{i=1}^{h}\left[\text{softmax}\left(\frac{(X W_Q^{(i)})(X W_K^{(i)})^T}{\sqrt{d_k}}\right) X W_V^{(i)}\right] W_O

看着复杂,但每一步都是第 2 章已经讲过的运算,只是被并行复制了 hh 份。

Concat 是「沿最后一维拼接」——Head1 (T, d_v)Head2 (T, d_v) 拼起来变成 (T, 2·d_v)。这个操作在 PyTorch 里是 torch.cat([head1, head2, ...], dim=-1)

3.3 关键的工程实现:一次大矩乘

上面的描述里我们说 "hh 个头独立做",听起来像要跑 hh 次 for 循环——这显然在 GPU 上不效率。实际工程上,所有头都被打包到一次大矩阵乘法里

把所有头的投影矩阵沿特征维拼起来:

WQbig=[WQ(1)WQ(2)WQ(h)]Rdmodel×hdkW_Q^{\text{big}} = [W_Q^{(1)} | W_Q^{(2)} | \dots | W_Q^{(h)}] \in \mathbb{R}^{d_{\text{model}} \times h \cdot d_k}

由于 hdk=dmodelh \cdot d_k = d_{\text{model}},所以 WQbigRdmodel×dmodelW_Q^{\text{big}} \in \mathbb{R}^{d_{\text{model}} \times d_{\text{model}}}WKbig,WVbigW_K^{\text{big}}, W_V^{\text{big}} 同理。

一次大矩乘把所有头的 Q/K/V 算出来:

Qbig=XWQbigRT×dmodelQ^{\text{big}} = X W_Q^{\text{big}} \in \mathbb{R}^{T \times d_{\text{model}}}

接下来用一个 reshape,把 dmodeld_{\text{model}} 这一维拆成 (h,dk)(h, d_k)

QbigRT×dmodelreshapeQRh×T×dkQ^{\text{big}} \in \mathbb{R}^{T \times d_{\text{model}}} \xrightarrow{\text{reshape}} Q \in \mathbb{R}^{h \times T \times d_k}

PyTorch 里这一步是 Q.view(T, h, d_k).transpose(0, 1)。然后所有头的 attention 就可以用一个带 batch 维的矩阵乘法统一算:

# Q, K, V 形状都是 (h, T, d_k)
S = Q @ K.transpose(-2, -1)        # (h, T, T)
S = S / math.sqrt(d_k)
A = torch.softmax(S, dim=-1)        # (h, T, T)
out = A @ V                         # (h, T, d_k)
out = out.transpose(0, 1).contiguous().view(T, d_model)
out = out @ W_O                     # (T, d_model)

整个过程没有 for 循环。所有 hh 个头的全部计算被压到 4 次矩阵乘法(QKV 投影各一次 + 输出投影一次)加 1 次 softmax 里——这是 Multi-Head 在工程上能高效跑的根本原因。

flowchart LR
  X["X<br/>(T, d_model)"] --> WQB["× W_Q_big<br/>(d_model, d_model)"]
  WQB --> QB["Q_big<br/>(T, d_model)"] --> RS["reshape<br/>(h, T, d_k)"]
  RS --> Q["Q<br/>(h, T, d_k)"]
  X --> KV["类似得到 K, V<br/>(h, T, d_k)"]
  Q --> ATT["batched<br/>scaled dot product<br/>attention"]
  KV --> ATT
  ATT --> CAT["transpose + reshape<br/>(T, d_model)"]
  CAT --> WO["× W_O"] --> OUT["MultiHead 输出<br/>(T, d_model)"]

3.4 参数量与计算量

来算一下 Multi-Head Attention 的参数量。每个头有 WQ(i),WK(i),WV(i)W_Q^{(i)}, W_K^{(i)}, W_V^{(i)},每个是 dmodel×dkd_{\text{model}} \times d_k;总共 hh 个头:

paramsQKV=h3(dmodeldk)=3dmodel2\text{params}_{QKV} = h \cdot 3 \cdot (d_{\text{model}} \cdot d_k) = 3 \cdot d_{\text{model}}^2

(用了 hdk=dmodelh \cdot d_k = d_{\text{model}}

加上输出投影 WORdmodel×dmodelW_O \in \mathbb{R}^{d_{\text{model}} \times d_{\text{model}}}

paramsMHA=3dmodel2+dmodel2=4dmodel2\text{params}_{\text{MHA}} = 3 d_{\text{model}}^2 + d_{\text{model}}^2 = 4 d_{\text{model}}^2

注意:Multi-Head 的总参数量等于「等价单头但 dk=dmodeld_k = d_{\text{model}}」的参数量——分头并没有让参数变多。这是一个非常重要的观察:MHA 不是用更多参数换更强表达力,而是用同样的参数实现了结构上更丰富的表达(多个独立子空间)。

计算量呢?一次 Self-Attention 的计算量主要是两块:

  1. QKV 投影O(Tdmodel2)O(T \cdot d_{\text{model}}^2)
  2. Attention 矩阵(QK^T、softmax、AV):O(T2dmodel)O(T^2 \cdot d_{\text{model}})

对 Multi-Head 来说总数和单头一致(因为参数量一致),分头只是把计算重新组织。但当 TT 很大时,第 2 项 O(T2dmodel)O(T^2 \cdot d_{\text{model}}) 是主导——这是 Transformer 在长上下文下的瓶颈,第 13 章会展开。

3.5 不同头到底学到了什么

直觉上我们说「不同头学不同关系」,是真的吗?这个问题在 BERT 出来之后被研究得很透彻。

2019 年 Clark et al. 的论文 What Does BERT Look At? An Analysis of BERT's Attention 系统分析了 BERT-base(12 层 × 12 头 = 144 个头)的注意力分布,发现了几类典型的「专家头」:

flowchart TB
  subgraph BERT 注意力分工 简化示意
    L1[Layer 1 - 主要是位置头与局部句法]
    L4[Layer 4-6 - 大量句法依存头]
    L8[Layer 8-10 - 共指 与 语义聚类头]
    L12[Layer 11-12 - 任务相关头 与 sink 头]
  end
  L1 --> L4 --> L8 --> L12

这个分工不是模型被显式训练出来的——它是从大规模语言建模这个单一目标里自发涌现的。这是 Multi-Head 设计真正的妙处:你不需要告诉模型「这个头学语法、那个头学指代」,模型自己会按照「让 loss 最低」的原则自动分工。

下面这张图是我们用第 9 章那个 1.7M 参数的 mini-GPT(4 层 × 4 头)在唐诗上训完后,把《静夜思》整首灌进去得到的最后一层 4 个头的 attention 模式。可以看到分工已经清晰浮现:

mini-GPT 最后一层 4 个头的 attention 模式

仔细看 Head 3——它有两条平行的对角线,一条是当前位置(紧挨着自己),另一条比当前位置往前 5 个 token。这正是五言诗的「对偶结构」——「举头望明月」对「低头思故乡」、「举头」对「低头」、「望明月」对「思故乡」——每个字在对偶句的对应位置上都有一个共指对象。模型从未被告诉过「这是首五言诗,第 6 字对第 11 字」,但 Head 3 自发学到了这个隔 5 字的呼应模式。

Head 1、2 主要看相邻位置(局部句法),Head 4 在更长距离上分散——「位置头 + 局部句法头 + 对偶头 + 长距离汇聚头」的分工自发涌现。这和 BERT 实验里观察到的「不同头学不同关系」性质完全一致——只不过我们的模型小了三个数量级。

但有一个发现让后续的工作产生了深远影响:很多头其实是冗余的。Voita et al. 在 Analyzing Multi-Head Self-Attention(2019)和 Michel et al. 在 Are Sixteen Heads Really Better than One?(2019)里证明,BERT 的 144 个头中相当一部分可以被剪掉而几乎不影响下游任务表现——一些层只剩 1–2 个头还能正常工作。

这个观察后来引出了一条工程优化路线:既然很多头是冗余的,能不能在保留 query 多头的同时,让 key 和 value 共享?——这就是 MQA / GQA 的来历。

3.6 推理瓶颈推动的演化:MHA → MQA → GQA

进入大模型时代后,Multi-Head 的代价开始显形。问题不在训练阶段,而在推理阶段的 KV Cache

第 15 章会详细讲 KV Cache,这里先给一个直觉:自回归生成时,每生成一个新 token,模型不需要重新算前面所有 token 的 K 和 V——它们可以缓存下来反复用。这个缓存的大小,对一个有 hh 个头、序列长度 TT、每个头维度 dkd_kLL 层的模型来说是:

KV Cache size=2LhTdkbytes per element\text{KV Cache size} = 2 \cdot L \cdot h \cdot T \cdot d_k \cdot \text{bytes per element}

「2」是 K 和 V 各一份。我们用 Llama-3-70B 算一笔账(先假设它没有用 GQA、保持 MHA):L=80L = 80h=64h = 64dk=128d_k = 128,FP16 就是 2 bytes。在 T=4096T = 4096 时:

280644096128210 GiB2 \cdot 80 \cdot 64 \cdot 4096 \cdot 128 \cdot 2 \approx 10 \text{ GiB}

光是 KV Cache 一个 batch 就要 10 GiB——而 Llama-3-70B 模型权重本身在 FP16 下是 140 GB。如果你想 batch=8 同时跑 8 个用户的请求,KV Cache 一项就要 80 GiB——已经把整张 H100 的显存吃完。这是「如果继续用 MHA」的反事实——下面我们会看到正是这个压力推动了 GQA / MLA 的演化。

这就是 KV Cache 的工程压力:它正比于 head 数 hh。如果能降低有效的 KV head 数,KV Cache 立刻线性下降。

Multi-Query Attention (MQA)

最激进的优化是 2019 年 Noam Shazeer 提出的 MQA:所有 query 头共享同一组 key 和 value:

headi=Attention(XWQ(i),XWKshared,XWVshared)\text{head}_i = \text{Attention}(X W_Q^{(i)}, X W_K^{\text{shared}}, X W_V^{\text{shared}})

也就是说,QQ 还是 hh 个头,但 KKVV 只有 1 个头。

KV Cache 立刻从 2LhTdk2 L h T d_k 降到 2LTdk2 L T d_k——降低了 hh 倍。Llama-3-70B 在上面 MHA 假设下的 10 GiB 直接砍到约 160 MiB。

但 MQA 有代价:因为所有 query 都用同一份 key、value,模型表达能力降低。PaLM 论文(Chowdhery et al., 2022)和 GShard 系列实验都报告 MQA 会带来一定程度的质量损失,特别是在小模型上更明显。

Grouped-Query Attention (GQA)

Ainslie et al. 在 2023 年提出的折中方案:把 query 头分组,组内共享一组 key 和 value

hqh_q 个 query 头,hkvh_{kv} 个 KV 头(且 hqh_qhkvh_{kv} 的整数倍)。每 hq/hkvh_q / h_{kv} 个 query 头共享一组 K、V:

groupg shared K, V; queries qg(hq/hkv),,q(g+1)(hq/hkv)1\text{group}_g\text{ shared K, V; queries } q_{g \cdot (h_q/h_{kv})}, \dots, q_{(g+1)\cdot(h_q/h_{kv})-1}

GQA 是 MHA 和 MQA 的连续插值:

flowchart TB
  subgraph mha ["MHA: h=8"]
    Q1[Q1-K1-V1]
    Q2[Q2-K2-V2]
    Q3[Q3-K3-V3]
    Q4[Q4-K4-V4]
    Q5[Q5-K5-V5]
    Q6[Q6-K6-V6]
    Q7[Q7-K7-V7]
    Q8[Q8-K8-V8]
  end
  subgraph gqa ["GQA: hq=8, hkv=4"]
    GQ1["Q1, Q2 共享 K1-V1"]
    GQ2["Q3, Q4 共享 K2-V2"]
    GQ3["Q5, Q6 共享 K3-V3"]
    GQ4["Q7, Q8 共享 K4-V4"]
  end
  subgraph mqa ["MQA: hq=8, hkv=1"]
    MQ["Q1, Q2, ..., Q8 全部共享 K-V"]
  end

GQA 在 Llama-2 之后成为主流。Llama-3-70B 用的就是 hq=64,hkv=8h_q = 64, h_{kv} = 8 的 GQA——KV Cache 比 MHA 降到原来的 1/8(4K 上下文下从 10 GiB 降到 1.25 GiB),同时质量损失非常小(论文实验显示与 MHA 几乎不可区分)。

主流模型的注意力变种汇总:

模型 h_q h_kv 类型
GPT-3 175B 96 96 MHA
Llama-1 7B 32 32 MHA
Llama-2 70B 64 8 GQA
Llama-3 70B 64 8 GQA
Mistral 7B 32 8 GQA
PaLM 540B 48 1 MQA
GPT-4(推测) GQA
DeepSeek-V2/V3 128 128 (MLA) MLA(更激进的 KV 压缩,第 15 章会讲)

可以看到 2023 年之后开源模型几乎全部从 MHA 切到 GQA,主要驱动就是推理时的 KV Cache 压力。MQA 太激进,GQA 是「质量损失最小、显存收益最大」的甜点。

3.7 几何上理解 Multi-Head:把空间切成子空间

我们在 3.1 节用「不同头学不同关系」给了直觉,3.5 节用 BERT 实验给了证据,3.6 节讲了工程演化。这一小节我们再换一个角度——几何视角——给 Multi-Head 一个更精确的解读。

WQ(i)Rdmodel×dkW_Q^{(i)} \in \mathbb{R}^{d_{\text{model}} \times d_k} 看作一个线性投影:把 dmodeld_{\text{model}} 维空间投到一个 dkd_k 维子空间。每个头有自己的 WQ(i),WK(i)W_Q^{(i)}, W_K^{(i)} 投影对,把输入投到一个专属的 dkd_k 维子空间,在这个子空间里计算注意力。

如果所有头的投影矩阵随机初始化,这些子空间在 dmodeld_{\text{model}} 维空间里大致是相互正交(或低相关)的——也就是说,每个头看的是输入信息的一个不同的「视角」。在子空间 1 看上去高度相关的两个 token,在子空间 2 可能完全没关系。

模型训练的过程,就是让每个 WQ(i),WK(i)W_Q^{(i)}, W_K^{(i)} 旋转到一个「让某种特定相关性凸显出来」的方向。最终:

这不是模型「学到了多套独立的 attention」,而是同一个 d_model 维空间被同时切分成多个互补的视角,每个视角看到的是同一份输入的不同侧面。

这个几何观点能解释为什么 Multi-Head 的总参数量不变但表达力增强:参数量取决于「自由度」,而 hh 个头的 WQ(i)Rdmodel×dkW_Q^{(i)} \in \mathbb{R}^{d_{\text{model}} \times d_k} 拼起来等价于一个 dmodel×dmodeld_{\text{model}} \times d_{\text{model}} 的大矩阵——参数自由度和单头的 dk=dmodeld_k = d_{\text{model}} 一样。但 Multi-Head 通过对最后这个矩阵施加分块结构(每 dkd_k 列对应一个独立的 attention 子空间),约束了模型一定要把表达力切成 hh 个独立部分,强行避免了所有注意力都坍缩到同一种关系上

约束就是先验。Multi-Head 的先验是「语言里有多种关系,必须分别建模」。这个先验和数据本身吻合得很好,所以 MHA 在表达力上反而比单头大 dk=dmodeld_k = d_{\text{model}} 的 attention 要强。

3.8 如何选 head 数

现在你看模型卡片里 num_attention_heads 这个数字,应该不再是个抽象选项。来对照几条工程经验法则:

法则一:dmodeld_{\text{model}} 必须能被 hh 整除。否则没法均匀分头。常见约定是 dk=dv=64d_k = d_v = 64(Vaswani 原论文这么用)或 128(Llama 系列),hhdmodel/dkd_{\text{model}} / d_k 决定。

法则二:每个头的维度 dkd_k 不能太小。从 2.5 节我们知道 dkd_k 影响 softmax 的尖锐度——dkd_k 太小(比如 8)时即使有 dk\sqrt{d_k} 也压不住分布的不稳定性。实践中 dk32d_k \ge 32 是合理下限,多数主流模型是 64 或 128。

法则三:头数太多不一定更好。Multi-Head 的「先验」在 hh 太大时会变成「过度切分」——每个子空间太小,反而让模型学不到有意义的表达。GPT-3 用 96 头是相当激进的选择,Llama-1 用 32 头是更稳妥的。

法则四:推理压力大时用 GQA。如果要在线服务大量用户、KV Cache 是瓶颈,GQA 的 hkv=hq/8h_{kv} = h_q / 8 是被工业验证过的甜点。

3.9 Multi-Head 的训练动力学

为了完整,我们再讲一下 Multi-Head 在训练过程中的「动力学」——也就是它在反向传播下怎么演化。这一小节稍微高阶,跳过不影响后续阅读。

每个头独立做注意力,反向传播时也独立。WQ(i),WK(i),WV(i)W_Q^{(i)}, W_K^{(i)}, W_V^{(i)} 的梯度只来自第 ii 个头的 loss 贡献——头与头之间在反向传播路径上是完全分离的,唯一的耦合是输出投影 WOW_O(它把所有头的输出合到一起做最后的变换)。

这意味着两件事:

第一,初始化非常重要。如果两个头初始化得太接近(比如几乎相同的随机种子),它们在训练中会沿着相似的梯度方向演化,最终学到几乎相同的 attention 模式——「头崩塌」(head collapse)。这就是为什么 PyTorch 的 nn.MultiheadAttention 默认用 Xavier/Kaiming 初始化,让每个头从足够不同的起点出发。

第二,稀疏注意力下的训练不稳定。如果某个头的 attention 分布在初始化阶段就接近饱和(softmax 输出几乎是 one-hot),它的梯度会很小——头几乎不再更新,永远停在初始的那个模式。这也是为什么 dk\sqrt{d_k} 缩放(第 2.5 节)必须存在的原因之一。

GPT-2 / GPT-3 时代曾流行过一种做法:用更小的 std 初始化注意力投影,配合 LayerNorm 的特殊放置,来稳定 Multi-Head 的训练动力学。Llama 把这套实践标准化为 RMSNorm + Pre-LN(第 5 章会讲)。

3.10 为什么 Cross-Attention 也用 Multi-Head

Multi-Head 的设计在 Self-Attention 里讲的最多,但它在 Cross-Attention 里同样适用——而且同样重要。

Cross-Attention 是 Encoder-Decoder 架构里 Decoder 的一种 attention 变体:QQ 来自 Decoder,K,VK, V 来自 Encoder:

CrossAttention(D,E)=softmax((DWQ)(EWK)Tdk)EWV\text{CrossAttention}(D, E) = \text{softmax}\left(\frac{(D W_Q)(E W_K)^T}{\sqrt{d_k}}\right) E W_V

机器翻译里 Cross-Attention 是 Decoder 「回头看英文原文」的机制(参考第 1 章 Bahdanau Attention)。它同样可以做成 Multi-Head——多个头并行「看」原文,每个头关注一种对齐关系(词级对齐、短语级对齐、句法对齐等)。

不过到了 Decoder-only 模型时代(GPT 系列),Encoder 和 Cross-Attention 都被消解了——所有事都用 Self-Attention 加上因果掩码完成。第 6 章会讲清楚这次架构演化的代价和收益。

本章小结

  1. 单头 Self-Attention 一次只能学一种关系——因为它被 WQ,WKW_Q, W_K 投影出的那个 dkd_k 维子空间限制住了。
  2. Multi-Head 通过把空间切成 hh 个互补子空间,让每个头独立学一类相关性(语法、共指、语义、位置等),最后用 WOW_O 把它们合在一起。
  3. MHA 的总参数量等于等价单头——分头不是加参数,而是给同样的参数量加结构约束
  4. 工程实现上所有头被打包到一次大矩阵乘法,没有 for 循环,GPU 上完全并行。
  5. 不同头有自发分工:BERT 实验里能识别出位置头、句法头、共指头、sink 头等专家。但很多头是冗余的——这促成了 MQA / GQA 的演化。
  6. MHA → MQA → GQA 是被推理时 KV Cache 压力推动的演化。今天主流模型几乎全部用 GQA,典型配置是 hqh_qhkvh_{kv} 的 8 倍。
  7. 几何视角:Multi-Head 是把 dmodeld_{\text{model}} 空间切成 hh 个子空间,每个子空间从输入的某个独立视角看相关性。

下一章我们解决一个第 2 章里挖下的坑:Self-Attention 不感知位置(animal, it)(it, animal) 在 Self-Attention 看来是一样的。语言里位置非常重要,必须从外部注入。第 4 章会讲位置编码——从原始 sinusoidal 到 RoPE 的完整演化路径。

延伸阅读