Transformer 解剖:从 Attention 到推理系统

第 13 章 长上下文之战:从 4K 到 1M

作者 杨艺韬 · 5,236 字

第 13 章 长上下文之战:从 4K 到 1M

2020 年 GPT-3 的上下文窗口是 2048。2022 年 GPT-3.5 是 4K。2023 年 GPT-4 把它扩到 8K,再到 32K,再到 128K。同年 Anthropic 把 Claude 推到 100K、200K。2024 年 Gemini 1.5 直接跳到 1M、2M。

五年,1000 倍

为什么要长上下文?不是为了「能放更多文字」这种字面理由,而是为了承载几类全新的应用:

  1. Agent 的中间状态:一个 Agent 任务可能需要几十次 tool call、几千 token 的 reasoning trace、加几十 KB 的工具返回值——全都要塞进 context。
  2. 整本书 / 整个代码库 / 整个项目:「读 500 页文档回答问题」「把整个 React 仓库塞进去做重构」——这是大模型替代「先 RAG 再回答」的关键场景。
  3. 多模态:一段 1 小时视频转成 token 后是几百万 tokens,长上下文是处理它的前提。

但每次把上下文从 N 扩到 10N,工程挑战不是线性放大,而是踩到三道墙:Attention 的 O(N²) 计算KV Cache 的 O(N) 显存位置编码的外推。这一章我们沿着这三道墙逐一展开——你会看到长上下文不是某个单一技术的胜利,而是 Flash Attention、YaRN、滑动窗口、Mamba 等多条线协同把它压下去的。

读完这章你能:

13.1 三道墙

把所有问题放在一张图上:

flowchart TB
  GOAL["把上下文扩到 1M"] --> W1["墙 1: Attention 计算 O(N²)"]
  GOAL --> W2["墙 2: KV Cache 显存 O(N)"]
  GOAL --> W3["墙 3: 位置编码外推"]
  W1 --> S1["Flash Attention 降 HBM 访问<br/>滑动窗口<br/>稀疏注意力<br/>Mamba 线性复杂度"]
  W2 --> S2["GQA / MQA / MLA 减 KV<br/>量化<br/>PagedAttention 减碎片"]
  W3 --> S3["RoPE base 调整<br/>Position Interpolation<br/>YaRN<br/>原生长上下文训练"]

每一堵墙都对应至少一种主流解决方案。今天的长上下文模型是这三条线技术的综合应用。

13.2 墙一:Attention 的 O(N²) 计算

第 1.7 节我们讲过,Self-Attention 的计算量是 O(N2d)O(N^2 \cdot d)——QK^T 矩阵是 N×N、softmax 沿一行做、AV 也是 N×N 量级。

把数字代进去:

上下文长度 Attention 矩阵元素数 单层 FLOPs(d=4096)
4K 16 M 134 G
32K 1 G 8.6 T
128K 16 G 137 T
1M 1 T 8.4 P

可以看到从 4K 到 1M 是 256× 上下文长度,但 attention 计算量是 65536×(平方关系)——这就是 O(N²) 的可怕之处。即使 attention 是 transformer 计算量的小头(FFN 是 2/3),1M context 下 attention 也会把整个推理拖慢几十倍。

减小这个计算量的三类思路

  1. 不改算法,只改实现:Flash Attention 通过减少 HBM 访问把实际墙上时间砍掉数倍——但理论 FLOPs 不变。
  2. 结构性稀疏:滑动窗口、稀疏注意力、Local + Global 混合——只算「看起来重要」的子集,把 O(N²) 降到 O(N · k)。
  3. 替换架构:Mamba、RetNet 用线性递归式架构,复杂度 O(N) 但放弃了 attention 的全连接性。

我们逐个讲。

13.3 Flash Attention:不改算法的工程胜利

第 18 章会展开讲 Flash Attention,这里只给一个直觉。

Attention 的瓶颈其实不是计算(GPU 的 TFLOPs 够多),而是内存访问。计算 S=QKTS = QK^T 时,Q,KRN×dQ, K \in \mathbb{R}^{N \times d},结果 SRN×NS \in \mathbb{R}^{N \times N}——这个 N×N 矩阵要写到 HBM(GPU 主显存)。然后再读出来做 softmax,再写回去;再读出来乘 V。每一步都是 HBM 读写,而 HBM 带宽远低于 GPU 算力。

Flash Attention(Tri Dao 等)的核心 insight:不要把 N×N 矩阵物化到 HBM。把 Q、K、V 切成小块(tiles),每次只算一小块的 attention,结果直接累加到 SRAM(GPU 上的缓存),最后才写回 HBM。

这把 HBM 访问从 O(N²) 降到 O(N),墙上时间砍 2-4 倍。但理论 FLOPs 不变——它没有降低算法复杂度,只是把同样的算法跑得更快。

flowchart LR
  PLAIN["朴素实现<br/>QK^T 物化到 HBM<br/>多次读写"] --> SLOW["HBM 访问 O(N²)"]
  FLASH["Flash Attention<br/>分块 tile<br/>累加到 SRAM"] --> FAST["HBM 访问 O(N)"]

Flash Attention 是今天所有长上下文模型必备的「基础设施」——vLLM、SGLang、TensorRT-LLM 默认都用它。但它不是长上下文的「魔法子弹」,因为它没改 O(N²) 这个根本——1M context 仍然要 1T 次浮点运算,只是更高效地完成。

要真正打破 O(N²),必须改算法。

13.4 滑动窗口:局部注意力

最简单的算法改造:只让每个 token 看左边 W 个 token(W 远小于 N)。

flowchart LR
  T1[token 1 看 1] --- T2[token 2 看 1,2] --- T3[token 3 看 1,2,3] --- T4[...]
  T4 --- TW["token W 看 1..W"]
  TW --- TW1["token W+1 看 2..W+1"] --- TW2["token W+2 看 3..W+2"]

复杂度从 O(N2)O(N^2) 降到 O(NW)O(N \cdot W)——线性。Mistral 7B v0.1 用 W=4096 的滑动窗口(v0.2/v0.3 之后取消了),理论上能扩到无限长上下文。

但只用滑动窗口会丢失长距离信息——上下文里第 1 token 在第 100 token 看不到了。这显然不行——人类阅读能记得几千字之外的内容。

实际工程解决方案是滑动窗口 + 全局注意力混合:每个 token 看左边 W 个 token + 几个全局 token(比如 [CLS] 或前几个 token,所有人都能看)。这是 Longformer(Beltagy et al., 2020)和 BigBird(Zaheer et al., 2020)的核心思路。

更进一步,层间混合:低层用 attention,高层用滑动窗口,或者反过来。这种异构架构能让模型在低层捕捉局部模式、在高层做全局聚合,复杂度仍然亚二次方。

13.5 稀疏注意力:选择性看

更精细的方向:让 attention 只看「重要的」位置——稀疏化但保留长距离能力。

具体策略很多:

稀疏化能把 attention 从 O(N²) 降到 O(N · log N) 甚至 O(N)。但代价是实现复杂、训练不稳定、性能略损失——所以业界主流 7B-70B 模型仍然用 dense attention,靠 Flash Attention 把 O(N²) 跑得够快。只有真正巨大的上下文(1M+)场景才会上稀疏化。

13.6 替换 Attention:Mamba / RetNet

最激进的方向:抛弃 attention,换成线性复杂度的架构

Mamba(Gu & Dao, 2023)基于状态空间模型(State Space Model, SSM),核心是一个递归式:

ht=Aht1+Bxth_t = A h_{t-1} + B x_t yt=Chty_t = C h_t

看起来很像 RNN——但它有几个关键改进:

  1. A、B、C 是输入相关的(selective)——不像 LSTM 是固定参数,Mamba 让 A、B、C 根据当前 token 动态变化。
  2. 可以并行计算——通过特殊的「parallel scan」算法,整个序列能在 O(N) 时间内并行算(不像传统 RNN 必须串行)。
  3. 状态向量是高维(几百维),有足够容量存信息。
flowchart LR
  X[x_1] --> H1[h_1]
  X2[x_2] --> H2
  X3[x_3] --> H3
  X4[x_T] --> HT[h_T]
  H1 --> H2
  H2 --> H3
  H3 --> DOTS[...]
  DOTS --> HT
  HT --> Y[y_T]
  
  RT["复杂度 O(N) (vs Transformer O(N²))"]

Mamba 在 1M-2M context 下计算成本只有 Transformer 的几分之一。但它仍有局限:

主流落地是 Hybrid 架构:把 Mamba block 和 Transformer block 交替堆叠。Jamba(AI21Labs, 2024)、Zamba(Zyphra, 2024)、Mamba-2(Dao & Gu, 2024)都是这条路线的代表。在长上下文下,Mamba block 处理大部分计算(线性复杂度),少数 Transformer block 提供精确召回能力。

第 19 章会更详细讲这条「Transformer 之后」的路线。

13.7 墙二:KV Cache 的 O(N) 显存

attention 的计算可以靠 Flash Attention 加速,但KV Cache 的显存开销没法靠工程消掉——它就是和上下文长度成正比。

回顾 3.6 节我们的反事实:如果 Llama-3 70B 仍用 MHA,4K context 下 KV Cache 约 10 GiB。在 128K 下线性放大:

10 GiB×128/4=320 GiB10 \text{ GiB} \times 128/4 = 320 \text{ GiB}

——超过 80GB 的 H100 单卡显存 4 倍!实际 Llama-3 70B 用的是 GQA(hkv=8h_{kv}=8),4K 下 KV Cache 1.25 GiB、128K 下 40 GiB——仍然超过单卡,必须多卡分布式才放得下。

下面这张图把 MHA / GQA / MLA 三种方案在 1K 到 1M 上下文下的 KV Cache 占用量画在一起(双对数轴):

KV Cache 随上下文长度增长 · 三种 attention 变种对比

横线是单卡 H100 的 80 GiB 显存上限。可以看到:MHA 反事实下 32K 就触顶单卡,GQA 把红线推到 256K 才触顶,MLA(DeepSeek-V3 的设计)一路压到 128K 仍只用 8.6 GiB。架构层面的 KV 压缩是长上下文真正的工程杠杆,纯靠 PagedAttention / 量化是无法跨越这种数量级差异的。

主流减小 KV Cache 的方法:

方法 1:GQA / MQA(第 3.6 节)——减少 KV head 数。Llama-3 用 8 个 KV head 把 KV Cache 砍 8 倍。

方法 2:MLA(Multi-Head Latent Attention,DeepSeek-V2/V3)——把 KV 投影到一个低维 latent 空间,KV Cache 只存这个低维 latent。比 MHA 小 5-10 倍

具体地,MLA 引入一个低秩压缩:

ct=WDKVxtRdcc_t = W_{DKV} \cdot x_t \in \mathbb{R}^{d_c}

其中 dcd_c 远小于 dmodeld_{\text{model}}(典型 512)。然后从 ctc_t 解压出 K、V:

kt(i)=WUK(i)ct,vt(i)=WUV(i)ctk_t^{(i)} = W_{UK}^{(i)} \cdot c_t, \quad v_t^{(i)} = W_{UV}^{(i)} \cdot c_t

KV Cache 只存 ctc_t(每个 token 一个 dcd_c 维向量),不存 kt(i),vt(i)k_t^{(i)}, v_t^{(i)}(推理时按需解压)。

DeepSeek-V3 上 MLA 让 KV Cache 比 GQA 还小 5 倍——是它能跑出 671B 总参数 + 128K context 的关键。

方法 3:KV Cache 量化——把 KV Cache 用 INT8 / INT4 存储,显存减半到减 4 倍。代价是引入量化误差,但 KV 对量化的容忍度比权重高,4-bit KV 量化几乎无损。

方法 4:PagedAttention(vLLM)——不减 KV 总量,但解决「碎片化」问题。后面第 15 章会专门讲。

方法 5:Sliding Window 配合 Cache 退役——只缓存最近 W 个 token 的 KV,旧的丢弃(牺牲长距离记忆)。Mistral 7B v0.1 用这条;后续 Mistral 版本(v0.2 起)和 Mixtral 都已经把 sliding_window 关掉了,因为 Flash Attention + GQA 让全 attention 仍然可负担。

13.8 墙三:位置编码外推

这一墙在第 4 章已经讲过——RoPE 在训练时见过 pos[0,Ttrain)pos \in [0, T_{\text{train}}),但实际推理时 pospos 远超训练范围

复习一下三种处理方法:

  1. Position Interpolation(PI):把 pospos 线性压缩回训练范围。简单粗暴有效。
  2. NTK-Aware:只对低频段插值,高频段不动。保留局部精度。
  3. YaRN:分频段处理,是 PI 和 NTK 的综合。Llama-3.1 把上下文从 8K 扩到 128K 用的是 Meta 自己发明的 frequency-based scaling(HF transformers 里的 rope_type: llama3,参数 factor=8.0, low_freq_factor=1.0, high_freq_factor=4.0),思路与 YaRN 同源(按频率分段处理)但实现细节不同——不是严格意义的 YaRN。
  4. 训练原生长上下文:直接用大 base、大 context 训练(如 Llama-3 base=500000)。

工程上的实操路径:

flowchart LR
  T1["训练阶段 1<br/>普通 base, 8K context<br/>大量数据"] --> T2["训练阶段 2<br/>放大 RoPE base<br/>32K context<br/>少量长样本"]
  T2 --> T3["训练阶段 3 (可选)<br/>YaRN 拓展<br/>128K context<br/>更少更长样本"]

这种「先短后长」的 curriculum 训练让模型在每个阶段都从短上下文开始(数据多、训练稳定),再用少量长样本扩展。Llama-3.1 / Qwen 2.5 等模型用的就是这套流程。

13.9 训练长上下文的工程现实

把所有墙加起来后,训练长上下文模型有几条主线工程经验:

经验 1:长上下文训练数据稀缺

互联网上「连贯长文档」(10K+ token 的连续 token)远少于短文档。GitHub 代码、长篇小说、学术论文、长对话才符合标准。需要专门收集。

经验 2:长上下文 batch size 必然小

context = 128K + batch = 4 已经是极限——长 sequence × 大 batch 会让 attention 矩阵爆炸。一般把 batch token 总数做成常数(比如 4M),长 context 下 batch_size 很小但每条 sequence 长。

经验 3:训练算力要专门做长 context 优化

Flash Attention 必须开。Sequence parallelism(把同一个序列切到多张 GPU)也常用。DeepSeek、Llama 等团队都有针对长 context 的特殊训练 stack。

经验 4:RoPE base 提前调好,避免后期插值

如果你计划支持 128K,训练时 RoPE base 从 10000 调到 500000 或更大——这样训完不需要插值就能用,避免插值带来的精度损失。

经验 5:长 context fine-tune 比从头训省

实践中,先用普通 base + 普通 context 训完模型,再用少量长 context 数据 fine-tune 几百到几千步——这种 staged 训练比从头长 context 训便宜得多,效果接近。

13.10 评估:能放 ≠ 会用

「我们的模型支持 1M context」这种声明很多。但支持 1M 不等于能利用 1M。

最经典的评估叫 Needle-in-a-Haystack (NIAH)——「干草堆里找针」:

这是评估「模型在 100K context 中精确召回任意位置的信息」的能力。

flowchart LR
  CTX[10 万 token 长文档<br/>大部分是无关内容] --> NEED["插入针:<br/>The best snack is sandwich"]
  NEED --> Q[结尾问: best snack?]
  Q --> M[模型]
  M --> A1["✓ sandwich (回答正确)"]
  M --> A2["✗ apple / 不知道 (失败)"]

NIAH 的可视化通常画成一张热力图:横轴是 context 长度(4K, 8K, 16K, ..., 1M)、纵轴是针的位置(0% 到 100%)、单元格颜色表示成功率。真正的好长上下文模型应该是『全绿』——所有位置、所有长度都召回成功

在 NIAH 上:

更进阶的评估:

Lost in the Middle

研究还发现一个现象:模型对 prompt 中部的信息利用最差——开头和结尾召回好、中部召回差。这叫 "Lost in the Middle"(Liu et al., 2023)。

直觉解释:因果 attention 让所有位置看左边的 token——结尾位置自然能看到所有内容;开头位置因为 RoPE 的衰减性质对自己附近最敏感;而中部既不在开头(不被特别强调)又不在结尾(不被回看)——容易被忽略。

工程实践应对:

13.11 主流模型长上下文能力对比

汇总一张表(数据截至 2025 年中期,长上下文能力是各家迭代最快的维度——每隔几个月就有刷新,下面的对比仅作量级参考,具体数字请按你写代码当下的最新公开 benchmark 复核):

模型 声称上下文 NIAH 实际可用 主要技术
GPT-4 Turbo 128K 128K 训练 + 工程优化
GPT-4o 128K 128K 同上
Claude 3.5 Sonnet 200K 200K Anthropic 黑盒
Gemini 1.5 Pro 1M (2M 限选) ~1M(深度逐渐降) 长 context 训练 + 工程
Llama-3 70B 8K 8K 标准 RoPE
Llama-3.1 70B 128K 128K YaRN-style 频率缩放 + 长 context fine-tune
Mistral 7B v0.3 32K 32K RoPE + 训练扩展
Mistral Large 2 128K 128K 同上
Qwen 2.5 72B 128K 128K 长 context 训练
DeepSeek-V3 128K 128K MLA + 长 context fine-tune
Yi-34B-200K 200K 200K 长 context fine-tune(基于 Yi-34B)
Mamba-2 1M+(理论) < Transformer SSM 架构

可以看到:

  1. 闭源 frontier 模型领先一档——Gemini 1.5 的 1M 和 Claude 3.5 的 200K 暂时无开源对手。
  2. 开源主流卡在 128K——再往上需要专门工程,Yi、Qwen 是少数推到 200K+ 的开源选手。
  3. 声称和实际有差距——一些声称 128K 的开源模型在 50K 之后就不可靠,要看 NIAH 才能验证。

13.12 用户视角:什么时候真的需要长上下文

理论讲完,回到工程现实——普通用户在什么场景下真的需要长上下文?

确实需要长上下文的场景

  1. 整篇长文档 QA:技术报告、合同、论文。100K-200K 是常见尺度。
  2. 整个代码库重构 / 分析:一个中型项目几十万 token,128K 起步。
  3. 长对话历史的 Agent:Agent 多轮 tool call 后历史上下文可能 50K+。
  4. 多模态视频理解:1 小时视频转 token 后是 200K-2M。
  5. 批量数据处理 in prompt:把几千行 CSV 塞进去做分析。

不需要长上下文的场景(用 RAG 更便宜更准):

  1. 从大量文档里查信息:100 万文档检索几条 → 用 vector DB / RAG,不是塞进 prompt。
  2. Chat 助手:普通对话上下文 < 16K 完全够用。
  3. 代码补全:当前文件 + 几个相关文件,< 32K 已经覆盖。

经验法则:先用 RAG,不行再上长 context。RAG 单次推理 1K-4K token,长 context 单次 100K,成本差 30-100 倍

13.13 长上下文成本

最后看看长上下文的钱账。GPT-4 Turbo 的 API 定价 10/Minput+10/M input + 30/M output token。一次 128K context 的请求:

每次调用花 1-2 美元。100 次调用 100-200 美元——和一杯咖啡的价格不再可比。

Gemini 1.5 Pro 的 1M context:

一次 4 美元——但 1M context 能塞下整本《战争与和平》。这种价格让一些以前不可行的应用变可行(比如「让 AI 读完一本书后回答问题」),但成本仍然不便宜。

工程优化:

本章小结

  1. 长上下文不是单一技术——是 attention 算法、KV Cache 工程、位置编码外推三条线综合作战。
  2. Attention 的 O(N²) 是核心瓶颈——Flash Attention 改实现不改算法,滑动窗口 / 稀疏 / Mamba 改算法换来线性复杂度。
  3. KV Cache O(N) 显存增长是必然要面对的——GQA、MLA、量化、PagedAttention 都在压它。
  4. 位置编码外推——RoPE base 调整、PI、NTK、YaRN 各自适用不同场景。Llama-3.1 用 frequency-based scaling(YaRN 思路的 Meta 改造版)推到 128K。
  5. 训练 curriculum:先短后长,少量长样本 fine-tune 比从头训便宜。
  6. NIAH 评估暴露真假长上下文:能放 ≠ 会用。Lost in the Middle 是常见失败模式。
  7. 闭源 frontier 领先:Gemini 1M、Claude 200K 暂无开源对手;开源主流卡在 128K。
  8. 应用边界:先 RAG,不行再上长 context——成本差 30-100 倍。
  9. Prefix caching 是降低长上下文成本的关键工程优化。

到这里第五部分完结。我们用三章把规模化的全部维度走了一遍:Scaling Laws(怎么花算力)、MoE(怎么放大容量)、长上下文(怎么扩窗口)。

下一部分(第 14-18 章)是这本书最重的一块——推理系统。我们要从「模型怎么训」切到「模型怎么跑给用户用」。第 14 章先把推理的基本几何讲清楚——Prefill 和 Decode 是两个完全不同的阶段,硬件瓶颈、batch 行为、优化策略都不一样。

延伸阅读