Zing Forum

Reading

CATS: An Adaptive Cascaded Tree Speculative Decoding Framework for Memory-Constrained Scenarios

CATS achieves efficient speculative decoding via a cascaded adapter architecture, significantly reducing the number of forward passes of large models while preserving the target distribution, providing new insights for accelerating LLM inference on memory-constrained devices.

speculative decodingLLM inferencememory optimizationadaptertree speculation边缘计算
Published 2026-05-12 23:41Recent activity 2026-05-12 23:48Estimated read 8 min
CATS: An Adaptive Cascaded Tree Speculative Decoding Framework for Memory-Constrained Scenarios
1

Section 01

[Introduction] CATS: A New Framework for Accelerating LLM Inference in Memory-Constrained Scenarios

CATS (Cascaded Adaptive Tree Speculation) is an adaptive cascaded tree speculative decoding framework for memory-constrained scenarios. It achieves efficient speculative decoding through an innovative cascaded adapter architecture, significantly reducing the number of forward passes of large models while preserving the accuracy of the target distribution, providing new insights for accelerating LLM inference on memory-constrained devices.

2

Section 02

Background: Memory Bottlenecks in Large Model Inference and Limitations of Existing Solutions

As the parameter scale of large language models continues to rise, memory usage during inference has become a key bottleneck for deployment. Traditional autoregressive decoding requires a complete forward pass for each token generated, leading to high computational costs. Speculative Decoding accelerates inference by using a small model to draft tokens and a large model to verify them, but in memory-constrained scenarios, loading two complete models simultaneously is often infeasible.

3

Section 03

Core Architecture: Three-Level Cascaded Verification and Dynamic Tree Speculation Mechanism

Three-Level Cascaded Verification Architecture

CATS proposes an innovative cascaded adapter architecture, achieving efficient speculative decoding with only two lightweight adapters and one base model:

  1. Draft Adapter: A 1-layer lightweight Llama decoding block attached to the early layers of the base model (default layer 3) to quickly generate candidate token trees;
  2. Shallow Validator: A 1-layer adapter attached to the middle layers (layer 10 or 15) to preliminarily verify candidates and filter out incorrect tokens;
  3. Target Model Verification: The complete target model confirms the filtered tokens, preserving accuracy while reducing the number of full forward passes.

Dynamic Tree Speculation Mechanism

Supports two modes:

  • Chain Speculation: Speculates the next token each time, suitable for simple scenarios;
  • Tree Speculation: Simultaneously speculates multiple branches to form a token tree, controlled by --tree-topk (e.g., 10) and --total-tokens (e.g., 40) for width and depth, using batch processing for parallel verification.
4

Section 04

Training Process: Two-Stage Data Generation and Adapter Fine-Tuning

CATS training is divided into two stages:

  1. Data Generation Stage: Use the ShareGPT dataset to generate early-layer, subsequent-layer, and final-layer hidden state tensors for each sample via the base model, serving as supervision signals for adapter training;
  2. Adapter Fine-Tuning Stage: Fine-tune the draft adapter and shallow validator separately, using different exit layers (3 and 10/15 layers), train independently for 20 epochs, and use multi-GPU acceleration (distributed training implemented via the accelerate framework).
5

Section 05

Experimental Evaluation: Performance of CATS on Multiple Benchmarks

CATS was validated on benchmarks such as MT-Bench, AlpacaEval, GSM8K, and HumanEval. The evaluation script CATS_dynamic.py outputs metrics including average accepted tokens per step, shallow validator acceptance rate, confusion matrix (TP/TN/FP/FN), precision/recall/F1 score. Experiments show that CATS significantly reduces the number of complete forward passes of the target model while maintaining generation quality, making it particularly suitable for deployment on memory-constrained edge devices.

6

Section 06

Technical Implementation Highlights: Flexible Layer Exit and Lightweight Adapter Design

The technical highlights of CATS include:

  • EarlyExitLlamaForCausalLM: A unified forward interface that supports forward propagation for any layer range [start_layer, end_layer), shares KV cache, and enables flexible layer exit;
  • AdapterModel: A lightweight 1-layer decoder with optional residual connections and layer normalization, loaded from config.json and pytorch_model.bin;
  • MultiAdapterCATSModel: Unifies management of the base model and two adapters, implementing a hybrid forward evaluation loop.
7

Section 07

Deployment Recommendations and Future Outlook

Deployment Recommendations

CATS is built based on the Kangaroo codebase, relying on frameworks like PyTorch 2.0.1 and Transformers 4.33.3. Deployment notes:

  1. The adapter directory must contain config.json and pytorch_model.bin;
  2. Hard-coded paths in training scripts need to be modified according to the environment;
  3. Test data must be copied/linked to data/question_mtbench.jsonl during evaluation.

Summary and Outlook

CATS is an important step towards the practical application of speculative decoding, solving the problem of memory-constrained deployment via cascaded adapters. Future exploration directions: compatibility of adapters with more base models, dynamic adjustment of adaptive tree structures, and performance optimization for more edge devices.