vLLM 推理内核深度解析

第4章 PagedAttention:虚拟内存的启示

作者 杨艺韬 · 8,645 字

第4章 PagedAttention:虚拟内存的启示

“Good artists copy, great artists steal.” — Steve Jobs(引用毕加索)

本章要点

  • 量化理解传统 KV Cache 的三类浪费:预留浪费、内部碎片、外部碎片;论文里的利用率数字只作为历史测量,不直接外推到任意部署
  • 把操作系统虚拟内存的核心机制(页帧 / 虚拟页 / 页表 / 按需分页 / 页面回收 / COW)逐个映射到 PagedAttention
  • 读懂 Block 的物理布局:同一个逻辑 block 在 FlashAttention、FlashInfer、Pallas、MLA 后端里可以有不同张量形状
  • 掌握 BlockTable 的 CPU/GPU 双缓冲表示,以及 Worker 是如何在每步执行前把它提交到 GPU 的
  • 走完 V1 attention 调用链:slot_mapping 写入 KV cache,block_table 交给 backend kernel 做 paged attention
  • 理解 Online Softmax 的数值稳定性技巧——“running max + rescale”为什么适合分 block 流式扫描
  • 看清前缀缓存共享块的引用计数协议,以及它和旧版”并行采样 COW”概念之间的边界
  • 学会估算一个给定模型在特定 GPU 下的 KV 池容量与最大并发上限
  • 理解 V1 把 PagedAttention 抽象成”Attention Backend”的工程意义,以及 FlashAttention、FlashInfer、Triton、Pallas、MLA 如何接入

4.1 KV Cache:甜蜜的负担

要理解 PagedAttention 为什么重要,必须先把 KV Cache 的内存压力量化清楚。

4.1.1 每 Token 到底占多大

Transformer 的注意力机制需要访问所有历史 Token 的 K、V。为了避免每步都从 Embedding 开始重算,每个 Token 第一次被处理时其 K、V 向量被缓存起来——这就是 KV Cache。

对 Llama-2-13B(FP16):

  • num_layers = 40
  • num_kv_heads = 40(这个模型 Q 头和 KV 头相等;GQA / MQA 模型这个值会小于 Q 头数)
  • head_dim = 128
  • dtype = 2 bytes(FP16)

每 Token 的 KV 成本:

2 (K+V)×40 层×40 头×128 维×2 B=819,200 B800 KB2 \text{ (K+V)} \times 40 \text{ 层} \times 40 \text{ 头} \times 128 \text{ 维} \times 2 \text{ B} = 819{,}200 \text{ B} \approx 800 \text{ KB}

一个 2048-token 序列 = 1.6 GB KV。一张 80 GB A100 扣掉 26 GB 模型权重,理论上还能塞 54GB / 1.6GB ≈ 33 个这样的请求。听起来不错?

但这个 33 只是”所有序列都刚好 2048 token、没有任何碎片、没有调度余量”时的上界。真实服务里,请求长度分布有长尾,输出长度在请求开始时未知,请求结束时间也不同步。传统连续分配方案的问题,不是某个单独请求算错了,而是整个系统在动态负载下很难把每一段显存都装满。

4.1.2 三类浪费的量化

预留浪费(Reserved Waste)——如果系统在请求开始时按最大可能长度预留连续 KV 空间,它必须为”尚未生成、也许永远不会生成”的 token 留位置。最大长度通常来自 max_model_len 或服务侧的输出上限;但一个请求实际输出可能是 50、500、或 2048 token。按最坏情况预留,长尾请求会把短请求的余量一起锁住。

内部碎片(Internal Fragmentation)——即使不预留整块,按”已生成长度”分配也会向上取整到某个粒度(比如 64 token 对齐)。一个 100-token 的请求分 128 slot,浪费 28 个 slot(22%)。

外部碎片(External Fragmentation)——请求完成后释放 KV,留下离散的”空洞”。新请求要求的连续空间大于任一空洞时,即使空洞总量够,也分配不出来。

PagedAttention 论文在若干工作负载上报告过传统方案较低的有效 KV 利用率,常被引用的一组范围是 20.4% ~ 38.2%。这不是当前 vLLM 在任意模型和任意集群上的固定结论,而是用来说明连续分配在动态请求流里会被预留、对齐和外部碎片共同拖垮。工程上更稳妥的读法是:只要 KV 以”请求级连续大块”为单位管理,显存浪费就会随长度方差和请求 churn 放大。

graph TB
    subgraph "连续 KV Cache 的浪费来源(示意)"
        P["KV 预算"]
        R["预留浪费<br/>为未发生的 token 留位置"]
        I["内部碎片<br/>按分配粒度向上取整"]
        E["外部碎片<br/>释放后留下离散空洞"]
        U["有效 KV<br/>真实 token 的 K/V"]
        P --> R
        P --> I
        P --> E
        P --> U
    end

    style R fill:#ef4444,color:#fff,stroke:none
    style I fill:#f59e0b,color:#fff,stroke:none
    style E fill:#3b82f6,color:#fff,stroke:none
    style U fill:#10b981,color:#fff,stroke:none

这就是早期 LLM 推理的”显存税”:真正昂贵的不是某个 attention kernel 多做了几次乘加,而是最贵的硬件资源被长度不确定性和连续分配约束困住了。

4.2 1961 年的那台机器

1961 年曼彻斯特大学的 Atlas 计算机面对的是同一类问题:程序对内存的需求参差不齐,直接分配物理内存既浪费又僵化。工程师们发明了虚拟内存——把内存分成固定大小的页,每个程序看到的是连续的虚拟地址空间,页表把虚拟页映射到物理页。这个思想后来被 IBM S/360、Multics、Unix 全面采纳,六十年间成为所有现代操作系统的基石。

PagedAttention 团队的核心洞察:GPU 显存管理和 1961 年的物理内存管理,结构上是同一个问题

操作系统概念PagedAttention 对应动因
物理页帧(Page Frame)KV Cache Block固定大小、统一管理、无外部碎片
虚拟页(Virtual Page)请求里 Token 的逻辑分段让上层看到连续,底下随意
页表(Page Table)Block Table间接寻址是全部灵活性的来源
按需分页(Demand Paging)KV Block 按需分配不预留就不浪费
页面回收请求结束释放 Block / LRU 驱逐前缀缓存动态归还空间
Copy-on-Write并行采样 / 前缀缓存的共享块 COW读多写少时节省 N-1 倍空间
地址转换缓存的思想GPU 端保留 block_table 张量并让 kernel 直接读把查表成本变成一次小规模元数据读取,而不是 CPU 参与的动态地址转换

一旦类比建立,PagedAttention 的所有设计选择都顺理成章。

4.3 Block:KV Cache 的物理单元

4.3.1 一个 Block 的形状与大小

KV Cache 被切成固定大小的 Block,每个 Block 存 B 个 Token 的 K-V 对。默认 B=16。但 Block 的张量形状是什么样子?这个看似基础的问题在 vLLM V1 里没有单一答案——它依赖于具体使用的 Attention Backend。最主流的 FlashAttentionBackend.get_kv_cache_shapevllm/v1/attention/backends/flash_attn.py:59)给的形状是:

return (2, num_blocks, block_size, num_kv_heads, head_size)

五维张量、最前面多一个 2 用来把 K 和 V 打包到同一块连续内存里(index 0 是 K、index 1 是 V)。对 Llama-2-13B(num_kv_heads=40head_size=128block_size=16、FP16)、一个 Block 占 2 × 16 × 40 × 128 × 2 B = 320 KB

整个 GPU 上存在 num_layers = 40 套这样的 5D 张量(每层一套)——每层都是它自己的 (2, num_blocks, block_size, num_kv_heads, head_size)。所以一个 “complete block”(覆盖全部层的那个逻辑块)的真实成本是 40 × 320 KB = 12.5 MB。一张 A100 上如果我们有 50 GB 做 KV Cache,大约能容纳 50 GB / 12.5 MB ≈ 4000 个 complete block。

4.3.2 为什么 block_sizenum_kv_heads 之前?K 和 V 怎么摆?

(2, num_blocks, block_size, num_kv_heads, head_size) 这个维度顺序不是随意定的,但也不要把它解释成”vLLM 的唯一 block 物理规范”。它是 FlashAttention/Triton 后端暴露给 V1 worker 的 KV cache 形状,服务的是该后端的读写路径:

  • 2 放最外层kv_cache.unbind(0) 会得到 key_cachevalue_cache 两个视图。flash_attn.py:560-566 正是这样把当前步新算出的 K/V 写入 cache:先 unbind,再调用 reshape_and_cache_flash(..., attn_metadata.slot_mapping, ...)。所以 K/V 语义上分区,但物理上仍来自同一个 per-layer tensor。
  • num_blocks:block table 查表得到的就是物理 block 号,这一维是分页分配的落点。调度器和 KVCacheManager 不需要知道后面几维怎么排,只需要维护 block id。
  • block_sizenum_kv_heads:这两维共同决定一个物理页内的 token/head 排列。FlashAttention 后端选择先 token 后 head,FlashInfer 与 Pallas 又各有不同排列;正确的原则不是背某个固定 stride,而是让 backend 的 kernel 按自己需要的内存访问方式定义形状。
  • head_size 最内:同一 head 的 channel 连续,是几乎所有后端都会保留的局部性,因为 QK 点积和 V 加权都沿 head 维做向量运算。

4.3.3 “形状”本身是 backend 变量——不是常量

要再强调一遍:上面这个 5D 形状是 FlashAttention 后端的选择,不是 vLLM 的普适规约。把五个 backend 的 get_kv_cache_shape 放在一起对比(源码位置都在 vllm/v1/attention/backends/*.py):

Backend返回形状说明
FlashAttn / Triton(2, num_blocks, block_size, num_kv_heads, head_size)K/V 打包在最外维
FlashInfer(num_blocks, 2, block_size, num_kv_heads, head_size)K/V 打包在第二维
Pallas(TPU)(num_blocks, block_size, num_kv_heads*2, head_size)K/V 并排在 head 维;4D
MLA(Deepseek)(num_blocks, block_size, head_size)单个共享 latent KV、无 head 维;3D

各 backend 的选择都对应它”底层 kernel 希望以何种 stride 访问”的偏好。Pallas 把 K/V 塞进 head 维度是因为 TPU 的 PaLM-style kernel 把它们当一个宽 head 一次性读完更高效;MLA 的 3D 形状则反映了 Deepseek V2/V3 的 Multi-head Latent Attention——K/V 压缩成同一个 c_kv 潜向量,没有独立的 K 和 V 之分、也没有 num_kv_heads 这个维度(每个 query head 都和这一个 latent 算 attention)。

这正是 4.9 节 Attention Backend 抽象的意义:block 的物理布局由 backend 决定,调度器和 KV Cache Manager 只操作”抽象的 block”——以 num_blocks 为粒度申请、释放、换位。Layout 的异构性被 get_kv_cache_shape 这个抽象点收束,上层不需要在调度逻辑里展开 TPU 和 GPU 的内存排列差异。

4.3.4 为什么 block_size = 16

块大小是一个需要非常精细权衡的参数。当前源码里,CacheConfig.block_size 没有静态默认值;如果用户不指定,平台层会补默认值。CUDA 平台在 vllm/platforms/cuda.py:141-143 把空值设成 16;如果模型启用 MLA 且走 FlashMLA,cuda.py:147-158 又会把 block size 强制成 64。也就是说,“默认 16”只适用于常规 CUDA attention 路径,不是所有硬件和所有模型的铁律。

  • 太小B=1B=4):每次 token 访问都要查一次 block table 间接,索引开销占主导;block table 超长,放不进 shared memory 或寄存器。
  • 太大B=128B=256):内部碎片严重(最后一个 block 平均浪费 B/2 个 slot);块之间的”预热/启动开销”(比如一个 block 跑满一个 warp)变不明显。
  • 后端约束:FlashAttention 和 Triton 的 get_kv_cache_shape 都显式要求 block_size % 16 == 0;Pallas 在 TPU 上还会根据 max_model_lenmax_num_seqs 和 shared memory 约束计算最小 page size。

所以本章后面用 block_size=16 做例子,是为了方便计算和贴近常规 CUDA 默认,而不是暗示 16 是所有部署的性能最优点。真正的工程判断要同时看硬件、attention backend、模型是否 MLA、最大上下文长度,以及调度器允许的最大并发。

4.3.5 预分配:KV 池不在请求路径上动态增长

启动时,vLLM 根据 KVCacheConfig 为每一层创建 KV cache tensor。V1 的真实路径在 GPUModelRunner.initialize_kv_cache()vllm/v1/worker/gpu_model_runner.py:1689-1733):

# vllm/v1/worker/gpu_model_runner.py(节选化简)
for layer_name in kv_cache_group.layer_names:
    tensor_config = kv_cache_config.tensors[layer_name]
    num_blocks = tensor_config.size // kv_cache_spec.page_size_bytes
    kv_cache_shape = self.attn_backend.get_kv_cache_shape(
        num_blocks,
        kv_cache_spec.block_size,
        kv_cache_spec.num_kv_heads,
        kv_cache_spec.head_size,
    )
    kv_caches[layer_name] = torch.zeros(
        kv_cache_shape,
        dtype=kv_cache_spec.dtype,
        device=self.device,
    )

这里有两个细节值得抓住。

第一,KV cache 是按 layer name 分配 tensor,不是一张全局 key_cache 加一张全局 value_cache。每层的 tensor 形状由 attention backend 决定,随后通过 bind_kv_cache() 绑定到模型静态 forward context。第二,kv_cache_config.num_blocks 是 KVCacheManager 可分配的全局下限,而某个 worker 实际为某层创建的 num_blocks 可以更大;源码在 gpu_model_runner.py:1709-1716 专门断言本地 num_blocks >= kv_cache_config.num_blocks,因为不同 GPU 可能层数和可用显存不同。

因此,更准确的说法不是”整个进程生命周期没有任何 cudaMalloc / cudaFree”,而是:请求级 KV block 的申请/释放不走 CUDA allocator。“分配”一个 block 是从 BlockPool 的空闲队列里取一个 block id;“释放”是递减引用计数,必要时把 block 放回 free queue。这个决策消除了三类请求路径风险:

  1. 碎片复发——动态 cudaMalloc 在多次分配后会产生物理碎片,PyTorch caching allocator 也不能完全解决。预分配后,请求长度变化不再转化成一批大小不同的 CUDA allocations。
  2. OOM 波动——运行时 cudaMalloc 可能因为 GPU 其他进程临时抢用显存而失败。启动时一次性申请,要么立刻失败(好诊断),要么一劳永逸。
  3. Allocator 开销——请求调度的 hot path 不需要为每个请求向 CUDA allocator 要连续 KV 空间,只更新 block 元数据和 block table。

代价是启动阶段必须准确估算 KV 预算。这个估算不在本章展开;你只需要记住它的输出会落到 KVCacheConfig.tensors[layer_name].sizenum_blocks 上,worker 按这个配置创建物理 KV 池,scheduler 只在这个池内分配 block。

4.4 Block Table:逻辑到物理的地址簿

4.4.1 数据结构

每个运行中的请求都有一张 Block Table。它是一个一维数组,第 i 个元素是第 i 个逻辑块映射到的物理 block 号:

Request A 的 Block Table: [7, 3, 12, 25]
  逻辑块 0 (token 0-15)  →  物理 block 7
  逻辑块 1 (token 16-31) →  物理 block 3
  逻辑块 2 (token 32-47) →  物理 block 12
  逻辑块 3 (token 48-63) →  物理 block 25

物理上,V1 有一个 BlockTable 对象维护 CPU/GPU 两份二维张量:block_table_cpublock_table,形状都是 [max_num_reqs, max_num_blocks_per_req]vllm/v1/worker/block_table.py:25-37 创建这两份 buffer,append_row()/add_row() 只改 CPU 侧 numpy 视图,commit(num_reqs) 再把前 num_reqs 行 copy 到 GPU(同文件 :69-71)。这样做的意义是:调度器可以在 CPU 上便宜地拼 row,GPU kernel 看到的是固定形状的 dense table。

graph TB
    subgraph "请求 A 的 Block Table"
        direction LR
        LA0["逻辑块 0<br/>Token 0-15"] --> PA7["物理 7"]
        LA1["逻辑块 1<br/>Token 16-31"] --> PA3["物理 3"]
        LA2["逻辑块 2<br/>Token 32-47"] --> PA12["物理 12"]
        LA3["逻辑块 3<br/>Token 48-63"] --> PA25["物理 25"]
    end

    subgraph "请求 B 的 Block Table"
        direction LR
        LB0["逻辑块 0<br/>Token 0-15"] --> PB1["物理 1"]
        LB1["逻辑块 1<br/>Token 16-31"] --> PB9["物理 9"]
    end

    subgraph "共享物理 BlockPool(一张 GPU 张量)"
        direction LR
        P0["块 0"]
        P1["块 1<br/>B"]
        P2["块 2"]
        P3["块 3<br/>A"]
        P7["块 7<br/>A"]
        P9["块 9<br/>B"]
        P12["块 12<br/>A"]
        P25["块 25<br/>A"]
    end

    style LA0 fill:#3b82f6,color:#fff,stroke:none
    style LA1 fill:#3b82f6,color:#fff,stroke:none
    style LA2 fill:#3b82f6,color:#fff,stroke:none
    style LA3 fill:#3b82f6,color:#fff,stroke:none
    style LB0 fill:#10b981,color:#fff,stroke:none
    style LB1 fill:#10b981,color:#fff,stroke:none

关键点:请求 A 的 4 个块在物理上不连续(7、3、12、25),但通过 Block Table 的间接寻址,attention kernel 能在计算时把它们当作一个连续的逻辑序列处理。这就是虚拟内存在 GPU 上的翻版。

4.4.2 Block Table 怎么送到 GPU

一次 schedule -> execute 的数据准备看起来是这样的:

sequenceDiagram
    participant S as Scheduler (CPU)
    participant MR as GPUModelRunner._update_states
    participant BT as BlockTable
    participant PI as GPUModelRunner._prepare_inputs
    participant K as Attention Backend

    S->>MR: SchedulerOutput (new_block_ids / block_ids)
    MR->>BT: add_row 或 append_row 更新 CPU 表
    PI->>BT: commit(num_reqs) 先拷 block table
    PI->>PI: 计算 positions / input_ids / slot_mapping
    PI->>K: builder 读取 GPU block_table 与 slot_mapping
    K->>K: backend kernel 按 block_table 访问 KV cache

源码里的顺序很值得看。GPUModelRunner._update_states() 在新请求进入 batch 时调用 self.input_batch.add_request(),后者把 request.block_ids 写入 BlockTable.add_row()gpu_input_batch.py:231-267)。运行中的请求如果分到了新 block,则 _update_states() 调用 self.input_batch.block_table.append_row(req_data.new_block_ids, req_index)gpu_model_runner.py:418-439)。真正执行前,_prepare_inputs() 第一件事就是 self.input_batch.block_table.commit(num_reqs)gpu_model_runner.py:498-500),注释写得很直白:先拷 block table,才能把 GPU copy 和后续 CPU 计算重叠。

这里不要脑补成”每步临时创建一个二维 tensor 再 .to('cuda')”。V1 是复用固定 buffer:CPU 侧 row 增量更新,GPU 侧批量提交。这个细节解释了为什么 block table 虽然引入了间接寻址,但元数据管理仍能放进每步调度路径。

4.4.3 Slot Mapping:新 Token 写到哪

除了 block table,每步还要告诉写 cache 的 op:“本步新 token 应该写进 KV Cache 的哪个槽位”。这个信息叫 slot_mapping

假设请求 A 当前长度是 50,本步新算 1 个 token(第 51 个)。第 51 个 token 归属逻辑块 50 // 16 = 3,在逻辑块内偏移 50 % 16 = 2。请求 A 的 block_table[3] = 物理块 25,所以 slot = 25 × 16 + 2 = 402。kernel 负责把新算出的 K[51] 和 V[51] 写到 key_cache[402]value_cache[402]

slot_mapping 是一个 [total_num_scheduled_tokens] 的一维 int32 张量。V1 在初始化时预分配 slot_mapping_cpu,并暴露 numpy 视图 slot_mapping_npgpu_model_runner.py:269-273)。每步 _prepare_inputs() 用以下逻辑计算它(gpu_model_runner.py:551-564):

block_table_indices = (
    req_indices * self.max_num_blocks_per_req
    + positions_np // self.block_size
)
block_numbers = block_table_cpu.flatten()[block_table_indices].numpy()
block_offsets = positions_np % self.block_size
np.add(block_numbers * self.block_size, block_offsets, out=slot_mapping_np)

这段代码说明了一个容易混淆的点:slot_mapping 不是”第几个请求、第几个逻辑块”的二维信息,而是把二维 block table 查表结果提前压成了一维物理 slot。后端写 KV cache 时只要按这个 slot 写,不需要再回头理解请求结构。

4.5 PagedAttention 内核:在离散 Block 上算注意力

有了 Block Table 和 Slot Mapping,backend 就能在非连续 KV 上执行注意力计算。这里要避免把概念伪代码误写成 vLLM 当前源码。V1 的调用链分成两个动作:

  1. 把当前步 K/V 写进 KV cache。FlashAttention 后端在 flash_attn.py:555-570key_cache, value_cache = kv_cache.unbind(0),再调用 torch.ops._C_cache_ops.reshape_and_cache_flash(..., attn_metadata.slot_mapping, ...)。Triton 后端对应调用 PagedAttention.write_to_paged_cache()triton_attn.py:145-159)。两者共同点是:写入位置由 slot_mapping 决定。
  2. 用 block table 读取历史 KV 做 attention。FlashAttention 后端随后调用 flash_attn_varlen_func(),把 k=key_cachev=value_cachecu_seqlens_qseqused_kmax_seqlen_kblock_table=block_table 传进去(flash_attn.py:606-626)。Triton 后端则把同样的 block_table 交给 chunked_prefill_paged_decode()triton_attn.py:179-195)。

用概念伪代码表示,就是:

# 写 KV:slot_mapping 是物理 slot
write_kv_cache(key, value, kv_cache, slot_mapping)

# 读 KV:block_table 把逻辑块映射到物理块
for request in batch:
    for logical_block in used_blocks(request.seq_len):
        phys = block_table[request.row, logical_block]
        k_block, v_block = kv_cache[phys]
        update_attention_state(q, k_block, v_block)

几个关键工程点:

(1) slot_mappingblock_table 分工不同slot_mapping 只描述本步新 token 的写入位置;block_table 描述完整历史上下文的读取路径。把两者混成”kernel 输入的一张地址表”会错过 V1 的数据流:先写当前步 K/V,再按 block table 读历史 K/V。

(2) 间接寻址是 backend 接口的一部分。FlashAttentionMetadata 里有 block_tableslot_mapping 字段(flash_attn.py:80-87),PallasMetadata 也有 slot_mappingblock_tablescontext_lenspallas.py:79-84),MLACommonMetadata 则把 prefill/decode 的 block_table 分进不同 metadata 对象(mla/common.py:256-307)。这说明 paged attention 在 V1 不是某个单独 CUDA 文件的私有实现,而是 attention backend 共同遵守的数据契约。

(3) GQA/MLA 等模型差异被 backend 吃掉。普通 GQA 模型的 KV head 数小于 Q head 数,backend 在内部按自己的 kernel 规则复用 KV;MLA 模型甚至把 KV cache 改成 (num_blocks, block_size, head_size) 的 latent 形状。调度器仍然只分配 block id,说明分页抽象没有把模型结构泄漏到 scheduler。

(4) 性能判断要回到真实后端。可以直觉上说 block table 读取量远小于 K/V 数据量,但本章不把”额外开销小于某个百分比”写成源码事实。真正要比较开销,需要固定模型、backend、block size、上下文长度和 batch 形态后做基准;章节里的重点是结构正确性,而不是伪造一个通用性能数字。

4.6 Online Softmax:不回头的数值稳定归一化

Paged 的数据结构让我们无法一次性拿到整个 context 的所有 scores——必须分 block 流式处理。这和标准 softmax 的两阶段(先遍历找 max,再遍历 exp 归一化)直接冲突。

Online Softmax 算法(Milakov & Gimelshein, 2018)解决这个问题:只扫一遍数据、在扫描过程中维护 running max 和 running sum,最终得到数值稳定的 softmax。

核心递推:设前 n 个 block 的 logsumexp 状态为 (m_n, s_n),第 n+1 个 block 的局部 max 是 m'、局部 sum 是 s',则合并后的状态:

m_{n+1} = max(m_n, m')
s_{n+1} = s_n * exp(m_n - m_{n+1}) + s' * exp(m' - m_{n+1})

Value 加权也对应更新:

O_{n+1} = O_n * exp(m_n - m_{n+1}) + Σ_t p'_t * V_t

其中 exp(m_n - m_{n+1}) 就是对之前累积结果的 rescale 因子。

为什么要这样做?有两个动因:

(1) 数值稳定。标准 softmax exp(x_i) / Σ exp(x_j)x_i 较大时会 overflow FP16。减去一个 shift(经典做法是全局 max)能避免溢出。online softmax 用 running max 做这个 shift,保证扫描过程中每一步的 exp 都不溢出。

(2) 单趟扫描。分 block 处理时我们不希望为了找全局 max 而遍历所有 block 两次。online 的方式做到了真正的”从头到尾一遍扫完”。

一个容易被忽略的细节:kernel 里的累加精度通常高于输入/输出精度。FP16/BF16 适合存储和矩阵乘,但 softmax 的 max、sum、rescale 与 value 累加如果全用低精度,很容易把长上下文误差放大。工程实现通常会在局部统计量和累加器上使用 FP32 或等价的高精度路径;目标不是保证和某个参考实现逐 bit 相同,而是在性能和数值稳定之间取得可验证的平衡。

4.7 共享块与引用计数:从 COW 概念到 V1 实现

Paged 的分页结构给了我们一个额外能力:当多个请求有完全相同的前缀(或者说,多个序列在某段逻辑上指向同一段 KV),它们可以共享同一组物理 block。共享是否需要”写时复制”,取决于后续是否会写到同一个半满 block。这里要区分论文概念和当前 V1 源码中能直接看到的机制。

4.7.1 场景 A:并行采样(n > 1)

PagedAttention 论文讲 COW 时常用并行采样作为例子:用户请求 n=3,同一个 prompt 产生 3 个独立样本。三个样本共享 prompt 的 KV,分化点在 decode 开始:

graph TB
    subgraph "共享 Prompt 物理 Block(ref_cnt = 3)"
        P0["块 100<br/>token 0-15"]
        P1["块 101<br/>token 16-31"]
        P2["块 102<br/>token 32-47"]
    end

    subgraph "Sample 1 block_table"
        S1a["逻辑 0"] --> P0
        S1b["逻辑 1"] --> P1
        S1c["逻辑 2"] --> P2
        S1d["逻辑 3<br/>decode 独有"] --> P200["块 200"]
    end

    subgraph "Sample 2 block_table"
        S2a["逻辑 0"] --> P0
        S2b["逻辑 1"] --> P1
        S2c["逻辑 2"] --> P2
        S2d["逻辑 3<br/>decode 独有"] --> P201["块 201"]
    end

    subgraph "Sample 3 block_table"
        S3a["逻辑 0"] --> P0
        S3b["逻辑 1"] --> P1
        S3c["逻辑 2"] --> P2
        S3d["逻辑 3<br/>decode 独有"] --> P202["块 202"]
    end

    style P0 fill:#8b5cf6,color:#fff,stroke:none
    style P1 fill:#8b5cf6,color:#fff,stroke:none
    style P2 fill:#8b5cf6,color:#fff,stroke:none
    style P200 fill:#3b82f6,color:#fff,stroke:none
    style P201 fill:#10b981,color:#fff,stroke:none
    style P202 fill:#f59e0b,color:#fff,stroke:none

Prompt 有 48 个 token = 3 个 block。不共享需要 3 × 3 = 9 块,共享只需要 3 块。prompt 越长,理论节省越多。

但在校对当前 vLLM V1 源码时,不能把这个图直接等同于”V1 对并行采样一定用同一套 COW 代码路径”。本节保留它,是因为它解释了 PagedAttention 为什么天然适合共享前缀;真正可由当前源码直接锚定的,是下一小节的前缀缓存共享与 BlockPool 引用计数。

4.7.2 场景 B:前缀缓存(Prefix Caching)

一个更常见的场景:不同请求之间共享相同的前缀。典型案例——所有请求都带同一段 system prompt(“你是一个专业的编程助手……“2000 字),或者 RAG 把同一段 context 拼到每个请求前。

前缀缓存(第 10 章详讲)会对每个 block 的 token 内容做哈希,两个请求的前几个 block 如果内容完全一致、哈希也一致——就让后来的请求的 block_table 直接指向已有的物理 block,ref_cnt 加 1。

V1 的实现不是上面那种 block_ref_cnts: dict[int, int] 的手写字典,而是 BlockPool 管理 KVCacheBlock 对象。关键路径是:

# vllm/v1/core/kv_cache_manager.py + block_pool.py(节选化简)
computed_blocks = specialized_manager.find_longest_cache_hit(block_hashes)
block_pool.touch(computed_blocks)       # cache hit: ref_cnt + 1
req_blocks.extend(computed_blocks)

new_blocks = block_pool.get_new_blocks(num_new_blocks)
req_blocks.extend(new_blocks)           # 新 token 写入新 block

block_pool.cache_full_blocks(...)       # 满块写入 hash 索引
block_pool.free_blocks(reversed(blocks))# 请求结束: ref_cnt - 1

KVCacheManager.get_computed_blocks() 会通过 block hash 找到可复用的满块,并返回 computed_blocks 与已计算 token 数(kv_cache_manager.py:92-162)。allocate_slots() 接着检查剩余容量,调用 block_pool.touch(new_computed_blocks) 防止命中的块被驱逐,再把它们追加到当前请求的 req_blockskv_cache_manager.py:226-247)。如果还需要新块,block_pool.get_new_blocks() 从空闲队列取 block,并在取出时 incr_ref()block_pool.py:158-188)。

释放也不是删除 tensor,而是引用计数协议。KVCacheManager.free() 取出该请求持有的 blocks;开启 prefix caching 时按反序释放,让尾部块先成为驱逐候选(kv_cache_manager.py:296-312)。BlockPool.free_blocks() 对每个 block decr_ref();当 ref_cnt == 0 且不是 null_block 时,把它 append 回 free_block_queueblock_pool.py:227-239)。这就是当前源码里能明确看到的共享块生命周期。

4.7.3 COW 的边界情况

部分写入:只有在”追加写到半满共享 block”时才需要真正的 copy。当前 V1 前缀缓存只复用满块,get_computed_blocks() 注释明确说 computed blocks must be full,且 num_computed_tokensblock_size 的倍数(kv_cache_manager.py:94-101:158-161)。因此 prefix caching 的主路径是”共享满块 + 新写入另分 block”,不是频繁复制半块。

前缀缓存驱逐BlockPool 的 free queue 同时保存真正空闲块和可驱逐的 cached block。命中缓存时,touch() 如果发现 ref_cnt == 0,会先把它从 free queue 移除再加引用(block_pool.py:212-225);新分配块时,如果取出的块带有 hash,_maybe_evict_cached_block() 会清掉 hash 索引(block_pool.py:176-183:190-210)。所以源码里的不变量是:正在被请求引用的块不会留在 free queue 里;只有 ref_cnt == 0 的 cached block 才能作为驱逐候选。

跨 TP rank:本章不再写”主 rank 广播 share/free”这类没有源码锚点的细节。可以确定的是,kv_cache_config.num_blocks 会取各 GPU 可用 blocks 的下限,worker 初始化时校验本地 num_blocks >= kv_cache_config.num_blocksgpu_model_runner.py:1709-1716)。至于多 rank 上 block 元数据如何同步,要回到调度器和分布式章节展开,不能在 PagedAttention 章里用一句话代替源码证明。

4.8 一次容量估算练习

容量估算最容易写成”看起来精确、实际上依赖一堆隐藏条件”的数字。这里我们把它写成公式,并明确所有假设。给定:

  • GPU: A100-80GB
  • Model: Llama-3-70B FP16,权重约 140 GB,假设 TP=2 后每卡约 70 GB
  • 非 KV 运行时余量:假设 4 GB(激活、临时 workspace、CUDA graph 等)
  • 每卡 KV 预算:80 - 70 - 4 = 6 GB

KV per token per layer per rank: 2 (K+V) × (num_kv_heads / TP) × head_dim × 2B = 2 × 4 × 128 × 2 = 2 KB (Llama-3-70B 用 GQA, num_kv_heads=8, TP=2 后每 rank 4 个 head)

每 token 总 KV(所有层)= 2 KB × 80 层 = 160 KB

每卡 KV 容量(以 token 计)= 6 × 1024 × 1024 KB / 160 KB39,000 token

block_size=16 分块:39,000 / 16 ≈ 2400 blocks

这个数字决定了 gpu_memory_utilization 配置、能同时服务的请求数、以及前缀缓存的容量预算。但它仍然只是手算。真实 vLLM 会按平台、权重、激活峰值和配置算出 KVCacheConfig,最终落到每层 tensor 的 size 和全局 num_blocks。手算的价值,是让你在配置前就知道数量级是否合理。

一个常用的粗算:

单请求平均 KV token = prompt_avg_len + output_avg_len
= 800 + 256 = 1056 token (假设在线对话场景)

期望并发 32 → 需要 32 × 1056 = 33,792 token ≈ 2112 block

这个数字和上面的 2400 block 接近,说明在这些假设下,32 并发已经接近 KV 预算边界。要提高并发,选项通常是降低平均上下文、提高 TP/PP 改变每卡 KV/权重分布、使用更低精度 KV cache、换更大显存 GPU,或者接受更激进的抢占/排队策略。不能只凭这一个估算就得出”必须换某张卡”的结论。

4.9 V1 的 Attention Backend 抽象

PagedAttention 不是一个孤立的 kernel——它是 V1 “Attention Backend” 抽象的一个实例。vllm/v1/attention/backends/ 里目前有多个 backend:

flash_attn.py          # FlashAttention 后端,常规 CUDA 主路径
flashinfer.py          # FlashInfer paged KV 后端
triton_attn.py         # Triton / paged attention 路径
pallas.py              # TPU/Pallas 后端
mla/common.py          # MLA 共享逻辑
mla/flashmla.py        # FlashMLA 路径
mla/triton_mla.py      # Triton MLA 路径

每个 backend 都要实现同一套接口:

class AttentionBackend(ABC):
    @staticmethod
    def get_name() -> str: ...
    @staticmethod
    def get_impl_cls() -> Type[AttentionImpl]: ...
    @staticmethod
    def get_metadata_cls() -> Type[AttentionMetadata]: ...
    @staticmethod
    def get_kv_cache_shape(num_blocks, block_size, num_kv_heads, head_dim) -> tuple[int, ...]: ...
    # ...

调度器和执行器不依赖具体 backend 的 KV tensor 内部布局。它们操作的是 block id、slot_mapping 和 backend metadata;具体计算由 backend 选择。这意味着同一套 vLLM V1 框架能覆盖不同 kernel 和硬件路径:常规 CUDA 走 FlashAttention 或 Triton,FlashInfer 有自己的 paged KV wrapper,TPU 走 Pallas,DeepSeek 这类 MLA 模型走专用 MLA 后端。

PagedAttention 作为”虚拟内存思想”的这一层抽象,在 backend 层面上依然有效。具体字段名有单复数差异(例如 Pallas metadata 里叫 block_tables,FlashAttention metadata 里叫 block_table),但语义稳定:写入靠 physical slot,读取靠 logical block -> physical block 的表。

4.9.1 实测:vllm/v1/attention/ 8 文件 3115 行的 backend 矩阵

§4.9 抽象讨论 “AttentionBackend 是 V1 的统一抽象层”——把目录按文件实测——

路径角色
vllm/v1/attention/backends/flash_attn.py817FlashAttention 2/3 后端——最常用、本章 §4.5 主参照
vllm/v1/attention/backends/mla/common.py974MLA backend 共享逻辑——本目录最大单文件、专给 DeepSeek-V2/V3 的 Multi-head Latent Attention
vllm/v1/attention/backends/flashinfer.py638FlashInfer 后端——支持更多 attention 变体 + 专用 prefix cache 优化
vllm/v1/attention/backends/pallas.py221Google TPU 后端——跨硬件 paged attention 的具体实现
vllm/v1/attention/backends/triton_attn.py198Triton 纯 Python 实现——fallback 路径
vllm/v1/attention/backends/mla/flashmla.py149MLA + FlashAttention 组合
vllm/v1/attention/backends/mla/triton_mla.py118MLA + Triton 组合
__init__.py × 30公共 export(全空,按需 import)
合计3115

两条值得记住的物理事实——

  1. MLA 子目录 1241 行(common 974 + flashmla 149 + triton_mla 118)= 整个 v1/attention/ 目录的 40%——专给 DeepSeek 一类 MLA 模型。这说明”PagedAttention 只是普通 MHA 的分页版”已经不够描述 V1:同样的 block 管理接口下面,模型结构特殊化会长出相当重的 backend 代码。
  2. pallas.py 221 行是 TPU 后端——Google TPU 的 PagedAttention 实现印证 §4.9 的抽象边界:TPU 这种架构完全不同的硬件,依然能在同一套 block table/slot mapping 语义下工作,只是 page size、shared memory 约束和 kernel 实现不同。

KV Cache 管理对应代码——vllm/v1/core/kv_cache_manager.py 385 + block_pool.py 281 = 666 行——是下一章 ch05 主角;vllm/attention/ops/paged_attn.py 255 行是 V0 遗留的 PagedAttention CUDA kernel 包装——V1 的 paged_attention 调用其实直接通过 backend 文件(如 flash_attn.py 内部)发起、不再走这个独立 ops 文件。

串联 §1.6.1 实测:本目录 3115 行 = vllm/v1/ 21050 行的 14.8%——和 worker (23%) / engine (20%) 同梯队的”重子系统”。这能解释为什么 PagedAttention 不能只按一篇论文里的 CUDA kernel 理解;在 V1 里,它已经变成了调度、KV 管理、backend 抽象共同维护的核心引擎能力。

4.10 本章小结

PagedAttention 是 vLLM 最核心的创新,它的智慧不在于发明新算法,而在于把一个 62 年前的操作系统思想干净地移植到 GPU 显存管理中

  • 问题——传统 KV Cache 分配有三类浪费(预留 / 内部 / 外部碎片);论文里的低利用率测量说明连续分配在动态请求流下会迅速变成显存瓶颈
  • 灵感——操作系统的虚拟内存分页机制;Block↔物理页、Block Table↔页表、COW↔Linux fork 的 COW
  • Block 设计——常规 CUDA 路径默认 16 token / 块;FlashAttention 的典型形状是 5D (2, num_blocks, block_size, num_kv_heads, head_size),但具体布局是 backend 变量(FlashInfer、Pallas、MLA 各不同)
  • 预分配——启动时按 KVCacheConfig 为每层创建 KV tensor;请求路径上的 block 分配只是 BlockPool 元数据操作
  • Block Table——V1 用 CPU/GPU 双缓冲二维张量维护请求到物理块的映射;commit(num_reqs) 后 backend kernel 读取 GPU 侧表
  • PagedAttention 调用链——slot_mapping 决定当前步 K/V 写到哪个物理 slot,block_table 决定 attention 读取历史 KV 的物理块序列
  • Online Softmax——running max + rescale,单趟扫描达到数值稳定的 softmax;累加精度是 kernel 正确性的关键
  • 共享块——前缀缓存通过 block hash 找满块、BlockPool.touch() 增引用、free_blocks() 降引用;论文里的 COW 思想要和当前 V1 主路径区分开
  • Backend 抽象——V1 把 PagedAttention 抽成统一的 AttentionBackend 接口,不同硬件 / 不同 kernel 可插拔,但”虚拟内存”这层抽象跨实现稳定

PagedAttention 论文报告过相对同期系统显著的吞吐提升,但对今天读源码的人来说,更重要的是它为后续优化(前缀缓存、分块预填充、投机解码、MLA、LoRA)奠定了地基:这些优化都需要”灵活地分配、共享、复用 KV”,而这正是 paged 抽象的自然能力。

物理事实:v1/attention/ 8 文件 3115 行(14.8% 是 v1/ 内”重子系统”梯队);MLA 子目录 1241 行(40%)专给 DeepSeek 比 FlashAttention 主路径 817 还重 51%;pallas.py 221 行 TPU 后端是本章源码导航未提的板块、印证 PagedAttention 跨硬件抽象在源码层依然有效。


延伸阅读

  • Kwon et al., “Efficient Memory Management for Large Language Model Serving with PagedAttention”, SOSP 2023 (arXiv:2309.06180)
  • Milakov & Gimelshein, “Online normalizer calculation for softmax”, arXiv:1805.02867
  • Dao et al., “FlashAttention / FlashAttention-2 / FlashAttention-3”, NeurIPS 2022/2023
  • Denning, “Virtual Memory”, ACM Computing Surveys, 1970

源码导航

  • AttentionBackend 抽象基类:vllm/attention/backends/abstract.py
  • FlashAttention paged 后端:vllm/v1/attention/backends/flash_attn.py
  • FlashInfer 后端:vllm/v1/attention/backends/flashinfer.py
  • Triton paged 路径:vllm/v1/attention/backends/triton_attn.py
  • Pallas TPU 后端:vllm/v1/attention/backends/pallas.py
  • MLA 后端共享逻辑:vllm/v1/attention/backends/mla/common.py
  • KV Cache Manager:vllm/v1/core/kv_cache_manager.py
  • BlockPool 引用计数与前缀缓存驱逐:vllm/v1/core/block_pool.py
  • Worker 侧 BlockTable:vllm/v1/worker/block_table.py
  • Block 布局与 get_kv_cache_shapevllm/v1/attention/backends/*