Transformer 解剖:从 Attention 到推理系统

第 1 章 为什么是 Transformer:从 RNN 的瓶颈到一次降维打击

作者 杨艺韬 · 5,749 字

第 1 章 为什么是 Transformer:从 RNN 的瓶颈到一次降维打击

我们要回答的问题是:为什么 Transformer 取代了 RNN?

这个问题表面上是技术史,背后其实是一道工程账。RNN 不是被某个理论证明「不行」的,它是被三道现实墙挡住、爬不动了,而 Transformer 用一次架构换型把这三道墙同时拆掉。理解这次换型的代价和收益,是理解整个大模型时代的起点。

这一章会按下面的顺序展开:先把 RNN 时代的序列建模图谱画清楚,再逐条剖析 RNN 的三道瓶颈,然后看 2014–2017 年间的 Attention 是怎么作为一种「外挂」开始改写这件事的,最后看 2017 年的 Transformer 论文怎么把外挂变成主体——直接把 RNN 删掉、只留 Attention。读完这一章你会拿到两样东西:一是看任何序列模型时都通用的坐标系(顺序 / 并行 / 长距离 / 上下文),二是后续 18 章的路线图。

1.1 序列建模这个问题

先把场景定下来。序列建模这个词听起来抽象,落到具体任务就是这些:

它们都符合一个共同结构:输入是一串带顺序的 token,模型要把这串 token 压缩成某种「上下文表示」(context representation),再用这个表示干下游的事——预测下一个 token、翻译、分类、生成。

对一个序列建模模型,我们大致会问它三个问题:

flowchart LR
  Q1[1 它能不能看懂当前 token 的<br/>上下文?]
  Q2[2 它能不能记住<br/>很远以前的 token?]
  Q3[3 它训练时能不能<br/>充分利用 GPU 并行?]
  Q1 --> ANS{决定能力上限}
  Q2 --> ANS
  Q3 --> ANS{决定能力上限}

第一个问题决定了模型的「理解力」——能不能根据上下文给一个 token 合适的表示。第二个问题决定了模型处理长文本的能力——一篇 5000 字的文章里第 1 段和第 50 段的呼应能不能被捕捉到。第三个问题决定了能不能把模型训得足够大——只要能并行,硬件的钱就花得动;不能并行,就只能在小规模上原地打转。

Transformer 之前主流回答这三个问题的方法是 RNN(Recurrent Neural Network,循环神经网络)家族。我们先看 RNN 是怎么回答的,再看它在哪一个问题上栽了跟头。

1.2 RNN 时代:用「时间」串起序列

RNN 的核心思想用一句话可以讲完:让模型按顺序逐个读取 token,每读一个就把当前看到的内容压缩到一个隐藏状态里,把这个隐藏状态带到下一步

形式化一点:给定输入序列 x1,x2,,xTx_1, x_2, \dots, x_T,RNN 维护一个隐藏状态序列 h0,h1,,hTh_0, h_1, \dots, h_T,更新规则是:

ht=f(ht1,xt)h_t = f(h_{t-1}, x_t)

其中 ff 是一个由神经网络参数化的函数。最朴素的 vanilla RNN 是:

ht=tanh(Whhht1+Wxhxt+b)h_t = \tanh(W_{hh} h_{t-1} + W_{xh} x_t + b)

这意味着读到第 tt 个 token 时的隐藏状态 hth_t,是把前一个隐藏状态 ht1h_{t-1} 和当前 token 的嵌入 xtx_t 一起经过一次线性变换+非线性变换得到的。整个过程像下面这样展开:

flowchart LR
  X1[x_1] --> H1[h_1]
  H0[h_0] --> H1
  X2[x_2] --> H2[h_2]
  H1 --> H2
  X3[x_3] --> H3[h_3]
  H2 --> H3
  X4[...] --> H4[...]
  H3 --> H4
  X5[x_T] --> H5[h_T]
  H4 --> H5

这种结构有一个非常优雅的性质:它对序列长度是不变的。无论输入是 5 个 token 还是 5000 个 token,模型都用同一组参数 Whh,Wxh,bW_{hh}, W_{xh}, b,只是循环的次数不一样。这一性质让 RNN 在 2014 年之前是处理变长序列的标准答案。

RNN 还有一个进一步的优势:它有一个内置的「时间」概念hth_t 自然地依赖 ht1h_{t-1},依赖 ht2h_{t-2},一直依赖到 h0h_0。理论上,远古 token 的信息可以通过这条递归链条一路传到当前。

但实际跑起来,问题马上来了。

1.3 第一道墙:长距离依赖会衰减

理论上 hTh_T 包含了 x1,,xTx_1, \dots, x_T 的全部信息。实际上不是。

为什么?因为 hth_t 是用 ht1h_{t-1} 经过一次非线性变换得到的,意味着 x1x_1 的信息要传到 hTh_T,要经过 T1T-1 次变换。每一次变换都是一次「揉」——把当前 token 的信息揉进去、把无关信息抹掉一些。揉的次数太多,远古信息就被新内容稀释、覆盖、压缩到几乎不可识别。

更糟的是数学上的衰减。考虑反向传播:要更新 WhhW_{hh},损失对 WhhW_{hh} 的梯度需要从 TT 步反向传到第 1 步。链式法则把这个过程展开后,会得到一连串雅可比矩阵的乘积:

Lh1=LhTt=2Ththt1\frac{\partial \mathcal{L}}{\partial h_1} = \frac{\partial \mathcal{L}}{\partial h_T} \cdot \prod_{t=2}^{T} \frac{\partial h_t}{\partial h_{t-1}}

如果每一步的雅可比谱半径都小于 1,连乘起来梯度会指数级地趋向 0——这就是著名的梯度消失(vanishing gradient)。如果谱半径大于 1,连乘起来梯度会指数级爆炸——梯度爆炸(exploding gradient)。

工程师为了对付这两个毛病,发明了 LSTM(Long Short-Term Memory,1997)和后来的 GRU(Gated Recurrent Unit,2014)。它们在 vanilla RNN 的基础上加了一组「门」(gate),允许信息绕过非线性变换、沿着一条「细胞状态」(cell state)流过:

flowchart LR
  subgraph LSTM step
    XT[x_t] --> GATE[输入门<br/>遗忘门<br/>输出门]
    HT1[h_t-1] --> GATE
    CT1[c_t-1] --> CELL[细胞状态 c_t]
    GATE --> CELL
    CELL --> HT[h_t]
  end

LSTM 让长距离依赖问题缓解了不少。在 2014–2016 年的机器翻译比赛里,LSTM 是绝对主力,BLEU 分数被一路推高。

缓解不等于解决。LSTM 把信息传到几百 token 之外仍然非常困难。一个直观的指标是:当时最好的 LSTM 翻译模型,在长度超过 50 词的句子上表现明显下滑,到 100 词以上几乎是灾难。而真实文本里 100 词以上的段落比比皆是。

更深层的原因是:无论 LSTM 设计得多巧妙,它本质上仍然是「用一个固定大小的向量去压缩任意长度的序列」hth_t 是一个固定维度(典型 512 或 1024)的向量。让一个 1024 维向量塞下 1000 个 token 的全部上下文,理论上是可以做到的(信息容量够),但学习起来极其困难——因为模型必须在每一步都做出无损取舍,决定哪些远古信息值得继续往后带、哪些可以丢掉。

这是 RNN 的第一道墙:信息瓶颈。它在结构上就要求把任意长度的序列压成定长向量,而这件事是物理性困难的。

1.4 第二道墙:训练不能并行

这是 RNN 真正的「死穴」,也是工业界最痛的那一刀。

RNN 的递归式 ht=f(ht1,xt)h_t = f(h_{t-1}, x_t) 意味着:要算 hth_t,必须先算出 ht1h_{t-1}。整个序列是严格按时间步顺序展开的,这个顺序是模型结构本身决定的,不是实现选择。

GPU 的强项是什么?大规模并行。一张 H100 GPU 有 14592 个 CUDA 核心、上百个 Tensor Core,它擅长在毫秒级里同时计算几万次矩阵乘法。但要让 RNN 跑在 GPU 上,每一步都必须等前一步完成——这相当于让一台 16 核服务器只用 1 个核工作,浪费了 93.75% 的算力。

具体一点。假设你训练一个机器翻译模型,输入是平均长度 100 token 的英文句子,batch size = 64。每一步 RNN 要做的核心计算是 ht=tanh(Whhht1+Wxhxt)h_t = \tanh(W_{hh} h_{t-1} + W_{xh} x_t),这是一次矩阵乘法。你希望把 64 个样本打包到一个 batch 里、把 100 个时间步打包到一次大矩阵乘里,但做不到——时间步之间是严格依赖的。

这意味着每个 batch 你只能并行 64 个样本(batch 维度),但必须串行执行 100 个时间步。GPU 的算力大部分时间在等待上一步完成、写回 HBM、再读出来给下一步用。在实践中,训练一个中等规模的 LSTM 翻译模型在 8 张 V100 上要跑 1–2 周,而模型参数量通常只有几亿。

工程师们尝试过各种「让 RNN 也并行起来」的招数——SRU(Simple Recurrent Unit)、QRNN(Quasi-Recurrent Neural Networks)、ConvSeq2Seq(用 CNN 替代 RNN 来获得并行)——都各有各的妥协,没有一个真正打开了天花板。

这是 RNN 的第二道墙:严格的顺序依赖让它无法在训练时充分利用现代硬件。在 GPU 算力一年翻倍的时代,这道墙意味着模型规模被锁死——你想训一个 100B 参数的 RNN?理论上可行,实际上要跑十年。

1.5 第三道墙:注意力作为外挂

到 2014 年,研究者们开始意识到 RNN 的信息瓶颈问题,并发明了第一种 Attention 机制——Bahdanau Attention(出自 ICLR 2015 的论文 Neural Machine Translation by Jointly Learning to Align and Translate)。

它的设计很巧妙。原本的 Encoder-Decoder 翻译模型是这样工作的:Encoder(一个 RNN)把整个英文句子读完,输出一个固定向量 hTh_T;Decoder(另一个 RNN)拿着这个 hTh_T 一步步生成中文。问题是 Decoder 在生成第 5 个中文 token 时,它能拿到的信息只有 hTh_T 这个被压扁的向量——它没法回头看英文原文。

Bahdanau Attention 说:别让 Decoder 只看 hTh_T,让它在生成每个 token 时都能「回头」看一眼整个英文序列,自己挑出当前最相关的那几个英文 token。具体做法是:

  1. Encoder 不只输出 hTh_T,而是把每一步的隐藏状态 h1,,hTh_1, \dots, h_T 全部保留,作为「记忆」。
  2. Decoder 在生成第 ii 个中文 token 时,根据自己当前的状态 si1s_{i-1},对每个英文位置 jj 算一个相关性分数 eije_{ij}
eij=a(si1,hj)e_{ij} = a(s_{i-1}, h_j)

其中 aa 是一个小神经网络。

  1. 把这些分数过一次 softmax 得到注意力权重 αij\alpha_{ij}
αij=exp(eij)k=1Texp(eik)\alpha_{ij} = \frac{\exp(e_{ij})}{\sum_{k=1}^{T} \exp(e_{ik})}
  1. 用注意力权重对 Encoder 的状态做加权求和,得到当前 Decoder step 的「上下文向量」 cic_i
ci=j=1Tαijhjc_i = \sum_{j=1}^{T} \alpha_{ij} h_j
  1. Decoder 用 cic_isi1s_{i-1} 一起算下一步状态 sis_i

这个机制优雅得令人发指——它让 Decoder 在每一步都能动态地「聚焦」到英文句子中相关的部分,而不是依赖一个被压扁的固定向量。在 BLEU 评测上,加了 Attention 的 RNN 翻译模型在长句子上的表现立刻拉开了和不带 Attention 的差距。

这是 Attention 第一次登场。但要注意:它只是 RNN 的一个外挂。Encoder 还是 RNN,Decoder 还是 RNN,Attention 是搭在它们中间的一个权重计算模块。它解决了第一道墙(信息瓶颈),但没解决第二道墙(训练并行)——RNN 的递归结构还在那里。

接下来的三年(2014–2017),研究者们沿着这个外挂思路探索了大量变体。Luong attention(2015)简化了 Bahdanau 的对齐函数。ConvSeq2Seq(2017)用 CNN 替代 RNN 来获得并行性,但 CNN 的「视野」是局部的,需要堆很多层才能看到长距离依赖。每条路都各走了一段,每条路都没能彻底打破天花板。

直到 2017 年 6 月那篇论文。

1.6 那一刀:删掉 RNN,只留 Attention

《Attention Is All You Need》的核心论断粗暴地直接:RNN 不是必要的。Attention 已经能独立完成序列建模。

具体怎么做?这篇论文给出了下面这张架构(简化版):

flowchart TB
  X[输入 token 序列] --> EMB[Embedding + 位置编码]
  EMB --> ATT1[Self-Attention 层]
  ATT1 --> FFN1[FFN 层]
  FFN1 --> ATT2[Self-Attention 层]
  ATT2 --> FFN2[FFN 层]
  FFN2 --> DOTS[...堆叠 N 层...]
  DOTS --> OUT[输出表示]

里面没有 RNN。也没有 CNN。只有 Self-Attention 层和 FFN(前馈网络)层堆叠。

Self-Attention 是 Bahdanau Attention 的一个推广:原来的 Attention 是「Decoder 看 Encoder」,现在变成「序列中每个位置都同时看序列中其他所有位置」。给定输入序列 x1,,xTx_1, \dots, x_T,Self-Attention 让每个 xix_i 都能根据它和其他 xjx_j 的相关性,加权聚合所有 xjx_j 的信息,得到一个新的表示 ziz_i

写成公式(这是后面整本书都会反复用到的核心公式):

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V

第 2 章会把这个公式逐项拆解。这里你只需要建立直觉:Self-Attention 是一种允许任意位置直接「看」任意其他位置的机制。两个相距 1000 token 的词之间,只需要一次矩阵乘法就能建立联系——而不是像 RNN 那样要传递 1000 步。

这一刀的意义有两重:

第一重,长距离依赖问题彻底解决。任意两个位置的「距离」在 Self-Attention 下都是 O(1) 的——一次矩阵乘法。RNN 时代的「梯度消失/信息衰减」从结构上就不存在了。

第二重,训练完全并行。Self-Attention 的计算 softmax(QKT/dk)V\text{softmax}(QK^T/\sqrt{d_k})V 是一个矩阵乘法,不存在「时间步顺序」的概念——所有位置可以同时计算。一个 batch 里的 64 个序列、每个序列的 100 个 token,可以全部同时算。这意味着 GPU 的算力不再被 RNN 的串行结构浪费,可以满负荷运转。

这两个性质合起来,就是 RNN 时代不可能达到的天花板被一次性打穿。论文当年报告的结果:在 WMT 2014 English-to-German 翻译任务上,Transformer Big 在 8 张 P100 上训练 3.5 天达到 28.4 BLEU,超过了当时所有 RNN-based 模型,而训练时间只有它们的 1/4 甚至 1/10。

1.7 代价:从 O(N) 到 O(N²)

天下没有免费的午餐。Transformer 用一次架构换型解决了 RNN 的所有结构性瓶颈,但它自己也带来了新的代价:计算复杂度从 O(N) 变成了 O(N²)

为什么?因为 Self-Attention 要让每个位置都看其他所有位置——一个长度为 NN 的序列,需要计算 N×NN \times N 个相关性分数。这个 QKTQK^T 矩阵的形状是 N×NN \times N,存储和计算都是平方级。

对比 RNN:RNN 一个时间步的计算是 O(1)(只看前一步),整个序列下来是 O(N)。

直觉上 O(N²) 听起来比 O(N) 差很多,但这里有两个工程现实让它在 2017 年的硬件上仍然划算:

第一,常数因子小。RNN 的每一步都包含矩阵乘法+非线性+门控,是一个完整的小神经网络计算;Self-Attention 的每一对位置只需要一次内积,常数小很多。

第二,并行性弥补了平方级。RNN 是 O(N) 但串行,墙上时间是 N 个 step;Self-Attention 是 O(N²) 但完全并行,墙上时间在 GPU 上接近 O(1)(被算力上限限制)。在序列长度 N=512 的典型机器翻译场景下,Self-Attention 的 wall-clock time 比 LSTM 快了 5–10 倍。

但 O(N²) 这个代价在大模型时代会反复出现,并且变成了核心痛点。当上下文长度从 512 拉到 4K、32K、128K、1M 时,O(N²) 的存储和计算都会变得无法承受——所以才有了后来的 Flash Attention(不降低复杂度但降低 HBM 访问)、滑动窗口注意力、Mamba 这些尝试。这条线索在第 13 章「长上下文之战」会展开。

维度 RNN/LSTM Transformer (Self-Attention)
任意两位置「距离」 O(N) 步 O(1) 步
单层时间复杂度 O(N · d²) O(N² · d)
训练能否并行 否(顺序依赖) 是(全并行)
长距离依赖 衰减严重 直达
信息瓶颈 固定隐藏状态 无(每位置都有完整上下文)
推理时增量延迟 O(d²) O(N · d)(KV Cache 后)

最右两列就是这次架构换型的全部得失。注意最后一行:在训练阶段 Transformer 全面胜出;但在推理(特别是逐 token 自回归生成)阶段,Transformer 反而要花精力去优化——这就是后面第六部分要解决的问题。

1.8 大模型时代为什么是 Transformer

到这里,你应该能理解为什么 Transformer 在 2017 年就赢了机器翻译。但它真正爆发是在 2018 年之后——BERT、GPT-1、GPT-2、GPT-3、ChatGPT 一路推到今天的大模型时代。这背后有三个 Transformer 独有、RNN 几乎不可能达到的性质:

性质一:Scaling 友好。Transformer 的训练并行性意味着只要硬件管够,模型就能放大。从 GPT-2 的 1.5B 到 GPT-3 的 175B,再到今天的 1T+ 模型,所有这些规模的训练都依赖 Transformer 能在数千张 GPU 上以接近线性的扩展效率训起来。RNN 在同样硬件上做不到——它的串行结构让单次 forward 的延迟随模型规模线性增长。

性质二:可堆叠性强。Transformer Block 是一个高度模块化的单位——Self-Attention + FFN + Residual + LayerNorm。把它堆 12 层是 BERT base,堆 96 层是 GPT-3,堆 80 层是 Llama 70B。每多一层带来的能力增益相对稳定,没有 RNN 那种「层数加多了反而训不动」的问题。第 5 章会讲清楚 Residual 和 LayerNorm 在这件事中扮演的关键角色。

性质三:上下文窗口可灵活扩展。RNN 时代你训了一个 LSTM 处理长度 512,要让它处理长度 5000 是非常痛苦的——隐藏状态会被信息淹没,需要重新设计。Transformer 时代,从 4K 扩到 128K 主要是位置编码层面的改造(RoPE base 调整、ALiBi 引入、位置插值),核心架构不动。第 13 章会讲清楚这条演化路径。

这三个性质合起来解释了为什么 Transformer 不只是一种技术替代,而是一次让模型规模、上下文长度、训练效率都有数量级提升的架构换型。今天的大模型工业站在这三个性质上。

1.9 这本书要带你看完的全景

下面这张图是你接下来 18 章会一一打开的全景。每个标签对应一个章节或一个章节族。

flowchart TB
  subgraph 第二部分 注意力机制
    SA[Self-Attention<br/>Q/K/V] --> MH[Multi-Head]
    MH --> POS[位置编码<br/>RoPE]
    POS --> BLK[Transformer Block<br/>FFN/LN/Residual]
  end

  subgraph 第三部分 架构家族
    BLK --> ARCH[Encoder/Decoder<br/>Decoder-only]
    ARCH --> PRE[预训练范式<br/>BERT vs GPT]
  end

  subgraph 第四部分 从零实现
    PRE --> CODE[50 行 PyTorch<br/>Self-Attention]
    CODE --> MGPT[mini-GPT<br/>训练写古诗]
    MGPT --> TOK[Tokenizer<br/>BPE/SentencePiece]
  end

  subgraph 第五部分 规模化
    TOK --> SCALE[Scaling Laws<br/>Chinchilla]
    SCALE --> MOE[Mixture of Experts]
    MOE --> LONG[长上下文<br/>4K → 1M]
  end

  subgraph 第六部分 推理系统
    LONG --> TWO[两阶段推理<br/>Prefill/Decode]
    TWO --> KV[KV Cache<br/>PagedAttention]
    KV --> QUAN[量化<br/>INT4/FP8]
    QUAN --> SPEC[投机解码]
    SPEC --> FA[Flash Attention<br/>TP/PP/EP]
  end

  FA --> END[Mamba/Hybrid<br/>Transformer 之后]

可以注意到这个全景的结构:机制 → 架构 → 实现 → 规模 → 推理 → 未来。每一层都建立在上一层之上,每一层都解决一组具体的工程问题。读这本书你不是在学一堆零散概念,而是在沿着工业界过去九年的演进路径,把每一个关键节点都打开看清楚。

特别强调一下第六部分(推理系统,第 14–18 章)。这是这本书相比其他 Transformer 教程最大的差异化所在——大多数教程讲到第 5 章就结束了,最多再讲一下 BERT 和 GPT 的差别。但今天的工业现实是:训练一次大模型的成本是几百万到几亿美元,而推理一次用户请求的成本被推理工程优化决定。一个 70B 模型,在朴素实现下单卡跑不起来,在 vLLM + PagedAttention + INT4 量化 + 投机解码 + Flash Attention 的组合下能在一张 H100 上跑到 100+ tokens/s。这其中每一层优化背后都是 Transformer 某个机制的工程后果——你不理解 KV Cache,就理解不了 PagedAttention;你不理解两阶段推理,就理解不了为什么 batch size 在 Prefill 和 Decode 阶段要不同处理。

读完第六部分,你会在面对一份生产级的 LLM 推理 trace 时,不再是「看不懂」,而是能直接指出「这个 spike 是 Prefill 阶段的 attention 计算密集」「这个低利用率窗口是 KV Cache miss 在等显存搬运」「这里能再加 30% 吞吐如果上 GQA」。这是真正能决定生产环境成本曲线的能力。

本章小结

这一章我们做了三件事:

  1. 画出了序列建模的坐标系:理解力、长距离依赖、训练并行——任何序列模型都要在这三个维度上接受拷问。

  2. 拆解了 RNN 的三道墙:信息瓶颈(定长隐状态压不住任意长度序列)、长距离依赖衰减(梯度消失/爆炸)、训练不能并行(递归结构锁死硬件利用)。Attention 作为「外挂」在 2014–2017 年缓解了第一道墙,但前两道仍没彻底解决。

  3. 讲清楚了 Transformer 的那一刀:删掉 RNN、只留 Attention。这次换型同时打掉三道墙,代价是计算复杂度从 O(N) 升到 O(N²)。这个代价在长上下文时代变成了主要痛点,也催生了第六部分要讲的整套推理系统优化。

下一章我们进入第二部分,把 Self-Attention 的公式 softmax(QKT/dk)V\text{softmax}(QK^T/\sqrt{d_k})V 一项一项打开——Q、K、V 各自代表什么、缩放因子 dk\sqrt{d_k} 为什么不能省、softmax 在这里干的不是分类而是「归一化注意力分布」。读完第 2 章,你会拿到 Transformer 整座大厦最底下的那块基石。

延伸阅读