第5章 sparse_attn 与 FlashMLA:V4 路径下的 CUDA 内核

“Architecture without efficient kernels is poetry without rhyme.” —— 引自 NVIDIA 一位资深 CUDA 工程师

V4 的全部 attention 革命(MLA + Compressor + Indexer),最后必须落到一个 CUDA kernel 上才能产生工程价值——这个 kernel 就是 FlashMLA。


5.1 引子:从一行 PyTorch 到 GPU 上的真实计算

V4 的 Attention.forward 里 attention 的”实际计算”被压缩成一行:

o = sparse_attn(q, kv, self.attn_sink, topk_idxs, self.softmax_scale)

参数:

  • q:query,形状 [B, S, n_heads=128, head_dim=512],dtype 取决于 prefill / decode
  • kv:KV cache(滑窗 + 压缩段拼起来的完整 KV),形状 [B, kv_cache_size, head_dim=512]
  • attn_sink:每 head 一个 float32 标量,形状 [n_heads]
  • topk_idxs:稀疏选取的 KV 位置索引,形状 [B, S, window_size + index_topk] = [B, S, 1152]
  • softmax_scale:缩放系数 head_dim ** -0.5

sparse_attn 是一个从 kernel 模块导入的函数:

from kernel import act_quant, fp4_act_quant, fp8_gemm, fp4_gemm, sparse_attn, hc_split_sinkhorn

kernel 模块本身是一个C++/CUDA 扩展——sparse_attn 的 PyTorch 入口只是个绑定,真正的计算发生在 GPU kernel 上。这个 kernel 被托管在 FlashMLA 仓库里,与 V4 同一周期开源。

flowchart TB
  subgraph Python层
    P1["Attention.forward 中:<br/>sparse_attn(q, kv, sink, idxs, scale)"]
    P2["kernel.py 中的 sparse_attn binding"]
  end
  subgraph C++层
    C1["torch::Tensor sparse_attn(...)"]
    C2["dispatch by GPU arch (SM90/SM100)"]
  end
  subgraph CUDA层
    K1["sparse_attn_kernel_sm100<<<...>>>"]
    K2["针对 H100 / B200 的 ldsm + WGMMA 实现"]
  end
  P1 --> P2 --> C1 --> C2 --> K1 --> K2

本章拆这条调用链——从 Python 的 sparse_attn 到 CUDA kernel 的全部工程缝合。


5.2 sparse_attn 的接口语义

sparse_attn(q, kv, sink, idxs, scale) 的语义是:

对每个 query token 的每个 head,仅计算 q · kv[idxs] 的内积,做 softmax + 加权求和,然后跨 head 输出。

伪代码形式:

def sparse_attn(q, kv, attn_sink, topk_idxs, softmax_scale):
    # q: [B, S, H, D]
    # kv: [B, T, D]   (T = window_size + max_seq_len // ratio)
    # attn_sink: [H]
    # topk_idxs: [B, S, K]   (K = window_size + index_topk = 1152)
    # softmax_scale: scalar
    # 返回 o: [B, S, H, D]

    B, S, H, D = q.shape
    K = topk_idxs.shape[-1]
    o = torch.zeros_like(q)

    for b in range(B):
        for s in range(S):
            for h in range(H):
                # gather K 个 KV 位置
                gathered_kv = kv[b, topk_idxs[b, s]]   # [K, D]
                # mask 掉 -1 位置
                valid = topk_idxs[b, s] >= 0
                # logits = q · k / sqrt(D)
                logits = (q[b, s, h] @ gathered_kv.T) * softmax_scale
                logits = torch.where(valid, logits, float("-inf"))
                # 拼接 attn_sink
                logits_with_sink = torch.cat([logits, attn_sink[h:h+1]])
                weights = F.softmax(logits_with_sink, dim=0)
                # 加权求和(不包括 sink)
                o[b, s, h] = (weights[:K, None] * gathered_kv).sum(dim=0)

    return o

这只是语义模型——真实 CUDA 实现要快几个数量级。但理解这个伪代码后,CUDA 实现就是”如何把这段循环并行化、向量化、利用 TensorCore”的工程问题。


5.3 FlashMLA 仓库的代码组织

FlashMLA 仓库(github.com/deepseek-ai/FlashMLA)的代码组织:

FlashMLA/
├── csrc/                           # C++ / CUDA 源码
│   ├── flash_mla/
│   │   ├── sparse_attn_sm100.cu    # SM100 (B200) 路径
│   │   ├── sparse_attn_sm90.cu     # SM90 (H100/H800) 路径
│   │   ├── dense_attn_sm90.cu      # 兼容 V3.2 的 dense MLA
│   │   └── ...
│   ├── flash_mla_extension.cpp     # PyTorch binding
│   └── ...
├── flash_mla/                      # Python 包装
│   ├── __init__.py
│   └── kernel.py                   # 暴露 sparse_attn / fp8_gemm 等接口
├── tests/                          # 测试
└── benchmark/                      # 性能基准

V4 的 inference/model.py 直接 from kernel import sparse_attn——这个 kernel 模块就是 FlashMLA 编译安装后暴露的 Python 包。

FlashMLA 的关键设计:

  • 每个 GPU 架构一个独立的 .cu 文件:SM90 (H100/H800) 与 SM100 (B200) 的 kernel 实现差异巨大——TMA、WGMMA、共享内存大小都不同,必须分开优化
  • dense 与 sparse 路径并存:dense 路径服务 V3 / V3.2-Exp(dense MLA),sparse 路径服务 V4。一个仓库支持两代模型
  • PyTorch binding 极薄:C++ 入口只做参数检查 + dispatch by arch + launch kernel,不含业务逻辑

5.4 稀疏 attention 的 GPU 优化挑战

稀疏 attention 比 dense attention 在 GPU 上更难优化。dense attention 的内存访问模式是连续的——每个 query 顺序读 KV cache 从 0 到 T。稀疏 attention 的内存访问模式是索引跳跃的——每个 query 按 topk_idxs 跳着读 KV。

这带来三个挑战:

挑战一:合并访问失效

GPU 的全局内存读取按 32-byte / 128-byte cache line 合并。dense attention 的 KV 读取连续,多个线程一次读取一条 cache line。稀疏 attention 的 KV 读取分散,多个线程可能读到不同的 cache line——内存带宽利用率下降。

FlashMLA 的解决方案:用 ldmatrix / TMA 指令做”非连续 gather”。SM90 / SM100 的 TMA(Tensor Memory Accelerator)支持基于索引的非连续读取,硬件层面解决合并访问问题。

挑战二:索引去重

如果两个相邻 query token 的 topk_idxs 有重叠(实际上重叠率往往很高),naive 实现会多次读同一个 KV。FlashMLA 的优化:先把所有 query 的 topk_idxs 做 union,按 union 后的位置读取 KV,再 gather 到每个 query。这把多次读取变成一次。

挑战三:tile 大小与 K 的关系

FlashAttention 的标准做法是把 KV 分成 tiles(典型 64 / 128 token 一个 tile),逐 tile 处理。稀疏 attention 的 K(每个 query 看 1152 个位置)与 tile 大小的关系决定了 kernel 效率:

  • K << tile_size:每个 tile 内能处理的 query 多,但有空闲的 tile 槽位
  • K >> tile_size:每个 query 要跨多个 tile,tile 边界的 softmax 归一化要做”online softmax”

FlashMLA 的 V4 路径选 K=1152、tile_size=128——意味着每个 query 跨 9 个 tile。kernel 用 online softmax 在 tile 间累积分子分母,最后归一化。


5.4·补 online softmax:稀疏 attention 的核心算法挑战

V4 的 sparse_attn kernel 必须解决一个问题:当一个 query 跨多个 tile 计算时,怎么在 tile 之间正确累积 softmax 的分子分母

这个问题在 dense FlashAttention 里就已经存在——每个 query 看完整 KV,KV 被分 tile 处理。FlashAttention 的解法是 online softmax 算法(Milakov & Gimelshein 2018):

对于每个 tile, 维护两个 running quantity:
  m: 当前已见 logits 的最大值
  l: 当前的 sum of exp(logits - m)

收到新 tile 的 logits 时:
  m_new = max(m, max(new_logits))
  l_new = exp(m - m_new) * l + sum(exp(new_logits - m_new))
  o_new = exp(m - m_new) * o + sum(exp(new_logits - m_new) * new_v)

最后归一化:
  o = o / l

这个算法的妙处在于 m 和 l 的更新可以增量进行,且数学上等价于 dense softmax。

V4 的 sparse_attn 的 online softmax 多了一个复杂性:attn_sink。sink 不是某个 tile 里的位置,而是一个 head-level 的常数 logit——必须把它”虚拟”地参与到 softmax 归一化里。

具体处理:

  1. 初始化 m = attn_sink[h],l = 1(exp(0) = 1,因为 sink 的 logit 已经在 m 里)
  2. 处理每个 KV tile 时,用上述公式更新 m / l / o
  3. 注意 o 的累积不包括 sink 项(sink 不贡献输出向量,只参与归一化)

这种”sink-aware online softmax” 是 V4 sparse_attn 的特化部分,FlashAttention v3 的标准 online softmax 没有这个能力。

flowchart LR
  subgraph TileFlow["online softmax 的 tile 流"]
    T0["tile 0: 64 个 KV"] -->|更新 m,l,o| T1["tile 1: 64 个 KV"]
    T1 -->|更新 m,l,o| T2["tile 2: ..."]
    T2 -->|...| TN["tile N"]
    TN -->|最后除以 l| Out["o 输出"]
  end
  Sink["attn_sink 作为<br/>初始 m"] -.参与归一化.-> T0

V4 的 sparse_attn kernel 在每个 query 上跑这套 sink-aware online softmax,跨 9 个 KV tiles(K=1152 / tile_size=128)累积,最后归一化输出。


5.4·补·补 索引 gather 的硬件路径

稀疏 attention 在 SM90 / SM100 上的”索引 gather”是 V4 工程化的一个具体硬件挑战。让我们看一段简化的 CUDA 伪代码理解它在硬件层做了什么:

// 简化的稀疏 KV gather kernel(C++/CUDA)
__global__ void gather_kv_for_sparse_attn(
    const __nv_fp8_e4m3* kv_global,        // [B, T, D] 全部 KV
    const int* topk_idxs_global,           // [B, S, K] topk 索引
    __nv_fp8_e4m3* kv_gathered_smem,       // shared memory 输出
    int B, int S, int T, int D, int K)
{
    int b = blockIdx.x;
    int s = blockIdx.y;
    int tid = threadIdx.x;

    // 每个 thread 负责 K/blockDim.x 个 KV 位置
    for (int k_local = tid; k_local < K; k_local += blockDim.x) {
        int kv_idx = topk_idxs_global[b * S * K + s * K + k_local];
        if (kv_idx < 0) continue;          // mask -1 invalid
        // gather D 维 (D=512)
        // SM90: 用 cp.async.bulk + ldmatrix 做异步 gather
        // SM100: 用 TMA 的 indexed mode
        int dst_offset = k_local * D;
        int src_offset = b * T * D + kv_idx * D;
        // ... TMA / cp.async copy ...
    }
}

实际 FlashMLA 的 kernel 比这复杂得多——

  • cp.async.bulk 异步从 global 拷到 shared,重叠拷贝与计算
  • ldmatrix.x4 一次加载 4 个 8x16 矩阵到寄存器
  • TensorCore 用 WGMMA 指令做 q @ k^T 的 GEMM

但核心思想是:用硬件原生的”非连续 gather”指令把 topk 选取的 KV 拷到 SMEM,然后正常做 attention。FlashMLA 的工程价值就在于”把这个 gather + GEMM + softmax 流水线优化到接近硬件极限”。


5.5 SM90 vs SM100:两套 kernel 的差异

V4 同时支持 H100/H800(SM90)和 B200(SM100)。FlashMLA 给两套架构写了完全独立的 kernel

维度SM90 (H100/H800)SM100 (B200)
FP8 GEMM 指令WGMMA (Warp Group MMA)WGMMA (改进版) + UMMA
FP4 GEMM 指令模拟(FP8 → FP4 算子分解)原生 FP4 MMA
TMA1D / 2D1D / 2D / 3D + tensor map
共享内存228 KB / SM256 KB / SM
L2 cache50 MB60 MB
Threadblock size4 warps / 8 warps8 warps / 16 warps
关键差异softmax 在 SMEM 内做softmax 跨 SMEM + L2

V4 在 H100 与 B200 上的 throughput 差异(README 公开数字):

  • H100 上 V4 Pro decode 吞吐约 410 TFlops(FP8)
  • B200 上 V4 Pro decode 吞吐约 640 TFlops(FP8)

差异主要来自 B200 的原生 FP4 MMA + 更大的 L2。SM90 路径的 FP4 因为是模拟(先反量化到 FP8 再 GEMM),效率比原生差不少——这是 V4 在 B200 上有”更高占有率”的硬件原因。


5.6 V4 sparse_attn 与 FlashAttention v3 的对比

FlashAttention v3 是 dense attention 在 H100 / B200 上的事实标准。把 V4 的 sparse_attn 与 FA3 横向对比:

维度FlashAttention v3V4 sparse_attn
KV 访问模式连续索引 gather
支持的 head_dim32 / 64 / 128 / 256必须 512
支持的稀疏度不支持top-K
是否支持 sink不支持支持
是否支持 grouped O不支持(O 投影外置)不支持(O 投影也外置)
主要应用场景dense attention 全部模型V3 (dense MLA) + V4
长上下文成本O(n²) FLOPsO(n) FLOPs (K 固定)

V4 的 sparse_attn 不能替代 FA3——它是为 V4 这种”超大 head_dim + 稀疏选取 + sink”的特定形态量身定做的。反过来 FA3 也不能跑 V4——FA3 的 dense KV 假设与 V4 的 topk_idxs 接口完全不兼容。

这种”特化 kernel”的代价是 V4 必须自带 FlashMLA。但带来的红利是:V4 的稀疏注意力在 H100 上能跑 410 TFlops——这个数字是 FA3 跑 dense attention 的 75% 左右,意味着稀疏 attention 在工程上已经追上 dense attention 的效率


5.7 vLLM / SGLang 集成 sparse_attn 的工程接缝

把 sparse_attn 集成进 vLLM / SGLang 这类推理引擎,至少要做四件事:

事项一:编译 FlashMLA 为可链接库

FlashMLA 的 C++/CUDA 必须用 -arch=sm_90sm_100 编译,且需要 CUDA 12.8+。集成时需要:

  • 把 FlashMLA 加到引擎的 wheel 构建脚本
  • 处理”用户的硬件不是 H100/B200” 的回退(一般回退到 PyTorch + Triton 实现)

事项二:传 topk_idxs 给 kernel

引擎需要在每次 forward 时为每个 attention layer 计算 topk_idxs——这意味着引擎要 invoke Indexer,传 query / 中间表示给 Indexer,再把 Indexer 的输出送给 sparse_attn。这个 dataflow 在 vLLM / SGLang 之前的代码里完全不存在——必须新增。

事项三:KV cache 形状改造

vLLM 默认的 PagedAttention KV cache 形状是 [block, block_size, head, dim]。V4 的 KV cache 是 [B, window + n/ratio, dim]——不分 head(MQA-style),不按 block 切。集成时要么扩展 PagedAttention 的形状抽象、要么新增一个”V4-style KV cache” 类型。第 19 章会展开 vLLM PR 的具体改动。

事项四:与调度器的协调

V4 的 prefill / decode 走不同 codepath,且 prefill 有”压缩 KV 一次性算”的批量步骤。vLLM 的调度器需要识别”这是 V4 模型”,分别调度 prefill 和 decode 阶段,避免把它们错误合并到同一个 CUDA stream。

第 19 章会针对 vLLM 主仓库的 V4 适配 PR 逐改动展开。本章只到这里——给读者建立”sparse_attn 需要外部生态怎么配合”的全局认知。


5.8 动手实验:跑通 FlashMLA 的 V4 测试

# 1. 拉取 FlashMLA
git clone https://github.com/deepseek-ai/FlashMLA.git
cd FlashMLA

# 2. 编译(需要 CUDA 12.8+ 和 H100/H800 或 B200)
pip install -e .

# 3. 跑 V4 路径的单元测试
python -m pytest tests/test_sparse_attn.py -v -k v4

# 4. 跑性能基准(输出 TFlops 数字)
python benchmark/bench_sparse_attn.py --arch sm90 --seq-len 1048576 --topk 1024

如果你没有 H100/H800,可以用 nVIDIA 的 PTX 模拟器或者改 arch=sm_80(A100,但需要回退到 Triton 实现,性能差得多)。

测试通过后,会得到一个对照表:dense attention vs sparse_attn 的 TFlops、显存占用、首 token 延迟。把这个表对照本书第 1 章 §1.9·补·补 的”README 三组数字”,能看到工程数字与营销数字的吻合度。


5.8·补 sparse_attn 在不同 batch / context 下的性能特征

V4 的 sparse_attn 性能不是恒定的——它随 batch size、context 长度、稀疏比例呈非线性变化。把这条性能曲线的几个关键拐点标出来:

拐点 1:context = 32K

context ≤ 32K 时,sparse_attn 的优势相对小——dense attention 在 32K 下 KV 也才几 GB,FLOPs 也可承受。sparse_attn 的”选 top-1024”反而引入了 Indexer 的额外成本。这个范围内 sparse_attn 比 dense 快约 1.2-1.5 倍。

拐点 2:context = 128K

context = 128K 时差异显著拉开——dense 的 attention FLOPs 是 O(n²) 二次增长,sparse_attn 的 FLOPs 是 O(n × 1152) 几乎线性。这个范围内 sparse_attn 比 dense 快约 5-8 倍。

拐点 3:context = 1M

context = 1M 是 V4 的设计目标。dense attention 在这里完全不可用(FLOPs 爆、KV 爆),sparse_attn 仍能维持大约 80-90% 的 32K 单 token 吞吐。这个范围内 sparse_attn 是”唯一可行方案”——速度对比变得无意义。

拐点 4:batch = 32+

随 batch 增大,sparse_attn 的优势从”FLOPs”转向”显存”——dense 的 KV cache 在大 batch 下爆掉,sparse_attn 的 KV cache 仍在可控范围。这个范围内 sparse_attn 让”原本只能跑 batch=4 的硬件能跑 batch=32”——并发能力提升 8 倍。

拐点 5:稀疏比例(topk / total_kv)= 1%

V4 默认 topk=1024,1M context 下稀疏比例约 0.4%。如果应用场景把 topk 调到 4096(稀疏比 1.6%),sparse_attn 的吞吐下降约 40%——但精度提升约 5%。这种”精度 vs 速度” 的 trade-off 是部署时可以调的。

理解这条性能曲线对容量规划 极重要——你要根据自己的典型 context 长度选择”V4 是否真的合适”——short context + 大 batch 用户其实可以用更小的模型。


5.8·补·补 sparse_attn 与 GPU SM 占用率

sparse_attn 在 GPU 上的”占用率”(SM utilization)是衡量 kernel 优化质量的关键指标。

理论上限:H100 有 132 个 SM。每个 SM 同时跑多个 warp(线程块)。sparse_attn 的理论上限是”所有 SM 都满载在跑 attention 计算”——约 95% 占用率。

实际占用率:FlashMLA 的 sparse_attn 在 H100 上典型占用率 80-85%。差距来自:

  • 索引 gather 的等待(部分 SM 在等 cp.async 完成)
  • tile 边界的 softmax 同步开销
  • Indexer 的输出尚未到达时 sparse_attn 必须等待

B200 上的占用率:B200 有 192 个 SM + 原生 FP4。FlashMLA 在 B200 上占用率 90%+——更接近理论上限。这是 V4 在 B200 上比 H100 快 1.6 倍的内在原因。

GPU 时钟与温度:实际占用率还受 GPU 物理状态影响。H100 在 boost clock + 良好散热下占用率最高;如果 thermal throttle,占用率会降到 70%。生产部署时必须监控 nvidia-smi 的 power / temp / clock。

与其他 kernel 协同:sparse_attn 不是孤立运行——它与 Linear(DeepGEMM)、与 RMSNorm、与 Compressor 共享 SM 资源。如果其他 kernel 占太多 SM,sparse_attn 会被”挤压”。V4 的解决:用 CUDA Graph 把多个 kernel 编排成”流水线”,让 SM 利用率持续高位。

理解 SM 占用率让你能判断”sparse_attn 是否被瓶颈”——如果 nvidia-smi 显示 GPU 利用率只有 60%,sparse_attn 一定不是瓶颈,问题在其他地方(数据加载 / 网络 / Python 调度)。


5.9 延伸阅读

  • FlashAttention v3(arXiv:2407.08608):dense attention 的 SM90 实现
  • Native Sparse Attention(arXiv:2502.11089):稀疏 attention 训练的理论
  • DeepSeek FlashMLA 仓库 README:本章主要参考
  • 本书《vLLM 推理内核深度解析》第 4-5 章:PagedAttention 与 V4 的 KV cache 对接
  • 本书第 19 章:vLLM 主仓库 V4 适配 PR 的全部改动

5.9·补 sparse_attn 的”工程债务”清单

任何工业级 kernel 都有它的工程债务——为快速发布而留下的”以后再优化” 项。FlashMLA 的 sparse_attn 也有。把可观察到的工程债务列出来:

债务 1:dense 路径与 sparse 路径并存

FlashMLA 仓库同时维护 dense 路径(给 V3 / V3.2-Exp)和 sparse 路径(给 V4)。两套代码在某些功能上重叠(如 KV cache 管理),但实现独立。未来某个版本可能把它们统一——但短期内为了不引入回归风险,保留并存。

债务 2:SM80(A100)路径缺失

FlashMLA 主要支持 SM90 / SM100。SM80(A100)没有原生 FP8 / FP4,需要软件模拟——FlashMLA 没有为此专门优化。如果你想在 A100 上跑 V4,sparse_attn 会有显著性能损失。

债务 3:动态稀疏度不灵活

V4 的 topk=1024 是 config 写死的。FlashMLA 内部 tile_size 与 K=1152 紧密绑定——如果 fine-tune 模型把 topk 改成 512 或 2048,需要重新调 kernel 的 tile size。短期内只能用预定的几个固定值。

债务 4:与 cuBLAS / cutlass 的协同

vLLM 部署时 sparse_attn 与 cuBLAS(其他模型的 GEMM)共享 GPU。但 FlashMLA 的 sparse_attn 占用的 SM 数固定(通过 set_num_sms),不会根据 cuBLAS 的负载动态调整。理想情况下应该有 SM 调度协同——目前是工程债务。

债务 5:错误诊断信息有限

如果 sparse_attn 跑出错(如 topk_idxs 越界),错误信息通常是 CUDA error,不直接指向问题位置。FlashMLA 缺一套 debug build——开启后可以打印每个 tile 的状态。这是开源项目的常见短板。

理解这些债务让你在使用 FlashMLA 时心里有底——遇到问题时知道去哪里查、知道哪些限制是”暂时的”,避免被”看似不一致的现象” 困惑。


5.9·补·补 sparse_attn 工程师速记

部署或调试 V4 sparse_attn 时最常用的几条速记规则。打印一张贴在工位上:

速记 1:版本要求

  • CUDA 12.8+
  • PyTorch 2.4+
  • H100 / H800(SM90)或 B200(SM100)
  • A100 / 老 GPU 不支持

速记 2:性能数字(H100 上)

  • FP8 GEMM 峰值:~1300 TFlops
  • sparse_attn 峰值:~410 TFlops
  • B200 大约提升 1.6x

速记 3:典型形状参数

  • topk_idxs 大小:[B, S, 1152](=window_size 128 + index_topk 1024)
  • KV cache 大小:每层 [B, 滑窗 + n/ratio, 512]
  • 单序列 KV cache 总和:~8 GB(1M context)

速记 4:常见错误信号

  • “RuntimeError: invalid argument”:GPU 不支持 SM 架构
  • “CUDA error: invalid configuration”:tile size 不匹配(通常 SMEM 不够)
  • 输出 NaN:q 没正确量化或 attn_sink 没初始化

速记 5:优化优先级

  • 检查 GPU 利用率(nvidia-smi)—— 低于 80% 说明是 IO/通信瓶颈,sparse_attn 不是元凶
  • 检查 tile size 与 SMEM 余量
  • 检查 stream 配置(DeepGEMM 与 sparse_attn 是否在同一 stream 串行)

速记 6:与 vLLM 集成的关键文件

  • vllm/attention/backends/flash_mla.py(V4 的 attention backend)
  • vllm/model_executor/models/deepseek_v4.py(V4 model class)
  • vllm/distributed/device_communicators/(DeepEP 集成)

这些速记是”工业实战经验” 的浓缩——不需要每次都翻文档,速记能解决 80% 的日常问题。


5.9·延展 sparse_attn 与 dense FlashAttention 的”代码量对比”

把 V4 的 sparse_attn 与 FlashAttention v3(dense)的代码量、复杂度、可读性对比一下——这能让你直观感受 sparse 路径的工程额外开销。

FlashAttention v3 (dense)

  • 主要 .cu 文件:3-5 个,每个数百行
  • 核心算法:online softmax + 分块 KV
  • 接口:flash_attn_func(q, k, v, ...)
  • 调用方式:直接传完整 KV,kernel 内部分 tile 处理
  • 数学复杂度:低(标准 attention)

FlashMLA sparse_attn (V4 路径)

  • 主要 .cu 文件:6-10 个(SM90 + SM100 各一套)
  • 核心算法:sink-aware online softmax + 索引 gather + tile 内 sparse mask
  • 接口:sparse_attn(q, kv, sink, idxs, scale) —— 多 2 个参数
  • 调用方式:传 KV cache + topk 索引,kernel gather 后再算
  • 数学复杂度:中(sink + sparse 选取)

代码量差异

sparse_attn 比 dense 多约 2-3x 代码量——主要在索引 gather、sink 处理、tile 边界对齐上。

可读性差异

dense FlashAttention v3 已经是 CUDA kernel 中的”复杂代码”——理解它需要懂 WGMMA、TMA、async copy。sparse_attn 在此基础上再加一层”非连续访问”的复杂度,理解成本约高 50%。

性能差异

dense FlashAttention v3 在 H100 上达到 ~600 TFlops(接近 FP8 GEMM 峰值)。sparse_attn 在 H100 上达到 ~410 TFlops——比 dense 低 30%,但因为只算 K=1152 个位置(dense 算所有),在长 context 下整体仍然快 5-10x。

这种”单位 FLOPs 慢、但总 FLOPs 少” 的工程权衡正是 sparse 路径的本质——用更少的 FLOPs 做出与 dense 相当的输出


5.9·拓展 sparse_attn 与”动态批处理(continuous batching)“的协同

vLLM 的核心调度优化是”continuous batching”——多个请求的 prompt / decode 共享 GPU 跑在一个 batch 里。V4 的 sparse_attn 与这套调度有几个微妙的协同点。

协同点 1:每请求独立的 topk_idxs

continuous batching 把多个请求的 query 拼到一起跑 attention。V4 的 sparse_attn 接受 [B, S, K] 的 topk_idxs——每条 sequence 有独立的稀疏选择。这与 dense attention 兼容(dense 用 padding mask)。

协同点 2:长短 prompt 共存

某些请求是 64K context、某些是 1K——一起 batch 时长 prompt 的 KV cache 占大头,但 sparse_attn 的计算仍可控(K=1152 固定)。这是 V4 在 mixed workload 下的红利。

协同点 3:prefill / decode 的混合 batching

vLLM 的 chunked prefill 让 prefill 与 decode 在同一 batch 内跑。V4 的 sparse_attn 需要分别处理两种 phase——prefill 走批量 codepath、decode 走增量 codepath。kernel 需要支持”同一 batch 内同时跑两种 codepath”。

协同点 4:prefix caching 与稀疏 KV

vLLM 的 prefix caching 复用相同前缀的 KV——多个请求共享 KV cache 的某部分。V4 的稀疏 KV 让这种复用更复杂——压缩 KV 段可以共享,滑窗段每请求独立。详见第 19 章关于 SGLang RadixAttention 的讨论。

协同点 5:抢占(preemption)

vLLM 的调度器会抢占 / 恢复请求——把”被抢占请求的 KV cache” 暂存到 CPU 后再恢复。V4 的稀疏 KV 也支持这种 swap——但 swap 单位变成”滑窗段 + 压缩段”两块,不再是单一 KV block。

理解这些协同点让你在 vLLM 中正确部署 V4——不会因为”V4 行为与 dense 模型不同”而踩坑。


5.10 本章小结

  • V4 的 sparse_attn 在 PyTorch 是一行调用,背后是 FlashMLA 仓库里数千行 CUDA 代码
  • 稀疏 attention 在 GPU 上比 dense attention 难优化——索引跳跃、tile 跨界、softmax 跨 SMEM 都是工程挑战
  • FlashMLA 的 V4 路径用 ldmatrix / TMA + online softmax 解决了主要挑战,在 H100 上能跑 410 TFlops
  • B200 (SM100) 因为有原生 FP4 MMA,比 H100 快约 1.6 倍——V4 在 B200 上的红利明显
  • 集成 sparse_attn 进 vLLM / SGLang 等引擎需要四件事:编 FlashMLA、传 topk_idxs、改 KV cache、调度器协调

第 6 章我们离开 attention 内部,来到 V4 长上下文工程的另一支柱:YaRN RoPE——它怎么把 65K 训练上下文外推到 1M。

评论 0