第4章 Indexer:稀疏注意力的可学路由
“Choosing what to read is half of reading itself.” —— 引自一位历史学家
V4 的 Indexer 就是模型版本的”选择阅读”——它决定了 1M context 中哪 1024 个位置值得这一步 attention 真正去看。
4.1 引子:稀疏注意力的”选择题”
V4 的稀疏注意力面对一个工程”选择题”:
已经有了 1M 长度的压缩 KV 序列(通过 Compressor 压成
n_tok / ratio组),attention 该从中选哪 top-k 组真正参与计算?
这个”选哪些”的问题,就是 Indexer 要解决的核心。它有几个看似可行但都不对的方案:
方案 A:随机选——理论上可行,但稀疏选取必须能”针对当前 query 选最相关的 KV”。随机选会丢掉关键 token。
方案 B:用主 attention 的 q · k 直接打分——但主 attention 的 q 维度是 512,dot product 成本仍然太高(1M / ratio 个位置都要算一遍)。
方案 C:用一个小型独立的 attention 给 KV 打分——这就是 Indexer 的方案。它有自己的小维度(head_dim=128,比主 attention 的 512 小很多)、自己的 KV 投影、自己的 head 数(64,比主 attention 的 128 少一半),把”打分”成本降到主 attention 的一小部分。
flowchart LR Q["主 attention 的 query<br/>[B, S, 128 heads, 512]"] --> MainAttn["主 attention 计算"] XInput["输入 x"] --> Indexer Indexer["Indexer (轻量)<br/>[B, S, 64 heads, 128]"] --> Score["score 打分"] Score --> TopK["topk(1024) 选取"] TopK --> SparseIdx["topk_idxs"] SparseIdx --> SparseAttn["sparse_attn(q, kv, sink, idxs)"] MainAttn -.数据流入.-> SparseAttn
Indexer 是个”小而专”的 attention——专门服务”选 top-1024”这一件事,不参与最终输出。
4.2 Indexer 类的源码结构
V4 的 Indexer 类(inference/model.py)的 __init__:
class Indexer(torch.nn.Module):
def __init__(self, args: ModelArgs, compress_ratio: int = 4):
super().__init__()
self.dim = args.dim
self.n_heads = args.index_n_heads # 64
self.n_local_heads = args.index_n_heads // world_size
self.head_dim = args.index_head_dim # 128
self.rope_head_dim = args.rope_head_dim # 64
self.index_topk = args.index_topk # 1024
self.q_lora_rank = args.q_lora_rank # 1536 (与主 attention 共享)
self.wq_b = ColumnParallelLinear(self.q_lora_rank, self.n_heads * self.head_dim)
self.weights_proj = ColumnParallelLinear(self.dim, self.n_heads, dtype=torch.bfloat16)
self.softmax_scale = self.head_dim ** -0.5
self.compress_ratio = compress_ratio
self.compressor = Compressor(args, compress_ratio, self.head_dim, True)
self.register_buffer("kv_cache",
torch.zeros(args.max_batch_size, args.max_seq_len // compress_ratio, self.head_dim),
persistent=False)
self.freqs_cis = None
注意几个关键点:
1. q 的来源是主 attention 的 qr:
# Attention.forward 里
qr = q = self.q_norm(self.wq_a(x)) # qr: [B, S, q_lora_rank=1536]
q = self.wq_b(q).unflatten(-1, ...) # 主 attention 用的 q
...
compress_topk_idxs = self.indexer(x, qr, start_pos, offset) # qr 传给 Indexer
Indexer 的 wq_b 输入是 qr(来自主 attention 的 LoRA 中间表示),输出是 [B, S, 64, 128] 的 query。Indexer 与主 attention 共享 q_lora_rank=1536 的 wq_a 投影——这一层 LoRA 投影不重复算,节省 FLOPs。
2. Indexer 有自己的 Compressor(rotate=True):
Indexer 的压缩 KV 是从 x 直接经过自己的 Compressor 算出来的——这个 Compressor 与主 attention 的 Compressor 是不同的实例,参数独立训练,量化策略也不同(FP4 + Hadamard 旋转)。
3. weights_proj:每 head 一个标量权重:
self.weights_proj = ColumnParallelLinear(self.dim, self.n_heads, dtype=torch.bfloat16)
输出形状 [B, S, n_heads=64]——每个 token 的每个 head 一个 BF16 权重。这些权重在 forward 中作用是”head-level 的注意力 mask 缩放”——某些 head 在某个 token 上”更可信”,权重就大。
4. softmax_scale 用 head_dim=128 而非主 attention 的 512:
softmax_scale = 128 ** -0.5 ≈ 0.088——这是 dot product 标准缩放系数。Indexer 的 head_dim 比主 attention 小,意味着 q · k 的方差更小、softmax 更陡峭、topk 选取更”果断”。
4.3 forward 流程与 score 计算
Indexer 的 forward 是 V4 稀疏注意力的核心算法:
def forward(self, x: torch.Tensor, qr: torch.Tensor, start_pos: int, offset: int):
bsz, seqlen, _ = x.size()
freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen]
ratio = self.compress_ratio
rd = self.rope_head_dim
end_pos = start_pos + seqlen
# 第一次调用时建立 Indexer 自己的 Compressor 状态
if self.compressor.kv_cache is None:
self.compressor.kv_cache = self.kv_cache
self.compressor.freqs_cis = self.freqs_cis
# q: 用主 attention 的 qr 经过 Indexer 自己的 wq_b
q = self.wq_b(qr)
q = q.unflatten(-1, (self.n_local_heads, self.head_dim)) # [B, S, n_heads, head_dim]
apply_rotary_emb(q[..., -rd:], freqs_cis)
q = rotate_activation(q) # Hadamard 旋转
fp4_act_quant(q, fp4_block_size, True) # FP4 量化
# KV: Indexer 的 Compressor 把 x 压成 KV
self.compressor(x, start_pos)
# weights: 每 head 一个权重 + 缩放
weights = self.weights_proj(x) * (self.softmax_scale * self.n_heads ** -0.5)
# 打分:q 与所有压缩 KV 做 dot product,再用 weights 加权 sum 跨 head
index_score = torch.einsum("bshd,btd->bsht", q, self.kv_cache[:bsz, :end_pos // ratio])
index_score = (index_score.relu_() * weights.unsqueeze(-1)).sum(dim=2)
if world_size > 1:
dist.all_reduce(index_score)
# mask 当前 token 之后的位置
if start_pos == 0:
mask = torch.arange(seqlen // ratio).repeat(seqlen, 1) >= torch.arange(1, seqlen + 1).unsqueeze(1) // ratio
index_score += torch.where(mask, float("-inf"), 0)
# topk 选取
topk_idxs = index_score.topk(min(self.index_topk, end_pos // ratio), dim=-1)[1]
# 加 offset 把 topk_idxs 从"压缩 KV 索引"变成"绝对 KV cache 索引"
if start_pos == 0:
mask = topk_idxs >= torch.arange(1, seqlen + 1).unsqueeze(1) // ratio
topk_idxs = torch.where(mask, -1, topk_idxs + offset)
else:
topk_idxs += offset
return topk_idxs
把这段代码拆成 6 个步骤:
步骤 1:q 投影 + 旋转
q = self.wq_b(qr)
q = q.unflatten(-1, (self.n_local_heads, self.head_dim))
apply_rotary_emb(q[..., -rd:], freqs_cis)
q = rotate_activation(q)
fp4_act_quant(q, fp4_block_size, True)
qr 来自主 attention 的 q_norm 输出 [B, S, 1536]。wq_b 是 Indexer 独立的 q 投影矩阵——把 qr 投到 n_heads × head_dim = 64 × 128 = 8192 维。然后 RoPE 处理最后 64 维(rope_head_dim)。再过 Hadamard 旋转 + FP4 量化。
为什么 q 也走 FP4?因为 score 计算只是个”打分排序”,不参与最终输出——精度损失对 topk 选取的影响可以接受。q 走 FP4 让 Indexer 的 GEMM 成本降到主 attention 的 1/16(FP4 是 0.5 字节,BF16 是 2 字节)。
步骤 2:KV 压缩
self.compressor(x, start_pos)
Indexer 的 Compressor 把当前 token 段 x 压成压缩 KV,存到 self.kv_cache 里。这个 Compressor 与主 attention 的 Compressor 结构相同但参数独立——参数独立意味着 Indexer 学到的”应该压什么”与主 attention 学到的”应该存什么”是不一样的。
步骤 3:每 head 权重投影
weights = self.weights_proj(x) * (self.softmax_scale * self.n_heads ** -0.5)
weights_proj 输出 [B, S, 64]——每 head 一个标量。乘以缩放系数 softmax_scale × n_heads^(-0.5)——这两个常数是 V4 团队从经验调出来的”最佳缩放”。
步骤 4:打分
index_score = torch.einsum("bshd,btd->bsht", q, self.kv_cache[:bsz, :end_pos // ratio])
index_score = (index_score.relu_() * weights.unsqueeze(-1)).sum(dim=2)
第一行:q 形状 [B, S, 64 heads, 128],kv_cache 形状 [B, end_pos/ratio, 128]。einsum 把每个 head 的 q 与所有压缩 KV 做内积,得到 [B, S, 64 heads, end_pos/ratio]。
第二行:relu_() 砍掉负相关分(in-place);乘以 head-level 权重;最后跨 head sum——得到每个压缩 KV 位置的”总分” [B, S, end_pos/ratio]。
注意 relu_ 是 in-place 操作——节省显存。**为什么用 ReLU 而非 softmax?**因为 ReLU 让”完全不相关”的 KV 直接得分 0,topk 选取更稳定;softmax 会强行把概率分到所有位置,topk 选取的边界容易模糊。
步骤 5:mask + topk
mask = ...
index_score += torch.where(mask, float("-inf"), 0)
topk_idxs = index_score.topk(min(self.index_topk, end_pos // ratio), dim=-1)[1]
mask 把”当前 token 之后” 的位置打成 -inf——causal mask 的稀疏版本。topk 选 1024 个分最高的位置(如果总数不足 1024,全选)。
步骤 6:offset 修正
if start_pos == 0:
mask = topk_idxs >= torch.arange(1, seqlen + 1).unsqueeze(1) // ratio
topk_idxs = torch.where(mask, -1, topk_idxs + offset)
else:
topk_idxs += offset
topk_idxs 出来时是”压缩 KV 数组的索引”(0 到 end_pos / ratio)。需要加 offset 转换成”主 KV cache 的绝对索引”——offset 是滑窗段占据的位置数。-1 表示”无效位置”。
最终返回的 topk_idxs 形状 [B, S, 1024]——对每个 query token,返回它应该关注的 1024 个 KV 位置的绝对索引。这个数组随后被 cat 到滑窗 idxs,传给 sparse_attn kernel。
4.3·补 ReLU vs Softmax:score net 的代数选择
V4 的 Indexer 在打分阶段用 relu_() 而不是 softmax。这是个看起来违反直觉的选择——传统 attention 用 softmax 是因为它是”概率分布”,符合 attention weights 的语义。但 score net 不是 attention,它的输出是”排序”,用 ReLU 反而更合适。
Softmax 的特性:
- 输出是 [0, 1] 区间内的概率分布,所有位置和为 1
- 数值上”温和”——再差的位置也会分到一点概率
- topk 选取依赖”概率最大的几个”,但概率之间的差异可能很小
- 容易受温度(temperature)影响——同样的 logit 在不同温度下选出来的 topk 可能完全不同
ReLU 的特性:
- 输出是 [0, ∞) 区间内的非归一化值,负值直接砍成 0
- 数值上”果断”——完全不相关的位置得分 0,被自然过滤
- topk 选取直接基于”分数最大的几个”,分数差异越大、选取越稳定
- 没有温度参数,行为可预测
V4 选 ReLU 的具体好处:
- 稀疏选取的稳定性:ReLU 让分数”二值化倾向”更强——要么得分不为 0、要么得分为 0。topk 选取在这种分布上更不容易被噪声扰动。
- 训练梯度的清晰性:当一个 KV 位置完全不相关(softmax 仍会给 1e-10 量级概率),它在反向传播里仍会有微弱梯度——这些梯度累积起来会污染 score 网络。ReLU 直接砍掉负值,反向传播 mask 掉这部分干扰。
- 数值稳定:softmax 在 1M / ratio 个候选位置上会要算 exp(x)——长上下文下 logit 的动态范围大,softmax 容易溢出或下溢。ReLU 没有这个问题。
但 ReLU 也有代价:
- 输出不是概率分布——不能直接当作”注意力权重”用。这就是为什么 V4 把 Indexer 的输出仅用于 topk 选取,不直接参与最终 attention 输出。最终输出由主 attention 用 softmax 算。
- ReLU 本身不带”head 间归一化”——V4 通过
weights_proj给每 head 一个独立权重,再跨 head sum,间接实现 head 间的相对重要性归一。
这种”score net 用 ReLU、主 attention 用 softmax”的分工,是 V4 团队从 V3.2-Exp 的实战中得到的经验沉淀。
4.4 Indexer 与主 attention 的资源对比
把 Indexer 与主 attention 的资源占用摆出来:
| 维度 | 主 attention | Indexer | 比率 |
|---|---|---|---|
| n_heads | 128 | 64 | 1/2 |
| head_dim | 512 | 128 | 1/4 |
| q 数据类型 | BF16/FP8 | FP4 | 1/4 |
| k 数据类型 | FP8 (压缩 KV) | FP4 | 1/2 |
| 单 token GEMM | 128×512 = 65536 | 64×128 = 8192 | 1/8 |
| 与 KV 内积成本 | head_dim=512 | head_dim=128 | 1/4 |
| 总相对成本 | 1 | ≈ 1/30 | 1/30 |
Indexer 大约是主 attention 的 1/30 计算量——用 3% 的额外计算,把主 attention 的 KV 候选从 1M / ratio 压到 1024。这是 V4 稀疏注意力的核心 ROI 公式。
4.4·补 Indexer 输出维度的张量代数
把 Indexer 的张量形状变化按维度逐步追踪一遍,能彻底理解它的资源占用为什么是主 attention 的 1/30:
输入:
x: [B, S, dim=7168]
qr: [B, S, q_lora_rank=1536] ← 来自主 attention 的 q_norm 输出
start_pos, offset
────────────────────
步骤 1:q 投影
────────────────────
q = wq_b(qr) # [B, S, n_heads × head_dim] = [B, S, 64 × 128] = [B, S, 8192]
# 计算量:1536 × 8192 = 12.6M FLOPs / token / layer
q = q.unflatten(-1, (64, 128)) # [B, S, 64, 128]
apply_rotary_emb(q[..., -64:], freqs_cis) # 仅最后 64 维
q = rotate_activation(q) # Hadamard 不改变形状
fp4_act_quant(q, 32, True) # FP4 量化(每 32 个元素一个 scale)
────────────────────
步骤 2:KV 压缩
────────────────────
self.compressor(x, start_pos)
# 内部:wkv(x): [B, S, 2*128=256] (overlap=True 翻倍维度)
# softmax(score) 加权求和后输出 [B, S/4, 128]
# 存到 self.kv_cache[:, :end_pos/4]
────────────────────
步骤 3:weights 投影
────────────────────
weights = self.weights_proj(x) # [B, S, 64]
# 计算量:7168 × 64 = 459K FLOPs / token / layer
weights *= self.softmax_scale * self.n_heads ** -0.5
# softmax_scale = 128**-0.5 ≈ 0.0884
# n_heads**-0.5 = 64**-0.5 = 0.125
# 总缩放系数 ≈ 0.011
────────────────────
步骤 4:score 计算
────────────────────
index_score = einsum("bshd,btd->bsht", q, kv_cache[:bsz, :end_pos/4])
# q: [B, S, 64, 128], kv_cache: [B, end_pos/4, 128]
# 输出: [B, S, 64, end_pos/4]
# 计算量:S × 64 × (end_pos/4) × 128 FLOPs / token / layer
# 在 1M context 下 ≈ S × 64 × 262144 × 128 ≈ 2.1G FLOPs / S tokens
index_score = (index_score.relu_() * weights.unsqueeze(-1)).sum(dim=2)
# 输出: [B, S, end_pos/4]
────────────────────
步骤 5:mask + topk
────────────────────
index_score += causal_mask # 维度同上
topk_idxs = index_score.topk(min(1024, end_pos/4), dim=-1)[1]
# 输出: [B, S, 1024]
────────────────────
步骤 6:offset 修正
────────────────────
topk_idxs += offset
# 输出: [B, S, 1024],整数索引
总计算量约 2.1 G FLOPs / S tokens (1M context)。对照主 attention 在同样配置下:
主 attention 的 q · k 内积:
q: [B, S, 128, 512], k 取自 KV cache (滑窗 + 压缩):
对于 ratio=4 层:query 看 (128 + 1024) ≈ 1152 个位置
计算量:S × 128 × 1152 × 512 ≈ 9.4 G FLOPs / S tokens (BF16/FP8)
主 attention 9.4 G + Indexer 2.1 G ≈ 11.5 G FLOPs / S tokens,相比”完全 dense” 的主 attention(要看完所有 1M / 4 = 262144 个位置,约 2150 G FLOPs)减少了约 187 倍。
Indexer 用 2.1 G 的额外计算,避免了 2150 G 的 dense 计算——这是 V4 稀疏路径的根本经济账。
4.5 Hadamard 旋转:FP4 量化前的”信息打散”
Indexer 的 q 在量化前过了一次 rotate_activation(q)。这个函数实际是 Hadamard 变换:
def rotate_activation(x: torch.Tensor) -> torch.Tensor:
"""Applies randomized Hadamard rotation to spread information across dims before FP8 quant."""
assert x.dtype == torch.bfloat16
from fast_hadamard_transform import hadamard_transform
return hadamard_transform(x, scale=x.size(-1) ** -0.5)
Hadamard 变换是一种正交变换——它不改变向量长度,但把每个维度的信息”打散”到所有维度。具体地,Hadamard 矩阵 H 满足 H × H^T = n × I,对一个向量 v 做 H × v 后,原本集中在某些维度的能量被均匀分布到所有维度。
为什么量化前要打散?因为 FP4 e2m1 的精度损失对”维度间能量集中”特别敏感——如果某些维度数值远大于其他维度,FP4 量化会把小数值维度”舍入到 0”,丢失信息。Hadamard 打散后所有维度数值大致同量级,FP4 量化的精度损失被均摊。
这种”先正交旋转再量化”的技巧不是 V4 首创——QuaRot 论文(arXiv:2404.00456)系统提出了这个方法。但 V4 把它工业化用到了 1.6T 模型的 attention 路径上——在 score net 这种”对精度可以放宽”的位置。
第 12 章会展开 V4 的 Hadamard + FP4 全链路。
4.5·补 Hadamard 旋转的几何直觉
Hadamard 旋转看起来神秘,但它的几何直觉非常简单——把一个向量的”信息”从某些集中维度均匀打散到所有维度。
考虑一个简化例子:四维向量 v = [10, 0.1, 0.1, 0.1]。
直接 FP4 量化:
- FP4 e2m1 的动态范围约 [0.5, 6](mantissa 1 bit + exponent 2 bit)
- 元素 10 量化后变成 6(饱和);元素 0.1 量化后变成 0(下溢)
- 量化后 v ≈ [6, 0, 0, 0]——丢失 75% 的信息
先 Hadamard 旋转再 FP4 量化:
- Hadamard 矩阵 H_4 = [[1,1,1,1], [1,-1,1,-1], [1,1,-1,-1], [1,-1,-1,1]] / 2
- H_4 × v ≈ [5.15, 4.95, 4.95, 5.05](每个维度都接近 5)
- 量化后约 [5, 5, 5, 5]
- 反 Hadamard 后 ≈ [10, 0, 0, 0]——主信息保留,仅丢失了零附近的小数
Hadamard 的关键性质:输入越”集中”(高峰值),输出越”均匀”;输入越”均匀”,输出仍然均匀。这种”消峰填谷”的效果对低精度量化是天然友好的——量化误差被均匀分布到所有维度,反向变换后误差也被均匀化。
flowchart LR
subgraph 量化前["FP4 量化挑战"]
direction TB
Concentrated["v = [10, 0.1, 0.1, 0.1]<br/>能量集中于 1 维"]
end
subgraph Hadamard["Hadamard 旋转"]
Spread["H × v = [5.15, 4.95, 4.95, 5.05]<br/>能量均匀分布"]
end
subgraph 量化后["FP4 量化"]
Quantized["[5, 5, 5, 5]<br/>误差在所有维度均摊"]
end
subgraph 反向["反 Hadamard"]
Recovered["≈ [10, 0, 0, 0]<br/>主信息保留"]
end
Concentrated --> Spread --> Quantized --> Recovered
V4 的 rotate_activation 函数用 fast_hadamard_transform 库实现,scale 设为 x.size(-1) ** -0.5——这个 scale 让 Hadamard 矩阵正交(H × H^T = I),保证旋转不改变向量长度。
为什么只有 Indexer 用 Hadamard 而主 attention 不用?因为主 attention 的精度要求更高,FP8 量化的精度损失已经够小,不需要 Hadamard 的”额外保险”。Indexer 走 FP4,量化损失更大,必须用 Hadamard 兜底。
4.6 Indexer 与 attn_sink 的协同
第 2 章讲过 attn_sink 是稀疏注意力的”兜底参数”——保证 Indexer 选错时数值不崩。Indexer 与 attn_sink 形成一个明确的协同分工:
- Indexer 努力选对:用学到的 score 函数尽可能选出最相关的 1024 个 KV
- attn_sink 兜住选错:当所有 1024 个 KV 都不相关时,sink 接住注意力质量
这个分工的工程意义:Indexer 不需要选得”完美”,只需要选得”够好”。如果它偶尔选错,attn_sink 会救场。这降低了 Indexer 的训练目标,让它可以在小算力(FP4 + 1/30 主 attention 成本)下达到生产可用的精度。
如果没有 attn_sink,Indexer 必须保证”永远不选错”——这会要求 Indexer 用与主 attention 同等量级的算力,工程上不可行。
attn_sink 的存在让 V4 的稀疏注意力从”理论可行”变成”工程可行”。
4.6·补 Indexer 在训练时的辅助监督
V4 的 Indexer 在推理时只输出 topk_idxs,不参与最终输出。但在训练时,Indexer 必须有自己的”学习目标”——否则它的参数无从更新。
虽然 V4 的训练源码不在公开仓库内,但从架构设计可以反推 Indexer 训练的几个标准做法(V3.2-Exp 论文已经描述过类似机制):
做法一:用主 attention 的 attention weights 作为监督信号
主 attention 在训练时会算完整的 dense softmax(Q · K^T)——这给每个 query token 一组”理想”的 KV 重要性分布。Indexer 的 score 应该与这组分布正相关——不需要完全相等,但 top-k 重合度要高。这可以用 KL 散度或 listwise ranking loss 作为监督。
做法二:稀疏 attention 输出与 dense attention 输出的距离
更直接的做法是:用 Indexer 的 topk 选取做一次稀疏 attention,再用主 attention 做一次 dense attention(仅训练时),对比两者的输出差异——这个差异就是 Indexer 的损失。
做法三:辅助 next-token prediction
让 Indexer 在主 lm_head 之外有一个”小 head”,直接预测下一个 token。如果稀疏选取错了关键 KV,next-token loss 会上升——这是另一种间接监督。
V4 实际用了哪种、或几种的组合,需要等技术报告或后续社区分析公开。但架构上保留了所有这些训练接口——Indexer 类的输入 qr 是 q_norm 之后、q_b 之前的中间表示,意味着 Indexer 可以在训练时被独立 supervise。
第 18 章会基于公开技术报告展开 V4 的训练 pipeline,到时候会回来填补 Indexer 训练监督的具体细节。
4.7 一段动手实验:Indexer 的最小可运行版本
import torch
import torch.nn as nn
import torch.nn.functional as F
class MiniIndexer(nn.Module):
"""简化版 Indexer:去掉 Hadamard / FP4 / 多卡,保留核心逻辑"""
def __init__(self, dim=512, n_heads=8, head_dim=64, ratio=4, topk=16):
super().__init__()
self.n_heads = n_heads
self.head_dim = head_dim
self.ratio = ratio
self.topk = topk
self.wq = nn.Linear(dim, n_heads * head_dim)
self.weights_proj = nn.Linear(dim, n_heads)
# 简化:这里用一个 Linear 模拟 Compressor 的输出(实际上应该是 Compressor)
self.kv_proj = nn.Linear(dim, head_dim)
self.scale = head_dim ** -0.5
def forward(self, x):
B, S, _ = x.shape
# q
q = self.wq(x).view(B, S, self.n_heads, self.head_dim)
# 模拟压缩 KV(每 ratio 个 token 取平均)
kv = x.view(B, S // self.ratio, self.ratio, -1).mean(dim=2)
kv = self.kv_proj(kv) # [B, S/ratio, head_dim]
# weights
weights = self.weights_proj(x) * self.scale * self.n_heads ** -0.5 # [B, S, n_heads]
# score
score = torch.einsum("bshd,btd->bsht", q, kv)
score = (score.relu_() * weights.unsqueeze(-1)).sum(dim=2) # [B, S, S/ratio]
# mask
mask = torch.arange(S // self.ratio).repeat(S, 1) >= torch.arange(1, S+1).unsqueeze(1) // self.ratio
score += torch.where(mask, float("-inf"), 0)
# topk
topk_idxs = score.topk(min(self.topk, S // self.ratio), dim=-1)[1]
return topk_idxs
# 测试
mi = MiniIndexer()
x = torch.randn(2, 32, 512)
idxs = mi(x)
print(idxs.shape) # 应该是 [2, 32, 8] (topk 在不足时被 clip)
跑通这个 mini 版本后再看 V4 源码,会觉得”原来 V4 的 Indexer 不过是把 mini 版本加上了 Compressor、Hadamard、FP4 三层工业化壳”。核心算法没那么神秘。
4.8 延伸阅读
- DeepSeek-V3.2-Exp:DSA score net 的最早期实现
- QuaRot 论文(arXiv:2404.00456):Hadamard 旋转 + INT4 量化的源头
- Native Sparse Attention 论文(arXiv:2502.11089):稀疏 attention 训练理论基础
- 本书第 5 章:sparse_attn kernel 把 Indexer 输出的 topk_idxs 真正”用起来”——FlashMLA 的 V4 路径
- 本书第 3 章:Indexer 的 Compressor 与 Attention 的 Compressor 是同源结构——回看 §3.6
4.8·补 Indexer 的工程实现陷阱清单
读 V4 的 Indexer 源码时,至少有四个细节如果没注意会在自己的实现里翻车:
陷阱一:weights_proj 的 dtype
self.weights_proj = ColumnParallelLinear(self.dim, self.n_heads, dtype=torch.bfloat16)
注意明确指定 dtype=torch.bfloat16——而 Indexer 的其他 Linear 默认走 FP4(因为 default_dtype 在 Transformer.init 里被设为 FP4)。weights_proj 必须 BF16 是因为它输出的是 head-level 的标量权重,FP4 量化会让这些权重精度不足。
陷阱二:rotate_activation 的 dtype 断言
def rotate_activation(x: torch.Tensor) -> torch.Tensor:
assert x.dtype == torch.bfloat16
...
Hadamard 变换强制要求输入是 BF16。如果你给它 FP8 或 FP4,会直接 assert 失败。Indexer 的 q 在 wq_b 之后是 BF16(FP4 GEMM 的输出默认 BF16),可以直接传给 rotate_activation;如果你的实现里 q 已经被提前量化到 FP4,需要先反量化回 BF16 再做 Hadamard。
陷阱三:causal mask 的边界
mask = torch.arange(seqlen // ratio).repeat(seqlen, 1) >= torch.arange(1, seqlen + 1).unsqueeze(1) // ratio
注意是 >= 不是 >,且右边是 (arange + 1) // ratio 而不是 arange // ratio。这两个细节确保了”压缩组的右边界”被正确处理——一个压缩组 g 包含 token [g*ratio, (g+1)*ratio),token i 只能看到压缩组 g 当 i >= (g+1)*ratio,即 g <= i // ratio - 1,即 g < (i+1) // ratio。源码的 >= 把”不可见”的位置打上 -inf。
陷阱四:offset 在 prefill / decode 上的不同语义
# Attention.forward 里
offset = kv.size(1) if start_pos == 0 else win
prefill 时 offset = kv.size(1)(所有滑窗 KV 已经填到 cache 前 seqlen 槽位);decode 时 offset = win(滑窗已经稳定占据前 win 槽位)。如果你写 reuse 一个常量 offset,两条 codepath 会出现”压缩 KV 索引与 cache 实际位置错位” 的 bug——稀疏 attention 读到的就是错位的 KV。
这四个陷阱都不会让源码报错,但会让模型输出”看起来对、实际错”——是稀疏 attention 实现里最难调试的一类问题。
4.8·补·补 Indexer 的”输入信号” 完整性分析
Indexer 的 forward 输入是 (x, qr, start_pos, offset)——四个参数提供完整的”上下文信号”。把每个信号的作用讲清楚。
信号 1:x(hidden state)
x 是当前 token 的 hidden state。Indexer 用 x 算两件事:
- 算 weights_proj 输出(每 head 权重)
- 通过自己的 Compressor 把 x 压成新的压缩 KV
x 提供”这个 token 是什么内容”的信息——决定它对哪些历史 KV 感兴趣。
信号 2:qr(query 中间表示)
qr 是主 attention 的 q_norm 输出 [B, S, q_lora_rank=1536]。Indexer 用 qr 经过自己的 wq_b 算 query。
qr 提供”主 attention 视角下的 query 信号”——让 Indexer 的 query 与主 attention 对齐。两者共享 q_lora 让 Indexer 不需要重学这一层投影。
信号 3:start_pos
当前 token 在序列中的位置。Indexer 用这个:
- 决定 freqs_cis 的切片(RoPE 旋转角度)
- 决定 causal mask(不能 attend 到未来)
- 决定 prefill / decode codepath
信号 4:offset
把 topk_idxs 从”压缩 KV 索引”换算到”主 KV cache 绝对索引”的偏移量。prefill 与 decode 的 offset 含义不同(详见第 §4.8 陷阱清单)。
这 4 个信号合起来让 Indexer 有”决定 top-k KV 选择” 所需的全部输入。少任何一个都会让 Indexer 工作不正确。
和主 attention 的”信号共享”:
Indexer 与主 attention 共享:
- q_lora_rank 投影(qr)
- 共享 freqs_cis(RoPE)的源
- 共享 input x
但不共享:
- wq_b(独立的 query 投影)
- Compressor 实例(独立的 KV 压缩)
- 量化策略(FP4 vs FP8)
这种”共享 + 独立”的设计让 Indexer 又快又准——快是因为复用了主 attention 的中间表示,准是因为有独立的精修空间。
4.8·延展 Indexer 的运行时性能特征
Indexer 在生产中的运行时性能对最终延迟影响显著。把它的几个特征列出来:
特征 1:每层只有 ratio=4 的层有 Indexer
V4 Pro 61 层中约 30 层是 ratio=4——意味着 Indexer 只在这 30 层运行。其他层(ratio=128 或 0)跳过 Indexer,省时间。
特征 2:Indexer 与主 attention 部分并行
Indexer 的 wq_b(q 投影)可以与主 attention 的 q 投影并行——两者输入都是 qr。生产实现里通常把这两个 GEMM 合并成一次 launch。
特征 3:Indexer 的输出在主 attention 之前完成
Indexer 必须先完成(输出 topk_idxs),主 attention 才能开始 sparse_attn。这是串行依赖——不能并行。这是 Indexer 在 critical path 上的体现。
特征 4:Indexer 的延迟约主 attention 的 5-10%
Indexer 是 1/30 的计算量(详见 §4.4),但因为它在 critical path 上,实际延迟占比比 1/30 高——约 5-10%。这是 V4 团队为了换取稀疏性付的延迟代价。
特征 5:Indexer 不并发跨层
Indexer 必须在本层 sparse_attn 之前完成,但不需要等其他层。理论上不同层的 Indexer 可以”流水线”——但实际生产中没必要,因为每层的延迟差异已经很小。
理解这些性能特征让你在调优 V4 时知道”Indexer 是否是瓶颈”——大多数情况下不是。如果是,意味着 GPU 极不健康,应该排查硬件而非软件。
4.9 本章小结
- Indexer 是 V4 稀疏注意力的”score net”——专门给压缩 KV 打分并选 top-1024
- 它与主 attention 共享 q_lora_rank 投影,但有独立的 wq_b、weights_proj、Compressor
- 它的 head_dim=128(而非主 attention 的 512)、n_heads=64(而非 128)、q/k 走 FP4——总计算量约主 attention 的 1/30
- ReLU + 加权 sum 替代 softmax + sum 是 V4 的工程取舍——topk 边界更稳定
- Hadamard 旋转保证 FP4 量化的精度损失被均摊到所有维度
- attn_sink 与 Indexer 形成”努力选对 + 兜住选错”的工程冗余
第 5 章我们进入 V4 注意力革命的最后一站:sparse_attn kernel——FlashMLA 在 V4 路径下到底是怎么实现这个稀疏 attention 的。
评论 0
还没有评论,来说两句吧。
评论加载失败,刷新重试。