第6章 YaRN RoPE 与 1M 长上下文工程

“Position is to a transformer what gravity is to physics—invisible but indispensable.” —— Tri Dao

V4 的位置编码不仅要支持 1M 上下文,还要在 65K 训练 / 1M 推理之间保持数学一致——这是 YaRN 给出的工程答案。


6.1 引子:65K 训练的模型怎么跑 1M 推理

V4 的 config.json 里有这么一段:

"max_position_embeddings": 1048576,
"rope_scaling": {
  "type": "yarn",
  "factor": 16,
  "original_max_position_embeddings": 65536,
  "beta_fast": 32,
  "beta_slow": 1
}

读这段配置:

  • max_position_embeddings = 1048576(1M)—— 推理时支持的最大上下文
  • original_max_position_embeddings = 65536(64K)—— 训练时实际见过的上下文
  • factor = 16 —— 1M / 64K = 16 倍外推

V4 训练时只见过 64K 长度的序列——但推理时要支持 1M。这种 16 倍的外推不是凭空的——它依赖一个叫 YaRN 的算法。如果你直接把 1M 长度的 input 喂给一个只训过 64K 的 RoPE 模型,会出现”位置编码外推塌陷”——模型在 64K 之外的位置上会输出乱码。

YaRN(“Yet Another RoPE Extension Method”,arXiv:2309.00071)通过频率插值让 RoPE 在未训练区间也能稳定工作。本章拆 V4 是怎么把 YaRN 工业化的——以及 V4 在主 attention 与压缩 KV 上分别用了什么 RoPE 配置。


6.2 RoPE 复习:为什么要外推

RoPE(Rotary Position Embedding)的核心思想:把每个 head 的最后 64 维(rope_head_dim=64)拆成 32 对,每对用复指数旋转编码位置。

具体地,对位置 m,第 i 对的旋转角度是:

θm,i=mβ2i/d,i=0,1,...,d/21\theta_{m,i} = m \cdot \beta^{-2i/d}, \quad i = 0, 1, ..., d/2-1

其中 β\beta 是 base(V4 主 attention = 10000,压缩 KV = 160000),dd 是 rope_head_dim = 64。

旋转作用在 q 和 k 的最后 64 维上——把位置信息”嵌入”到向量本身。Attention 计算 qTkq^T k 时,不同位置之间的”相对位置”通过旋转角度差自然出现。

RoPE 的训练数据假设:模型在训练时见过的位置 m 的旋转角度分布。如果训练只到 64K,模型从来没见过 m=1M 的旋转——这时模型在 attention score 里看到一个”陌生角度”,会输出乱码。

外推的本质问题:当 m 从 64K 推到 1M,对每个频率维度 i,θm,i\theta_{m,i} 也变大 16 倍。低频维度(i 大)的角度差变化可能跨越多个 2π2\pi,相对位置感失真。


6.3 YaRN 的核心思想:频率插值

YaRN 的解决方案:给频率分维度做插值——

  • 高频维度(i 小,旋转快,对短距离敏感):保持原始 rope_theta,因为短距离的位置关系训练时见过
  • 低频维度(i 大,旋转慢,对长距离敏感):把频率除以 factor=16,让 1M 位置看起来像”原始 64K 的位置乘以 16”——模型已经训过的角度范围

中间过渡区用 linear ramp(beta_fast / beta_slow 控制)平滑插值。

flowchart LR
  subgraph 频率分布["维度 i (从高频到低频)"]
    direction LR
    HF[高频<br/>i 小]
    Mid[过渡区<br/>beta_fast → beta_slow]
    LF[低频<br/>i 大]
  end
  subgraph 处理["YaRN 的处理"]
    HF --> Keep["保持原 freq"]
    Mid --> Ramp["线性插值<br/>(1 - smooth) / factor + smooth"]
    LF --> Scale["freq / factor"]
  end

YaRN 的论文证明了这种”分频率插值”在外推上比 NTK-aware、Linear interpolation 等先前方法更稳定——尤其在外推因子 > 8 时。V4 的 16 倍外推完全在 YaRN 论文给出的”安全区”内。


6.4 V4 源码里的 precompute_freqs_cis

V4 在 inference/model.py 里实现了带 YaRN 的 RoPE 预计算函数:

@lru_cache(2)
def precompute_freqs_cis(dim, seqlen, original_seq_len, base, factor, beta_fast, beta_slow) -> torch.Tensor:
    """Precomputes complex exponentials for rotary embeddings with YaRN scaling."""
    def find_correction_dim(num_rotations, dim, base, max_seq_len):
        return dim * math.log(max_seq_len / (num_rotations * 2 * math.pi)) / (2 * math.log(base))

    def find_correction_range(low_rot, high_rot, dim, base, max_seq_len):
        low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len))
        high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len))
        return max(low, 0), min(high, dim-1)

    def linear_ramp_factor(min, max, dim):
        if min == max:
            max += 0.001
        linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
        ramp_func = torch.clamp(linear_func, 0, 1)
        return ramp_func

    freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
    if original_seq_len > 0:
        low, high = find_correction_range(beta_fast, beta_slow, dim, base, original_seq_len)
        smooth = 1 - linear_ramp_factor(low, high, dim // 2)
        freqs = freqs / factor * (1 - smooth) + freqs * smooth
    t = torch.arange(seqlen)
    freqs = torch.outer(t, freqs)
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
    return freqs_cis

这函数可以拆成 4 步:

步骤 1:算原始 RoPE 频率

freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))

freqs[i] = 1 / base^(2i/dim)——dim/2 个频率值,从高到低。

步骤 2:YaRN 频率插值(仅当 original_seq_len > 0)

low, high = find_correction_range(beta_fast, beta_slow, dim, base, original_seq_len)
smooth = 1 - linear_ramp_factor(low, high, dim // 2)
freqs = freqs / factor * (1 - smooth) + freqs * smooth

find_correction_range 计算”高频维度的截止”和”低频维度的截止”。smooth 在这两个截止之间做线性 ramp(高频 = 1,低频 = 0)。频率被按 smooth 插值——高频用原 freqs,低频用 freqs/factor。

步骤 3:构造完整的位置×频率矩阵

t = torch.arange(seqlen)
freqs = torch.outer(t, freqs)

freqs[m, i] = m * freqs[i] ——每个位置 m 在每个频率 i 上的旋转角度。

步骤 4:转复指数

freqs_cis = torch.polar(torch.ones_like(freqs), freqs)

torch.polar(magnitude=1, angle=θ) 返回 cos(θ) + i sin(θ) 的复数张量。

@lru_cache(2) 装饰器让这个函数最多缓存 2 个不同参数组合——V4 实际使用 2 套不同的 freqs_cis(主 attention + 压缩 KV)。


6.5 V4 的双 RoPE:主 attention 与压缩 KV 各自一套

V4 在 Attention.__init__ 里有这一段不显眼但极关键的代码:

if self.compress_ratio:
    original_seq_len, rope_theta = args.original_seq_len, args.compress_rope_theta
else:
    original_seq_len, rope_theta = 0, args.rope_theta

freqs_cis = precompute_freqs_cis(self.rope_head_dim, args.max_seq_len, original_seq_len,
                                 rope_theta, args.rope_factor, args.beta_fast, args.beta_slow)
self.register_buffer("freqs_cis", freqs_cis, persistent=False)

注意分支逻辑:

  • 有压缩段的层(compress_ratio > 0:用 compress_rope_theta = 160000,启用 YaRN(original_seq_len > 0
  • 没压缩段的层(compress_ratio = 0,仅末层):用 rope_theta = 10000不启用 YaRNoriginal_seq_len = 0

这两个 RoPE 的差异不只是 base 不同,更关键的是:V4 给 KV 在压缩前后用不同的位置编码

为什么压缩段要用 base=160000?想象一下:滑窗段的 128 个 token 是连续的,每相邻 token 的位置差是 1。压缩段的每个组代表 ratio=4 个 token,相邻组的位置差是 4 个 token——但在压缩 KV 的”组索引”维度上,相邻组的索引差是 1。位置编码必须感知”压缩组之间的真实物理距离”

V4 的处理:用 compress_rope_theta = 160000 = 16 × 10000——把频率提高 16 倍,让相邻”组索引差 1”在 RoPE 角度上对应”位置差 16”。这样压缩段的位置编码与滑窗段在数学上保持一致。

MTP 层(compress_ratios 第 62 个元素,ratio=0)的特殊处理是:因为它不压缩,所有 KV 都按 token 索引存,不需要”组索引修正”——所以用标准 rope_theta=10000、不启用 YaRN(不外推,因为它的 KV 用的是真实位置)。注:主模型 61 层都是 ratio=128 或 ratio=4,没有 ratio=0 的层。

flowchart TD
  subgraph 主attention层["有压缩段的层 (ratio > 0)"]
    A1["KV 滑窗段<br/>位置 = 真实 token 位置"]
    A2["KV 压缩段<br/>位置 = 组索引 × ratio<br/>用 compress_rope_theta=160000"]
    A1 -->|不同位置编码| Cat[拼接 KV]
    A2 -->|不同位置编码| Cat
  end
  subgraph MTP层["MTP 层 (ratio = 0,主模型外)"]
    B1["所有 KV<br/>位置 = 真实 token 位置<br/>用 rope_theta=10000"]
  end

6.6 1M 上下文下的位置精度问题

YaRN 让 1M 外推数学上稳定,但工程上还有一个问题:位置编码的浮点精度

考虑位置 m=1,048,575 的最高频维度,旋转角度 θ=m1000020/64=1048575\theta = m \cdot 10000^{-2 \cdot 0 / 64} = 1048575(弧度)。这个数除以 2π2\pi 后取余数,是真正的”有效旋转角度”。

torch.outer(t, freqs) 计算 m * freq 在 float32 下,当 m 很大时,lsb 精度丢失——float32 对大数的精度只到 ~10^-7 量级。1M / 1e7 = 0.1——意味着旋转角度的精度只到 0.1 弧度!

V4 的 precompute_freqs_cis 全部用 dtype=torch.float32 计算(注意函数里 torch.arange(0, dim, 2, dtype=torch.float32)torch.arange(seqlen) 默认 int64),最后转复指数。这套精度对 1M 大致够用,但已经在 float32 的极限附近。

如果你想跑 V4 到 16M 甚至更长,可能要改成 float64 计算——V4 当前的实现没有为这种极端情况做特殊处理。

@lru_cache(2) 装饰器的另一个作用是:把 freqs_cis 的计算结果缓存,避免每次 forward 都重算 1M 个复指数。这个缓存对长上下文推理至关重要——freqs_cis 一次性算 1M × 32 个 complex = 256 MB(float32 + 复数),重算成本不低。


6.6·补 双 RoPE 的工程权衡:为什么不用单一 base

V4 的”主 attention base=10000 + 压缩 KV base=160000” 看起来复杂,能不能简化?理论上可以——但每种简化都有代价:

简化方案 A:全部用 base=10000(V3 风格)

如果压缩段也用 base=10000,那么相邻压缩组的位置编码差异是”组索引差 × 10000^(-2i/d)“——但实际上相邻组的真实物理距离是 ratio=4 个 token。压缩段会”压缩”位置感——模型把”间隔 4 token”误解成”间隔 1 token”,长距离信息组织失真。

简化方案 B:全部用 base=160000

如果主 attention 也用 base=160000,那么滑窗段相邻 token 的位置编码差异会变得”过细”——高频维度的旋转过快,模型在 short-range(128 token 内)的位置感知反而下降。

V4 的双 base 是必要妥协

  • 滑窗段用 base=10000 保证 short-range 精度(短上下文性能)
  • 压缩段用 base=160000 = 16×10000 让”组索引差 1” 在 RoPE 上对应”位置差 16”——保持物理距离的一致性

数学上这个 16 倍因子的来源:每个压缩组代表 ratio=4 个 token,但压缩组在 KV cache 里相邻存储,相邻组的 KV cache 索引差 1。如果用 base=10000 + 组索引,相当于把 ratio=4 的物理距离映射到位置差 1 ——失真。用 base=160000 让”组索引差 1 ≈ 位置差 16”,与 ratio=4 的乘以 4 倍相比仍偏,但在 16 倍数学上能更好覆盖到 1M 范围。

更深入的优化方向是为每层独立选 base(与 compress_ratio 联动),但 V4 选了全模型只用一对 base——简化工程实现。


6.6·补·补 freqs_cis 的内存账本

freqs_cis 是个被频繁忽略的内存大户。让我们算一下 V4 在 1M context 下的 freqs_cis 总占用:

主 attention 的 freqs_cis:

  • 形状:[max_seq_len=1048576, rope_head_dim/2=32]
  • dtype:torch.complex64(8 字节 per element)
  • 大小:1048576 × 32 × 8 = 256 MB

压缩 KV 的 freqs_cis:同样大小 256 MB。

每层 Attention 都注册了一份这两个 freqs_cis 的 buffer——但因为 register_buffer(persistent=False) 且 PyTorch 对相同 tensor 的 register_buffer 会共享存储,实际只占 2 × 256 MB = 512 MB 总量。

512 MB 听起来不多,但在 1M context + 多 GPU 场景下,每张卡都要保存这份 freqs_cis——8 卡部署累积 4 GB 显存只用于位置编码。这是 V4 长上下文部署中容易被忽略的内存项。

实践中,@lru_cache(2) 装饰器让 freqs_cis 在多层之间共享同一份计算结果——避免每层重算。这是 V4 源码里看似不起眼但极重要的优化。

如果你想把 V4 部署到更长上下文(如 2M),这部分内存会线性增长——4 GB 主 RoPE + 4 GB 压缩 RoPE = 8 GB / GPU。这时候可能需要把 freqs_cis 改成”按需生成”(只算本次 forward 用到的位置),而不是预生成全部 max_seq_len。


6.7 与其他外推方案的对比

YaRN 不是唯一的外推方案。把它与同期方案比一下:

方案思路外推上限V4 是否使用
Linear interpolation把所有频率统一除以 factor~4 倍
Position interpolation把位置 m 缩到 m / factor~4 倍
NTK-aware改 base:base × factor^(d/(d-2))~8 倍
YaRN分频率插值 + 注意力温度调节~16-32 倍
LongRoPE进化算法搜索每维度的最优插值因子~32-64 倍
Pose训练时随机抽长位置~16 倍

V4 选 YaRN 是因为:

  • 它在 16 倍外推的稳定性已被 V3 / V3.2-Exp 验证
  • 论文公开 + 工程实现简单
  • 不需要训练时做特殊处理(与 LongRoPE 对比)

YaRN 的代价是:在 32 倍以上外推时精度会下降——所以 V4 只支持到 1M,没有像 Gemini 那样宣称 10M 上下文。


6.8 章末:YaRN 在生产中的实际表现

把 YaRN + V4 在生产中的表现总结成一个对照表:

上下文长度是否 YaRN 训练区V4 表现工程建议
≤ 64K训练区内全精度,无外推损失完全可用,正常温度
64K-256K浅外推接近原生可用,温度可微调
256K-1M深外推有轻微外推损失可用但建议在 prompt 工程上重复关键信息
> 1M超出 max_position直接报错需要重训或换更大 max_position 的版本

256K-1M 这个区间是 V4 的”长上下文红利区”——比对手的 128K 长但又仍在 YaRN 安全外推内。这是 V4 在 RAG / 仓库代码理解 / 长文档分析等场景下的核心竞争力。


6.9 动手实验:可视化 YaRN 频率插值

import torch
import math
import matplotlib.pyplot as plt

def find_correction_dim(num_rotations, dim, base, max_seq_len):
    return dim * math.log(max_seq_len / (num_rotations * 2 * math.pi)) / (2 * math.log(base))

def linear_ramp(low, high, dim):
    return torch.clamp((torch.arange(dim) - low) / (high - low), 0, 1)

dim = 64
base = 10000
factor = 16
original = 65536
beta_fast = 32
beta_slow = 1

orig_freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
low = max(int(find_correction_dim(beta_fast, dim, base, original)), 0)
high = min(int(find_correction_dim(beta_slow, dim, base, original)) + 1, dim - 1)
smooth = 1 - linear_ramp(low, high, dim // 2)
yarn_freqs = orig_freqs / factor * (1 - smooth) + orig_freqs * smooth

plt.semilogy(orig_freqs.numpy(), label='Original RoPE')
plt.semilogy(yarn_freqs.numpy(), label='YaRN (factor=16)')
plt.xlabel('Frequency index i')
plt.ylabel('Frequency value (log scale)')
plt.title('V4 main RoPE vs compressed KV RoPE')
plt.legend()
plt.show()

跑这段代码,能直观看到 YaRN 的”高频保持、低频缩放”行为。把 base 改成 160000 再跑一遍,能看到压缩 KV 的 RoPE 频率分布(整体频率高 16 倍)。


6.9·补 双 RoPE 与 vLLM 的接口冲突

V4 的双 RoPE(base=10000 + base=160000 共存)与 vLLM 的传统 attention backend 有一个微妙的接口冲突——这是部署时容易踩的坑。

vLLM 的 attention backend 假设”一个模型用一组 RoPE 参数”——它把 freqs_cis 算一次、缓存一次,所有层共享。V4 的两套 RoPE 让这个假设失效——必须让每一层 attention backend 知道”自己用哪一套 RoPE”。

具体怎么改 vLLM 适配?有两种路径:

路径 1:每层独立 freqs_cis

在 vLLM 的 ModelRunner 里给每一层 attention backend 注入它对应的 freqs_cis——压缩段层用 base=160000 的 freqs_cis,末层用 base=10000 的。这种改造要修改 PagedAttention backend 的 metadata 结构。

路径 2:把 RoPE 计算移到模型代码内

让 vLLM 的 attention backend 仅算”无 RoPE 的 attention”,RoPE 在模型 forward 中显式调用。这种做法接近 V4 的 inference/model.py——RoPE 是模型代码的一部分,attention backend 只负责 attention 本身。

V4 的 vLLM 适配 PR 大概率走路径 2——它与 inference/model.py 的代码结构对齐,迁移成本最低。

这个细节也展示了为什么 vLLM 适配 V4 不是”5 行 PR”——必须深入到 attention backend 的接口细节。具体的 vLLM PR 改动会在《vLLM 推理内核深度解析》第 14 章后续更新中详细展开。


6.10 延伸阅读


6.10·补 1M 上下文的真实应用场景

V4 的 1M 上下文不是营销噱头——它对应几类真实有商业价值的应用。把这些场景列出来:

场景 1:法律 / 合同分析

一份完整的并购合同 + 历史修订版本 + 相关法规可能 200K-500K token。V4 的 1M context 可以一次处理完整合同包,不需要拆段后再合并。

场景 2:仓库级代码理解

中型项目的源码 50K-300K token,加上文档 / issue 历史可能 500K-1M。V4 可以在一个 context 内”看懂整个项目”——回答如”这个 bug 涉及哪些模块”、“重构 X 类需要改哪些文件”等问题。

场景 3:长篇研究文献综述

100 篇相关论文每篇平均 5K token = 500K token。V4 可以一次性读完,做综述总结、引用关系分析、共识 / 分歧梳理。

场景 4:长视频字幕分析

一部 2 小时视频的字幕 + 元数据约 30K-50K token。V4 可以在 1M 内同时处理多部相关视频——做跨视频的主题分析。

场景 5:客服历史对话上下文

一个用户的客服对话累积 6 个月可能 100K-500K token。V4 可以把全部历史作为 context,不需要做摘要丢失细节——客服模型对这个用户的理解更深入。

场景 6:生物信息基因组分析

某些基因组分析任务需要在长序列上找模式——1M 上下文可以覆盖大型基因或多基因区域。

这些场景的共同点:短上下文模型必须做”摘要 / 检索”才能处理,但摘要 / 检索会损失细节。V4 的 1M 让”完整原始信息直接送给模型” 成为可能——质量上限被解锁。

理解这些场景让你判断”V4 是否真的适合你的产品”——如果你的应用永远是 short context,V4 的 1M 用不到,可以选更便宜的模型;如果你的应用本质是长上下文,V4 可能是当下唯一开源选择。


6.10·补·补 长上下文模型的”3 个评估指标”

部署 V4 用于长上下文场景时,需要专门的评估指标——不能只看常规 benchmark。把 3 个最重要的长上下文指标解释清楚:

指标 1:Needle-in-a-Haystack (针在草垛中)

把一段不相关的”针”(如某个特殊数字 / 句子)插入长 context 的不同位置,让模型答出针的内容。这个测试评估”模型是否真的看到了所有 context”——而非只看开头结尾。

V4 在 1M 下的 Needle 测试需要在多个位置(10%、30%、50%、70%、90%)测试——每个位置应该都能高精度回答。如果某些位置回答差,说明 Compressor / Indexer 在该位置的稀疏选取不准。

指标 2:RULER (Ruler benchmark)

更系统的长上下文测试——多任务(针、变量追踪、问答、聚合)在不同 context 长度上的成绩。Ruler 给出一个综合分数,可以与其他长 context 模型对比。

V4 在 RULER 1M 下的成绩是它”对得起 1M 标签” 的硬证据——具体数字等 V4 GA 后社区独立测评。

指标 3:长输出连贯性

不只是输入长——某些任务要求输出长。比如让 V4 写一篇 50 万字的小说、生成 100 万字的代码 / 文档。输出长度增加会累积稀疏 attention 的误差——评估输出连贯性是关键。

测试方法:让模型生成长输出后,逐段评估”是否与开头一致 / 角色名是否漂移 / 逻辑链是否断裂”。

这 3 个指标在 V4 部署时必须监控——任何一个退化都说明长上下文能力在劣化。常规 benchmark(MMLU / HumanEval)跑得再好,也不能保证长 context 表现。


6.10·延展 V4 在 RAG 场景下的”上下文优化”

V4 的 1M 上下文让 RAG 模式发生本质变化——这部分对部署 V4 的工程师极重要。

传统 RAG 的限制

传统 RAG 把召回的 chunks 拼成 < 32K context 喂给模型——召回数限制在 5-20 个。结果:召回不全 → 答案不完整。

V4 RAG 的新模式

V4 的 1M 让你召回 100-1000 个 chunks 不是问题——召回率几乎可以拉满。RAG 设计的重心从”精召回” 转向”全召回 + 智能排序”。

新的工程模式

  • 减少向量数据库的 top-k 限制(从 20 提到 200-500)
  • 把召回结果按相关性排序(最相关放在 prompt 前后)
  • 中间放低相关 chunks 作为”上下文支持”
  • 让模型自己在 1M 内做精细推理

对系统设计的影响

之前 RAG 系统需要”召回 → 重排 → 精筛 → 拼 prompt”四步,现在可以简化为”召回 → 拼 prompt”。重排的工作转移给 V4 的 Indexer——稀疏 attention 内部就在做”哪些 KV 重要”的选择。

这种”把工作从外部排序转到模型内部” 的简化让 RAG 系统的工程量减小一半——但前提是你愿意付 V4 的部署成本。对中小项目来说可能不值;对大项目(如企业知识库)来说极有 ROI。


6.11 本章小结

  • V4 训练只到 65K,推理到 1M——16 倍外推靠 YaRN 频率插值实现
  • YaRN 把高频维度保持原 freq、低频维度除以 factor、中间用 linear ramp——既支持外推、又不损失短距离精度
  • V4 在主 attention 与压缩 KV 上用两套 RoPE:主 attention 用 base=10000 + YaRN,压缩 KV 用 base=160000 + YaRN
  • MTP 层(仅 1 层)ratio=0 用 base=10000、不启用 YaRN——给 dense KV 一个”忠实位置”
  • 1M 在 float32 下接近精度极限——更长上下文需要 float64
  • YaRN 安全外推上限约 16-32 倍,V4 1M 完全在 YaRN 安全区内

第 7 章我们离开 attention 工程,进入 V4 的另一个核心子系统:MoE Gate——sqrtsoftplus、noaux_tc、384 专家路由的故事。

评论 0