Zing 论坛

正文

CATS:面向内存受限场景的自适应级联树推测解码框架

CATS通过级联适配器架构实现高效推测解码,在保持目标分布的同时显著减少大模型前向传播次数,为内存受限设备上的LLM推理加速提供新思路。

speculative decodingLLM inferencememory optimizationadaptertree speculation边缘计算
发布时间 2026/05/12 23:41最近活动 2026/05/12 23:48预计阅读 3 分钟
CATS:面向内存受限场景的自适应级联树推测解码框架
1

章节 01

【导读】CATS:内存受限场景下LLM推理加速的新框架

CATS(Cascaded Adaptive Tree Speculation)是面向内存受限场景的自适应级联树推测解码框架,通过创新的级联适配器架构实现高效推测解码,在保持目标分布准确性的同时显著减少大模型前向传播次数,为内存受限设备上的LLM推理加速提供新思路。

2

章节 02

背景:大模型推理的内存瓶颈与现有方案局限

随着大型语言模型参数规模不断攀升,推理阶段的显存占用已成为部署的关键瓶颈。传统自回归解码每次生成一个token都需要完整的前向传播,计算成本高昂。推测解码(Speculative Decoding)通过小模型草稿+大模型验证的方式加速,但在内存受限场景下,同时加载两个完整模型往往不可行。

3

章节 03

核心架构:三级级联验证与动态树形推测机制

三级级联验证架构

CATS提出创新的级联适配器架构,仅通过两个轻量级适配器和一个基础模型实现高效推测解码:

  1. 草稿适配器:1层轻量级Llama解码块,附加在基础模型早期层(默认第3层),快速生成候选token树;
  2. 浅层验证器:1层适配器,附加在中间层(第10或15层),初步验证候选以过滤错误token;
  3. 目标模型验证:完整目标模型确认筛选后的token,保留准确性并减少完整前向次数。

动态树形推测机制

支持两种模式:

  • 链式推测:每次推测下一个token,适合简单场景;
  • 树形推测:同时推测多个分支形成token树,通过--tree-topk(如10)和--total-tokens(如40)控制宽度和深度,利用批处理并行验证。
4

章节 04

训练流程:两阶段数据生成与适配器微调

CATS训练分为两阶段:

  1. 数据生成阶段:使用ShareGPT数据集,通过基础模型生成每个样本的早期层、后续层和最终层隐藏状态张量,作为适配器训练的监督信号;
  2. 适配器微调阶段:分别对草稿适配器和浅层验证器微调,使用不同退出层(3层和10/15层),独立训练20个epoch,采用多GPU加速(accelerate框架实现分布式训练)。
5

章节 05

实验评估:CATS在多基准上的性能表现

CATS在MT-Bench、AlpacaEval、GSM8K、HumanEval等基准测试验证,评估脚本CATS_dynamic.py输出指标包括每步平均接受token数、浅层验证器接受率、混淆矩阵(TP/TN/FP/FN)、精确率/召回率/F1分数。实验表明,CATS在保持生成质量的同时,显著减少目标模型完整前向传播次数,特别适合显存受限边缘设备部署。

6

章节 06

技术实现亮点:灵活层级退出与轻量级适配器设计

CATS的技术亮点包括:

  • EarlyExitLlamaForCausalLM:统一forward接口支持任意层范围[start_layer, end_layer)前向传播,共享KV缓存,实现灵活层级退出;
  • AdapterModel:轻量级1层解码器,含可选残差连接和层归一化,从config.json和pytorch_model.bin加载;
  • MultiAdapterCATSModel:统一管理基础模型和两个适配器,实现混合前向评估循环。
7

章节 07

部署建议与未来展望

部署建议

CATS基于Kangaroo代码库构建,依赖PyTorch 2.0.1、Transformers 4.33.3等框架,部署注意:

  1. 适配器目录需包含config.json和pytorch_model.bin;
  2. 训练脚本硬编码路径需根据环境修改;
  3. 评估时测试数据需复制/链接到data/question_mtbench.jsonl

总结与展望

CATS是推测解码实用化的重要一步,通过级联适配器解决内存受限部署难题。未来可探索方向:适配器与更多基础模型兼容性、自适应树形结构动态调整、更多边缘设备性能优化。