vLLM 推理内核深度解析
第16章 LoRA 适配器热切换:一个引擎承载多个微调
第16章 LoRA 适配器热切换:一个引擎承载多个微调
“Don’t change the whole model — just teach it a new trick.” — 所有参数高效微调(PEFT)方法的共同信条
“传统微调像重新装修整座房子,LoRA 像换一幅壁画——墙面不动,屋子立刻变样。”
本章要点
- 回顾 LoRA 的数学原理,理解为什么权重更新在微调中天然是低秩的
- 看清 vLLM LoRA 推理的核心架构:基座权重不变 + 多个低秩旁路叠加,额外成本由 rank、目标层和 active slot 数决定
- 读懂
LoRAModel/LoRAModelManager/WorkerLoRAManager三层抽象的职责划分 - 掌握 punica wrapper、BGMV/SGMV 元数据和 token-to-LoRA 映射在 batched LoRA 中的作用
- 理解
--max-loras和--max-cpu-loras的真实语义:调度并发上限、GPU slot、CPU LRU 容量三者相关但不相同 - 看懂 LoRA 对前缀缓存的影响:为什么
extra_key隔离是必须的、如何设计业务路由来缓解 - 掌握 QLoRA + PagedAttention + 多 LoRA 并发的三重组合部署方案
- 理解 LoRA 的局限:频繁切换开销、前缀缓存命中率下降、并非所有模型都支持
- 拿到三个实战场景的配置思路:RAG 租户切换、多领域 API 服务、开发者自定义 fine-tuning
16.1 LoRA 的数学原理:为什么低秩?
16.1.1 一个反直觉的事实
2022 年 LoRA 论文出来时,很多做微调的工程师第一反应是怀疑:“只训练两个低秩矩阵,真的能替代全量更新吗?”
不是魔法,是一个深刻的经验洞察——大模型在下游任务上需要调整的部分,本来就不是全空间。LoRA(Low-Rank Adaptation,Hu et al., ICLR 2022 / arXiv:2106.09685)的核心假设是:大模型在下游任务上的权重更新 处于低秩子空间中。
16.1.2 标准微调 vs LoRA 微调
标准微调:
和 同形状(比如 4096×4096)。训练时要反传梯度、存优化器状态、写 checkpoint——成本和训练原模型差不多。
LoRA:
把 分解成两个低秩矩阵: 形状 、 形状 ,其中 (典型 )。
→ 参数量只占标准微调的 0.78%。
16.1.3 推理时的叠加式计算
LoRA 的真正威力在于推理时不需要真的合并权重:
是缩放因子(训练时指定)。基座路径和 LoRA 旁路是独立的两条计算路径,最后相加。这个设计有三个关键好处:
- 基座权重不变——多个 LoRA 共享同一份基座;加载一个新 LoRA 只需要加载目标层上的 以及可选 bias / extra vocab
- 推理开销可由 rank 估算——对一个方阵线性层,旁路参数量约为 ,基座参数量为 ,所以 rank 越小,旁路越轻;但真实开销还要看目标层数量、batch 形态和 kernel 实现
- 批量并发友好——一个 batch 内不同请求可以用不同的 对,基座路径仍共享
graph LR
subgraph "LoRA 推理的双路径"
X[输入 x]
X --> BASE[基座 xW<br/>共享主路径]
X --> A_mat[LoRA A<br/>xA: small]
A_mat --> B_mat[LoRA B<br/>xAB: scaled]
BASE --> ADD[+]
B_mat --> ADD
ADD --> Y[输出 Y]
end
style BASE fill:#3b82f6,color:#fff,stroke:none
style A_mat fill:#10b981,color:#fff,stroke:none
style B_mat fill:#10b981,color:#fff,stroke:none
16.1.4 为什么权重更新真的低秩?
这不是任意的工程选择——有理论和经验证据支持:
- 理论角度:微调是在大模型已学得的表示空间里”微调方向”,本身就不需要高维度。大模型的表示能力集中在少数方向上(intrinsic dimension),下游任务只需要调整其中与任务相关的方向
- 经验角度:LoRA 原论文和后续 PEFT 实践都说明,很多下游任务在较小 rank 下就能得到可用效果;但最优 rank 与任务、数据规模、目标层和基座模型有关,不能把某个论文表格里的百分比直接搬成生产保证
这个低秩假设 = LoRA 的全部魔法源头。
16.1.5 rank 选择的经验规则
实战中选择 可以按任务复杂度从小到大试。下面是经验起点,不是质量承诺:
| 任务类型 | 推荐 | 理由 |
|---|---|---|
| 风格迁移(对话、翻译) | 4-16 | 分布漂移小,低秩足够 |
| 领域适配(法律/医学) | 16-32 | 需要学习新术语和结构 |
| 指令微调 | 16-64 | 涉及多样任务 |
| 多任务能力注入 | 64-128 | 接近”小型再训练” |
| 继续预训练式微调 | 不建议用 LoRA | 全量或 QLoRA 更合适 |
选择 的黄金法则:从小往大试。 是经验甜点,一半以上任务这个值就够。
16.2 “一个引擎服务 N 个微调”的经济学
16.2.1 成本对比
在传统推理服务架构下,如果要同时服务 10 个不同领域的微调模型,最朴素的做法是部署 10 套完整权重。这个方案的主要浪费在于:10 个模型的大部分权重完全相同,只是微调后参数略有差异。
LoRA 的容量账应该写成公式,而不是写成固定硬件结论:
完整微调: N * base_weight
LoRA: 1 * base_weight + N * adapter_weight + shared_runtime_state
其中 adapter_weight 近似正比于 rank * target_layers,而不是正比于完整模型参数量。以一个 的线性层为例,全量更新要存 个参数;LoRA 只存 个参数。这个比例可以解释为什么”一个基座 + 多个 adapter”在多租户服务里有吸引力。真正部署时还必须把 KV cache、CUDA graph、LoRA slot 预分配、量化格式、TP/PP 切分和并发目标一起算进去。
graph TB
subgraph "传统方案"
B1[客户 A 专用<br/>完整模型副本]
B2[客户 B 专用<br/>完整模型副本]
B3[...]
B10[客户 J 专用<br/>完整模型副本]
Cost1[权重按租户复制]
end
subgraph "vLLM + LoRA"
BASE[共享基座模型]
L1[LoRA A]
L2[LoRA B]
L3[...]
L10[LoRA J]
BASE --> L1
BASE --> L2
BASE --> L3
BASE --> L10
Cost2[共享基座 + 多 adapter]
end
style Cost1 fill:#ef4444,color:#fff,stroke:none
style Cost2 fill:#10b981,color:#fff,stroke:none
16.2.2 几个行业案例
这种结构在业务上常见于:
| 行业 | 使用场景 | 朴素做法 | LoRA 做法 |
|---|---|---|---|
| 法律 SaaS | 每客户独立合同语料微调 | 每客户一套完整权重 | 共享基座,按客户切换 adapter |
| 代码助手平台 | 按编程语言微调 | 每种语言一套服务 | 共享基座,按语言加载 adapter |
| 电商客服 | 按品牌微调 | 冷门品牌不值得常驻 | 热门 adapter 常驻,冷门 adapter 走 resolver |
| 教育辅导 | 按学科 / 年级微调 | 多套模型分散部署 | 通过路由和 LoRA slot 控制活跃集合 |
这是 LoRA 在 vLLM 里的第一类价值——硬件成本革命。
16.3 vLLM 的 LoRA 三层架构
vLLM 的 LoRA 模块 vllm/lora/ 下分三层:
vllm/lora/
├── request.py # LoRARequest:用户侧的 API
├── layers.py # LoRA 线性层(inject 进模型)
├── models.py # LoRAModel, LoRAModelManager(内存管理)
├── worker_manager.py # WorkerLoRAManager(Worker 进程内协调)
├── resolver.py # LoRAResolver(远程加载)
├── punica_wrapper/ # punica kernel 封装
└── ops/ # 底层 triton / CUDA ops
16.3.1 LoRARequest:用户可见 API
# vllm/lora/request.py
class LoRARequest(msgspec.Struct, omit_defaults=True, array_like=True):
lora_name: str # 逻辑名(调试用)
lora_int_id: int # 数值 ID(引擎内部用)
lora_path: str = "" # 权重路径;__post_init__ 要求非空
lora_local_path: Optional[str] = None # deprecated alias
long_lora_max_len: Optional[int] = None
base_model_name: Optional[str] = None
用户请求里携带 LoRARequest 后,V1 会把它放进 Request 对象、SchedulerOutput 和 Worker 的 InputBatch;OpenAI 入口还可以通过 model 名称触发 resolver,把模型名解析成 LoRA adapter。
为什么区分 lora_name 和 lora_int_id? 前者是字符串,给人看、用于日志和 resolver;后者是整数,进入调度器、prefix-cache hash、InputBatch 映射和 kernel metadata。源码注释还强调:lora_int_id 应该在 adapter 维度全局唯一,但当前 vLLM 不强制检查;这意味着业务层 resolver 或模型注册表必须自己保证 ID 分配不会冲突。
16.3.2 LoRAModel:权重的内存载体
# vllm/lora/models.py(概念性)
class LoRAModel:
"""一个具体 LoRA 的权重集合(A, B 矩阵 + scaling)"""
lora_model_id: int
rank: int
loras: Dict[str, LoRALayerWeights] # module name → weights
# 关键操作
def clone(self, lora_model_id: int) -> "LoRAModel":
"""浅拷贝 —— 权重张量共享,只是给一个新 ID"""
def cast_to(self, dtype, device) -> "LoRAModel":
"""把权重移到指定设备"""
clone 是关键:源码注释写明它会 “share the underlying tensors”。clone(lora_model_id) 返回一个新 LoRAModel 对象和新的整数 ID,但 loras=self.loras.copy() 只是复制字典外壳,底层 LoRALayerWeights 里的张量仍共享。这不是复杂的引用计数系统,而是 Python 对象层面的浅拷贝;它适合 warmup/dummy LoRA 这类需要复用权重对象、但要换 ID 的场景。
from_local_checkpoint() 也值得注意:vLLM 先找 adapter_model.safetensors,找不到再看 adapter_model.bin;它用 safetensors key 反查 unexpected modules,因为 PEFT 配置里可能声明了目标模块,但实际 checkpoint 没有对应权重。加载到 CPU 时,如果 pin memory 可用,lora_a、lora_b、embedding 和 bias 都会 pin 住,为后续搬到 GPU slot 做准备。
16.3.3 LoRAModelManager:激活集合管理
class LoRAModelManager:
def __init__(self, max_num_seqs, max_num_batched_tokens, vocab_size,
lora_config, device):
# 预分配 LoRA 权重缓冲区
self.lora_index_to_id: List[Optional[int]] = [None] * lora_config.max_loras
self._registered_adapters: Dict[int, LoRAModel] = {}
self._active_adapters: OrderedDict[int, None] = OrderedDict()
def activate_adapter(self, lora_id: int) -> bool:
"""把某个 LoRA 放进 active slot(可能触发驱逐)"""
def deactivate_adapter(self, lora_id: int) -> bool:
"""踢出某个 LoRA,但保留在 registered 集合中"""
def remove_adapter(self, lora_id: int) -> bool:
"""从 registered 集合里彻底删除"""
它管理两个维度的集合:
- Registered:所有已注册到系统的 LoRA(内存中可能有,但不一定在 GPU 上)
- Active:当前在 GPU 上准备好、可以参与 forward 的 LoRA
activate_adapter 负责把一个 registered 但未 active 的 LoRA 搬到 GPU 上(可能触发驱逐);deactivate_adapter 把它踢出 active 但保留在 registered(下次再 activate 比重新加载 registered 快)。
16.3.4 WorkerLoRAManager:Worker 侧的协调者
每个 Worker 进程(TP rank)维护一个 WorkerLoRAManager。Scheduler 下发 SchedulerOutput 时附带 lora_requests 字段,WorkerLoRAManager 在 forward 前确保所有需要的 LoRA 都 active。
class WorkerLoRAManager:
def set_active_loras(self, lora_requests: Set[LoRARequest]):
"""确保 lora_requests 里的所有 LoRA 都处于 active 状态。
返回每个 request 对应的 slot index。"""
# 1. 找出当前 active 但不再需要的 → deactivate
# 2. 找出 registered 但未 active 的需要 → activate
# 3. 找出未 registered 的 → 先 load_adapter 再 activate
这段代码的关键不是固定切换耗时,而是先保证合法,再进入 forward。LRUCacheWorkerLoRAManager._apply_adapters() 会先把本轮请求里的 LoRA ID 收集成 map,如果数量超过 lora_slots 直接抛错;add_adapter() 则先从本地路径加载并校验 adapter,加载成功后才在 CPU cache 超容量时驱逐 oldest,最后调用 activate_adapter()。这个顺序避免了”先踢掉旧 adapter,结果新 adapter 加载失败”的坏状态。
V1 runner 侧还有一层映射:InputBatch.make_lora_inputs() 返回三件东西:每个 request 的 prompt_lora_mapping、按本步 token 展开的 token_lora_mapping,以及活跃的 LoRARequest 集合。LoRAModelRunnerMixin._set_active_loras() 把这两个 mapping 包成 LoRAMapping,再交给 WorkerLoRAManager.set_active_adapters()。也就是说,kernel 看到的不是原始请求对象,而是已经压平成 token/request 维度的整数 LoRA ID。
16.3.5 三层职责总结
一张表梳理清楚:
| 层 | 位置 | 职责 | 典型调用方 |
|---|---|---|---|
LoRARequest | Client 接口 | 描述”我想用哪个 LoRA” | 用户代码 |
LoRAModelManager | Engine 进程 | 注册表 + GPU 激活集合管理 | Engine Core |
WorkerLoRAManager | Worker 进程 | forward step 前确保 active | Model Runner |
这三层让用户代码和 kernel 代码完全解耦——用户只需要填 LoRARequest,不用关心什么时候 activate。
16.4 punica 与分组 GEMM:batched LoRA 的性能秘密
batched LoRA 的核心挑战:一个 batch 里不同请求用不同 LoRA,如何高效计算?
16.4.1 朴素方案的代价
最直白的做法——循环遍历所有 LoRA:
output = base_forward(x, W) # 基座
for lora_id in unique_loras_in_batch:
indices = [i for i, r in enumerate(batch) if r.lora_id == lora_id]
output[indices] += scaling * (x[indices] @ A[lora_id]) @ B[lora_id]
问题:每个 LoRA 的那次 GEMM 都是小 batch size 的 GEMM(因为一个 LoRA 通常只对应 batch 里几个请求),GPU 利用率极低。
16.4.2 punica 论文的关键洞察
Punica(Chen et al., MLSys 2024 / arXiv:2310.18547)提出:把所有 LoRA 的 A 矩阵”堆”成一个三维张量,用一次分组 GEMM(Grouped GEMM) 完成所有 LoRA 旁路:
# A_stacked: [num_loras, r, d]
# B_stacked: [num_loras, d, r]
# lora_indices: [batch] → 每个请求对应的 lora id
# 所有 LoRA 旁路在一个 kernel 里完成
output = punica_sgmv(x, A_stacked, B_stacked, lora_indices, scaling)
output += base_forward(x, W)
这就是 SGMV (Segment Grouped Matrix-Vector) kernel——给定一个 token 到 LoRA 的映射,把不同 LoRA 的 GEMM 融合在一个 kernel 里,每个 thread block 处理一段连续的同 LoRA token,组间并行。
16.4.3 BGMV vs SGMV
vLLM 里有两种 punica kernel:
- BGMV (Batched Grouped Matrix-Vector):适合 decode 阶段——每个请求只有 1 个新 token。kernel 里每个 thread block 处理一个请求 + 一个 LoRA。
- SGMV (Segment Grouped Matrix-Vector):适合 prefill 阶段——每个请求有多个 token 连续属于同一 LoRA。kernel 用 segment reduction 的思路处理。
两种 kernel 的选择是自动的(vllm/lora/punica_wrapper/),用户感知不到。
16.4.4 源码里的 punica wrapper
本地源码能直接确认的是 wrapper 的职责,而不是某个固定 latency 表。punica_base.py 把接口拆成五类:
| 方法 | 作用 |
|---|---|
update_metadata() | 根据 LoRAMapping、lora_index_to_id、vocab 信息和 long-context LoRA 状态更新 kernel 元数据 |
add_shrink() | 对多个 lora_a slice 做 shrink,即 LoRA A 方向的低秩投影 |
add_expand() | 对多个 lora_b slice 做 expand,并按 output slice 写回 |
add_lora_embedding() | 给 embedding 类 LoRA 做 expand |
add_lora_linear() / add_lora_logits() | 给 linear / logits processor 的 LoRA 入口提供统一封装 |
PunicaWrapperBase 预分配了 token 级和 prompt 级的索引张量:_token_lora_indices、_sampler_indices、_sampler_indices_padded、_embeddings_indices、_long_lora_indices。它还维护 SGMV 需要的 seq_start_locs、seq_lengths、lora_indices_per_batch。这解释了为什么前面的 InputBatch.make_lora_inputs() 要把 LoRA ID 展开到 token 维度:kernel 不想拿 Python 请求对象,它只需要一组紧凑的索引张量。
GPU 实现 punica_gpu.py 又补了一层设备相关逻辑:它创建 LoRAKernelMeta,update_metadata() 时调用 prepare_tensors(),add_shrink() 和 add_expand() 分别调 lora_shrink / lora_expand Triton op。CPU/HPU 有自己的 wrapper,选择逻辑在 punica_selector.py。这是一种典型的 vLLM 风格:上层保持统一接口,下层按硬件后端切实现。
16.4.5 性能该怎么测
本章不写”某模型某 batch 延迟多少毫秒”,因为这类数字对硬件、rank、目标层、batch 混合、是否 prefill、是否 CUDA graph、是否 fully sharded、量化基座都敏感。更可靠的判断方式是把测试拆成四组:
| 对照组 | 用途 |
|---|---|
| 基座无 LoRA | 得到纯 base model 的 TPOT/TTFT 基线 |
| 单 LoRA 常驻 | 测 adapter 常驻时旁路计算的额外成本 |
| 多 LoRA 混合 batch | 测 token-to-LoRA mapping 和 grouped kernel 的成本 |
| 高频切换 / 冷加载 | 测 resolver、本地加载、CPU cache、GPU slot 激活的尾延迟 |
如果多 LoRA 混合 batch 明显慢,先看三件事:max_loras 是否逼近上限导致调度跳过请求,rank 是否过大,是否启用了 fully_sharded_loras 且当前 TP/序列长度真的适合它。LoRAConfig 的注释说 fully sharded 在高序列长度、高 rank 或高 TP size 时可能更快,这是一条调参方向,不是默认无脑开启。
16.5 双层缓存:GPU active + CPU pinned
16.5.1 加载成本的不对称性
LoRA 的加载成本不对称:
- 已 active:slot 已经写入
lora_index_to_id,forward 前只需要更新 mapping 元数据 - 已 registered 但未 active:adapter 已在 manager 缓存里,
activate_adapter()把权重写入空闲或被 LRU 腾出的 GPU slot - 未 registered 但本地路径存在:
_load_adapter()读取 PEFT 配置和 checkpoint,校验 rank、extra vocab、目标模块,再加入 manager - 远程或动态名称:OpenAI 入口先走
LoRAResolverRegistry里的 resolver,把名称解析成LoRARequest,然后调用engine_client.add_lora()
vLLM 把这分成了两层缓存:
graph TB
subgraph "GPU 侧 (最快)"
Active[Active LoRAs<br/>max_loras 个槽位]
end
subgraph "CPU 侧 (中等速度)"
CPU[Registered LoRAs<br/>max_cpu_loras 容量]
end
subgraph "磁盘 / 远程 (最慢)"
Disk[本地 / S3 / HF Hub]
end
Disk -->|resolver / local load| CPU
CPU -->|activate_adapter| Active
Evict1[GPU LRU 驱逐] --> CPU
Evict2[CPU LRU 驱逐] --> Disk
style Active fill:#10b981,color:#fff,stroke:none
style CPU fill:#3b82f6,color:#fff,stroke:none
style Disk fill:#94a3b8,color:#fff,stroke:none
两个关键参数:
--max-loras:GPU 上同时 active 的 LoRA 数,决定一个 batch 里最多能处理多少种 LoRA--max-cpu-loras:manager 里可注册/缓存的 LoRA 数;未显式设置时,LoRAConfig.__post_init__()会把它设成max_loras,且它必须大于等于max_loras
16.5.2 真实实现:GPU 侧定长 slot 数组 + CPU 侧 LRU(不是”两层 LRU”)
这一层和常见描述不太一样。打开 vllm/lora/models.py,真实架构是两层机制不对称:GPU 侧是定长 slot 数组(不做自动驱逐)、CPU 侧才是真正的 LRU。分两段看源码:
GPU 侧——LoRAModelManager(基类,models.py:304):
# models.py:332
self.lora_index_to_id: List[Optional[int]] = [None] * self.lora_slots
@property
def lora_slots(self) -> int:
return self.lora_config.max_loras # GPU 槽位数
@property
def capacity(self) -> int:
return self.lora_config.max_cpu_loras # CPU 总容量
# models.py:330
assert self.capacity >= self.lora_slots # CPU ≥ GPU 的硬约束
GPU 侧的状态就是一个定长 List[Optional[int]]——None 代表空槽、整数代表”这个 LoRA id 占据此 slot”。activate_adapter(line 381)找第一个 None 填入:
first_free_slot = next(
((i, lora_id) for i, lora_id in enumerate(self.lora_index_to_id)
if lora_id is None), None)
if first_free_slot is None:
raise ValueError("No free lora slots")
找不到空槽就直接抛 ValueError——基类根本不做自动驱逐。这是刻意的设计:基类假设调用方(调度器)自己负责 slot 管理。为什么这么设计?因为”哪个 LoRA 应该被驱逐”高度依赖具体策略(LRU?LFU?按优先级?),基类不想把策略写死。
CPU 侧——LRUCacheLoRAModelManager(子类,models.py:711)加上 LRU:
class LRUCacheLoRAModelManager(LoRAModelManager):
def __init__(self, ...):
super().__init__(...)
# CPU 缓存:真 LRU,有 deactivate_adapter 回调
self._registered_adapters: LoRALRUCache = LoRALRUCache(
self.capacity, self.deactivate_adapter)
# GPU "活跃表":也是 LRU 但容量等于 lora_slots
self._active_adapters: LoRALRUCache = LoRALRUCache(
self.lora_slots, self._deactivate_adapter)
def activate_adapter(self, lora_id: int) -> bool:
# 关键:子类在调基类之前自己主动驱逐
if lora_id not in self._active_adapters and len(
self._active_adapters) >= self.lora_slots:
self._active_adapters.remove_oldest()
result = super().activate_adapter(lora_id)
self._active_adapters.touch(lora_id)
return result
子类用两个 LoRALRUCache 叠加机制:
_registered_adapters:CPU 侧真 LRU,容量max_cpu_loras。deactivate_adapter回调会在元素被 LRU 挤出时触发——这是”CPU 满了、把最老的彻底卸掉”的自动通路。_active_adapters:GPU 侧的并行 LRU,容量max_loras。它和基类的lora_index_to_idslot 数组并存——activate_adapter里先remove_oldest()腾一个 slot、再super().activate_adapter()走基类的 “找空 slot 填入”。
这种分层设计有一个优雅的特点:基类的 lora_index_to_id 数组是真正承载 GPU 上 kernel 所需 slot index 的状态(punica BGMV kernel 要按固定 index 索引 A/B stacked tensor),_active_adapters LRU 只是辅助的元信息用来决定下次驱逐谁。数据平面(slot 数组)和控制平面(LRU 结构)分离——基类管数据平面、子类管控制平面。
math.ceil(max_num_batched_tokens / 8) * 8 的对齐(line 331)是另一个容易漏看的细节——内部用的 max_num_batched_tokens 会被向上圆整到 8 的倍数。源码没有在这里解释性能原因,所以本章只保留事实:LoRA manager 内部按 8 对齐,外部 scheduler 配置仍是原始 token 上限。
生产经验:如果你看到 "No free lora slots" 报错、通常不是基类 bug、而是你在代码里直接用了 LoRAModelManager 基类而非 LRUCacheLoRAModelManager 子类——需要手动管理 slot 或者切到子类让它自动 LRU。
16.5.3 调参建议
max_loras 设置原则:实际同时高频使用的 LoRA 数,不是”总共注册了多少 LoRA”。
- 如果你有 100 个 LoRA 但每段时间只有 3 个在用:
max_loras=4(留一个余量) - 如果有 8 个客户同时都在用各自 LoRA:
max_loras=8 - 不要为了”稳妥”设太大——每个 active LoRA 都占 GPU 预分配 buffer,降低可用 KV
16.5.4 冷启动 warmup
生产部署中,如果你知道启动后前几分钟流量会打在某几个 LoRA 上,最好主动预热:
# 启动脚本里
for lora_name in ["customer_a", "customer_b", "customer_c"]:
await engine.add_lora(LoRARequest(lora_name, ...))
第一个真实请求到达时,这些 LoRA 已经通过 add_lora() 进入 manager;是否已经 active 取决于后续请求映射和 slot 状态。预热的目标是把路径解析、PEFT 校验、本地 checkpoint 加载这些不稳定步骤前移,而不是保证某个固定毫秒级收益。
16.6 LoRA 与前缀缓存的冲突
这是 LoRA 推理最容易被忽视的副作用:LoRA 破坏前缀缓存的跨请求共享。
16.6.1 冲突原理
前缀缓存的基础是”相同 token 序列产生相同 KV”。但 LoRA 改变了模型中被注入层的输出,不同 LoRA 下相同 prompt 会产生不同 KV。
因此每个 block 的 hash 必须包含 LoRA 标识:
# vllm/v1/core/kv_cache_utils.py(摘录语义)
def need_extra_keys(request):
return bool(request.mm_positions) or (request.lora_request is not None)
def _gen_lora_extra_hash_keys(request):
if not request.lora_request:
return []
return [request.lora_request.lora_int_id]
def generate_block_hash_extra_keys(request, ...):
lora_extra_keys = _gen_lora_extra_hash_keys(request)
extra_keys = lora_extra_keys + mm_extra_keys
return tuple(extra_keys) if extra_keys else None
结果:使用 LoRA-A 的请求和使用 LoRA-B 的请求,即使 token 内容完全一样,block hash 的 extra keys 也不同,不能共享同一份 prefix cache block。这个设计是正确性要求,不是保守优化。
16.6.2 数量分析
假设系统提示是固定的 system prompt,正常情况下多个请求可以共享同一份 prompt KV。加了 LoRA 后,缓存隔离粒度会多出 LoRA ID:
- 同一个 LoRA 内部:相同 prompt 仍可能共享;
- 不同 LoRA 之间:相同 prompt 不共享;
- 活跃 LoRA 越多,prefix cache 被按 adapter 分片得越细;
- 如果每个请求都换不同 LoRA,prefix cache 对跨请求共享的帮助会明显下降。
16.6.3 缓解策略
业务层路由:前端 Load Balancer 按 LoRA ID 路由——同 LoRA 的请求尽量送同一个副本。这样单副本内部的 prefix cache 对该 LoRA 的命中率能保持。
减少 LoRA 数量:如果可能,把相似的 LoRA 合并成更少的大 LoRA。
接受 trade-off:多数 B 端场景下,LoRA 带来的多租户能力 > prefix cache 命中率损失。评估之后果断选择 LoRA 路径。
16.6.4 权衡决策表
帮你快速判断该不该用 LoRA:
| 场景特征 | 推荐 LoRA | 理由 |
|---|---|---|
| 10+ 租户、每租户独立微调 | ✓ 强烈推荐 | 硬件成本降至 1/N |
| 高共享 system prompt、少量 LoRA | ✗ 谨慎 | prefix cache 损失可能 > LoRA 收益 |
| 在线用户只有 1-2 租户 | ✗ 不推荐 | 不如直接部署专用模型 |
| 需要动态加载新微调 | ✓ 必须 | 冷切换场景 LoRA 是唯一方案 |
| 每请求都切换 LoRA | ✗ 避免 | 切换开销累积 |
16.7 QLoRA:量化基座 + FP16 LoRA
16.7.1 组合思路
QLoRA(Dettmers et al., NeurIPS 2023 / arXiv:2305.14314)是训练时的概念——基座用 4-bit 量化,LoRA 旁路保持 FP16/BF16。vLLM 推理时也支持这个组合:
vllm serve meta-llama/Llama-3-70B \
--quantization gptq \
--enable-lora \
--max-loras 8 \
--max-lora-rank 64 \
--lora-modules \
customer_a=/models/lora/a \
customer_b=/models/lora/b \
...
效果:
量化基座: 降低 base_weight
FP16/BF16 LoRA: 保留 adapter 的低秩旁路
KV cache: 仍按并发、上下文长度、dtype 和 block 策略消耗显存
这类组合的价值是把”基座权重”和”adapter 权重”分别优化:基座靠量化省显存,adapter 靠低秩省显存。能不能放进一张卡,要用目标模型、量化格式、max_loras、max_lora_rank、KV cache dtype、最大上下文和并发一起算,不能只看参数名。
16.7.2 精度损失
很多人担心”4-bit 量化精度够吗?“答案取决于模型、量化方法和任务。不要把某个评测集上的百分比泛化成产品承诺:
| 对比 | 应该验证什么 |
|---|---|
| FP16 原模型 → 量化基座 | 通用能力、长上下文、工具调用或结构化输出是否退化 |
| FP16 原模型 + LoRA → 量化基座 + LoRA | adapter 学到的领域能力是否仍然有效 |
| 数学、代码、严格格式任务 | 小概率错误、边界条件和格式遵循是否变差 |
如果服务的是开放式对话,用户可能更能容忍轻微差异;如果服务的是数学、代码、合规审查或结构化抽取,就必须做离线集和线上 A/B。vLLM 能提供推理组合能力,但不替你证明质量。
16.7.3 注意事项
- LoRA rank 高时表达能力更强,但
max_lora_rank会影响预分配 buffer,且本地LoRAConfig只允许(8, 16, 32, 64, 128, 256, 320, 512)这些 rank 上限 - 量化基座的 kernel 与 LoRA 旁路会在同一 forward 路径里协作;是否有明显开销要按目标模型测试
--kv-cache-dtype fp8能降低 KV cache 占用,但需要额外验证质量和硬件支持,不能和 LoRA 显存节省混为一谈
16.8 LoRA 解析器:从 S3 / HuggingFace 热加载
16.8.1 Resolver 抽象
企业场景下 LoRA 通常存在对象存储。vLLM 的 LoRAResolver 提供了可插拔的远程加载:
# vllm/lora/resolver.py
class LoRAResolver(ABC):
@abstractmethod
async def resolve_lora(
self, base_model_name: str, lora_name: str
) -> Optional[LoRARequest]:
"""给定 lora_name,返回可以加载的 LoRARequest;找不到则返回 None。"""
@dataclass
class _LoRAResolverRegistry:
resolvers: Dict[str, LoRAResolver] = field(default_factory=dict)
def register_resolver(self, resolver_name: str, resolver: LoRAResolver):
self.resolvers[resolver_name] = resolver
def get_resolver(self, resolver_name: str) -> LoRAResolver:
...
本地源码里没有内置 S3LoRAResolver 或 HuggingFaceLoRAResolver 具体类;只有抽象接口和 registry。OpenAI serving 入口启动时会遍历 LoRAResolverRegistry.get_supported_resolvers(),把已注册 resolver 放进 self.lora_resolvers;收到动态 LoRA 名称时,serving_models.py 的 resolve_lora() 会先检查是否已加载,再逐个调用 resolver,找到后给它分配新的 lora_int_id,调用 engine_client.add_lora() 并加入本地 lora_requests 列表。
所以生产部署的真实扩展点是:你在自己的启动代码或插件里注册一个 resolver,它可以从 S3、内部模型仓库、HF Hub 镜像或本地缓存解析 adapter,但这不是 vLLM 当前源码里已经写好的具体类。
16.8.2 自定义 Resolver 的三个关键点
企业内常见需求:
- 鉴权:把 request 的
auth_token传入 Resolver,在 Resolver 里校验 - 多版本:
lora_name可能带版本(customer_a:v2),Resolver 做版本解析 - 本地缓存:避免每次都从对象存储下载,用 LRU 本地文件缓存
示意:
class CompanyLoRAResolver(LoRAResolver):
def __init__(self, cache_dir="/cache/lora", max_cache_gb=100):
self.cache_dir = cache_dir
self.cache = LRUDiskCache(cache_dir, max_cache_gb * 1e9)
async def resolve_lora(self, base_model_name, lora_name, auth_token=None):
# 鉴权
tenant_id = await verify_auth(auth_token)
if not await tenant_can_use_lora(tenant_id, lora_name):
raise PermissionError()
# 多版本
if ":" in lora_name:
name, version = lora_name.split(":")
else:
name, version = lora_name, "latest"
# 本地缓存
cache_key = f"{name}:{version}"
local = self.cache.get(cache_key)
if local:
return LoRARequest(cache_key, new_id(), local)
# 远程拉取
remote = f"s3://company-loras/{name}/{version}/"
local = await s3_download(remote, self.cache_dir)
self.cache.put(cache_key, local)
return LoRARequest(cache_key, new_id(), local)
16.9 LoRA 的四个局限
局限 1:频繁切换的累积开销
当请求模式高频切换(每个请求不同 LoRA),系统会更频繁地做 resolver、本地加载、CPU cache LRU、GPU slot 激活和 metadata 更新。解决思路不是迷信某个固定延迟,而是让同一批次内的 LoRA 种类数受控:调大 max_loras、按 LoRA ID 路由,或者把低频 adapter 放到单独副本。
局限 2:前缀缓存命中率下降
如 16.6 所述,不同 LoRA 的 KV 不共享。LoRA 越多,prefix cache 越不值钱。严重时要评估 LoRA 路径是否合算。
局限 3:并非所有模型都支持
V1 runner 在 load_lora_model() 里先调用 supports_lora(model);不支持就直接抛 ValueError。多模态模型还有额外限制:日志会提示当前只支持把 LoRA 加到 language model 部分。判断某个模型能不能用 LoRA,最稳妥的方法是看模型实现是否声明支持 LoRA,以及 vllm/lora/utils.py 能否为它的线性层创建对应 wrapper。
局限 4:rank 上限约束
--max-lora-rank 决定 GPU 预分配缓冲区。实际 LoRA 的 rank 超过这个值就无法加载。典型场景:
- 训练时 rank=64,推理时
--max-lora-rank 64OK - 训练时 rank=128,推理必须设
--max-lora-rank >= 128
提前和训练团队对齐。
16.9.5 总结:LoRA 不是银弹
把四条局限记下来:
| 局限 | 缓解方法 |
|---|---|
| 切换开销 | max_loras 调大 + 业务路由 |
| prefix cache 下降 | 租户路由 + 接受 trade-off |
| 模型兼容性 | 事先确认,或改用主流模型 |
| rank 上限 | 训练推理参数提前对齐 |
16.10 三类生产场景的配置
16.10.1 RAG 多租户(B 端 SaaS)
每个客户有自己的 LoRA,训练数据私有:
vllm serve meta-llama/Llama-3-70B-Instruct \
--quantization gptq \
--enable-lora \
--max-loras 16 \
--max-cpu-loras 64 \
--max-lora-rank 64 \
--max-num-seqs 64 \
--enable-chunked-prefill
max_loras=16 表示单个 batch 最多容纳 16 种 LoRA,也对应 GPU active slot 上限;max_cpu_loras=64 给 registered adapter 留更大的 CPU cache。动态 S3/内部仓库解析不是一个内置 --lora-resolver s3 参数,而是需要在服务进程启动前向 LoRAResolverRegistry 注册自定义 resolver。
16.10.2 多领域 API 服务
一个开放 API,按领域切换 LoRA(代码、法律、医学、翻译):
vllm serve Qwen/Qwen2.5-72B-Instruct \
--enable-lora \
--max-loras 4 \
--max-cpu-loras 8 \
--lora-modules \
code=/models/lora/code-v3 \
legal=/models/lora/legal-v2 \
medical=/models/lora/medical-v1 \
translate=/models/lora/translate-v5
预加载 4 个常用 LoRA;用户在 API 里通过 model 字段指定。
16.10.3 开发者自定义微调(Playground)
允许用户自己上传 LoRA 做实验:
vllm serve meta-llama/Llama-3-8B-Instruct \
--enable-lora \
--max-loras 2 \
--max-cpu-loras 16 \
--max-lora-rank 128
max_loras 小,表示单批同时服务的 adapter 种类少;max_lora_rank=128 允许较高 rank 的实验 adapter。用户上传路径仍然要通过业务侧 resolver 或静态 --lora-modules 管理,vLLM 当前没有内置 --lora-resolver user_upload 这类 CLI。
16.10.4 三场景对比
| 参数 | RAG 多租户 | 多领域 API | Playground |
|---|---|---|---|
| max_loras | 16 | 4 | 2 |
| max_cpu_loras | 64 | 8 | 16 |
| max_lora_rank | 64 | 32 | 128 |
| resolver | 自定义 S3/仓库 resolver | 本地预加载 | 用户上传 resolver |
| 典型并发 | 由 max_num_seqs 和 max_loras 共同限制 | 中 | 低 |
| 切换频率 | 中 | 低 | 高 |
16.10.5 实测:vllm/lora/ 6027 行 + V1 集成 145 行的真实分布
把整个 vllm/lora/ 目录按子模块实测——
| 路径 | 行 | 角色 |
|---|---|---|
vllm/lora/layers.py | 1263 | 本目录最大——LoRA 层集成(ColumnParallelLoRA / RowParallelLoRA / EmbeddingLoRA / 对应基础 layer 的 wrapper) |
vllm/lora/models.py | 802 | LoRAModel + LoRAModelManager(§16.3.2-16.3.3 的实现) |
vllm/lora/punica_wrapper/punica_base.py | 483 | PunicaWrapperBase 抽象——add_shrink / add_expand / add_lora 等 §16.4 BGMV/SGMV 接口 |
vllm/lora/punica_wrapper/punica_cpu.py | 348 | CPU fallback 实现(无 GPU 环境用) |
vllm/lora/fully_sharded_layers.py | 335 | TP sharded 版本的 LoRA layer |
vllm/lora/punica_wrapper/punica_gpu.py | 289 | GPU 实现——调 ops/triton_ops/ 里的 kernel |
vllm/lora/ops/triton_ops/lora_expand.py | 293 | Expand kernel(BGMV expand:B 矩阵 outer-product 部分) |
vllm/lora/ops/triton_ops/lora_shrink.py | 247 | Shrink kernel(BGMV shrink:A 矩阵 inner-product 部分) |
vllm/lora/ops/triton_ops/kernel_utils.py | 243 | 共用 Triton kernel 工具 |
vllm/lora/worker_manager.py | 251 | WorkerLoRAManager(§16.3.4) |
vllm/lora/utils.py | 237 | 工具函数(rank 验证 / 路径解析等) |
vllm/lora/lora.py | 198 | LoRA 基础数据类 |
vllm/lora/peft_helper.py | 115 | HuggingFace PEFT 格式适配 |
vllm/lora/request.py | 97 | LoRARequest(§16.3.1) |
vllm/lora/resolver.py | 83 | Resolver 抽象(§16.8) |
其余(punica_hpu.py / punica_selector.py / punica_wrapper/utils.py / __init__.py / ops 内部文件) | 余下 | — |
vllm/lora/ 合计 | 6027 | — |
vllm/v1/worker/lora_model_runner_mixin.py | 145 | V1 worker 的 LoRA mixin 集成层(§1.6.1 实测 v1/worker/ 4851 行的 3%) |
两条值得记住的物理事实——
layers.py1263 行 +fully_sharded_layers.py335 行 = 1598 行专门给 LoRA 层 wrapper——是整个目录 27%——因为 vLLM 要为每种基础 layer(ColumnParallelLinear / RowParallelLinear / VocabParallelEmbedding 等)写一个对应的 LoRA 包装层、再加 TP sharded 版本——这是”基座 N 种 layer × 是否 TP sharded 2 种 = 2N 个 wrapper”的 笛卡尔积代价——印证 §16.3 “三层架构” 中”LoRA 层集成”是工程上最重的一块、不是 §16.3 重点讨论的 LoRARequest/LoRAModel 那种数据类- punica 跨 4 种硬件后端——
punica_wrapper/有base / cpu / gpu / hpu4 个实现——通过punica_selector.py运行时挑选——和 §1.6.1 实测vllm/distributed/device_communicators/13 文件的多后端模式同款——是 vllm “核心算法 + 多硬件适配”的一致设计纪律;GPU/CPU/HPU 三个 punica 实现合计 ~1100 行——分发开销在 selector 几十行里
串联 §15.10.4 多模态 4738 行 + §14.6.5 distributed 2483 行 + 本节 lora 6027 行 = ~13000 行——是 vLLM 三大水平扩展能力(多模态 / 分布式 / LoRA)的真实工程量秤砣。
16.11 本章小结
LoRA 在 vLLM 里不是”一个 feature flag”,而是一组贯穿请求、调度、worker、kernel 和缓存的机制:
- 数学原理:权重更新低秩假设 ; 让参数量降到 1%
- 推理叠加: 双路径;基座不变 + 小旁路
- 经济学:共享基座、按需加载 adapter,容量账从
N * base_weight变成base_weight + N * adapter_weight + runtime state - 三层架构:
LoRARequest用户 API /LoRAModel+Manager内存管理 /WorkerLoRAManager执行侧协调 - punica wrapper:用 token/prompt LoRA mapping 和 BGMV/SGMV 元数据把多 adapter 旁路交给后端 kernel
- 缓存与 slot:
max_loras控制单 batch LoRA 种类和 GPU slot,max_cpu_loras控制 registered cache 容量,LRU 子类负责驱逐策略 - 前缀缓存冲突:LoRA 破坏跨请求 KV 共享;
extra_key隔离;业务路由缓解 - QLoRA 组合:量化基座 + FP16/BF16 LoRA,必须按目标模型和质量集验证
- LoRA Resolver:抽象接口 + registry;S3 / HF Hub / 内部仓库需要业务侧注册具体 resolver
- 四个局限:切换开销、prefix cache 下降、模型兼容性、rank 上限
一句话记忆:
LoRA 把”N 个完整权重副本”改成”1 个基座 + N 个低秩 adapter”;vLLM 的价值,是把这个数学结构接到调度、缓存、worker 和 kernel 的在线推理链路里。
物理事实:vllm/lora/ 6027 行(layers.py+fully_sharded_layers.py 1598 行 27% 是基座 N 种 layer × TP 是否 sharded 2 种的笛卡尔积代价)+ V1 集成 145 行;punica 跨 4 种硬件(base/cpu/gpu/hpu)同款多后端模式;串联多模态 4738 + distributed 2483 + lora 6027 = ~13000 行 vLLM 三大水平扩展能力工程秤砣。
源码导航
- LoRA 请求:
vllm/lora/request.py- LoRA 模型:
vllm/lora/models.py- LoRA 层:
vllm/lora/layers.py- WorkerManager:
vllm/lora/worker_manager.py- Resolver:
vllm/lora/resolver.py- Punica kernels:
vllm/lora/punica_wrapper/+vllm/lora/ops/- V1 的 LoRA 集成:
vllm/v1/worker/lora_model_runner_mixin.py论文
- Hu et al., “LoRA: Low-Rank Adaptation of Large Language Models”, ICLR 2022 (arXiv:2106.09685)
- Dettmers et al., “QLoRA: Efficient Finetuning of Quantized LLMs”, NeurIPS 2023 (arXiv:2305.14314)
- Chen et al., “Punica: Multi-Tenant LoRA Serving”, MLSys 2024 (arXiv:2310.18547)
- Sheng et al., “S-LoRA: Serving Thousands of Concurrent LoRA Adapters”, 2023 (arXiv:2311.03285)