Researchers at CMU Introduce TriForce: A Hierarchical Speculative Decoding AI System that is Scalable to Long Sequence Generation

With the widespread deployment of large language models (LLMs) for long content generation, there’s a growing need for efficient long-sequence inference support. However, the key-value (KV) cache, crucial for avoiding re-computation, has become a critical bottleneck, increasing in size linearly with sequence length. The auto-regressive nature of LLMs necessitates loading the entire KV cache for each generated token, leading to low computational core utilization and high latency. While compression methods have been proposed, they often compromise generation quality. LLMs like GPT-4, Gemini, and LWM are gaining prominence in applications like chatbots, vision generation, and financial analysis. However, serving these LLMs efficiently remains challenging due to the auto-regressive nature and the growing memory footprint of the KV cache.

Prior methodologies propose KV cache eviction strategies to reduce the memory footprint of the KV cache, selectively discarding pairs based on eviction policies. This allows models to operate within a limited cache budget. However, such strategies face challenges due to potential information loss, leading to issues like hallucination and contextual incoherency, particularly in long contexts. Speculative decoding, which involves using a lightweight draft model to predict the next tokens, has been introduced to accelerate LLM inference while preserving model output. However, deploying this for long sequence generation presents challenges, including the need for substantial computation to train draft models and the risk of poor speculating performance, especially with existing training-free methods like KV cache eviction strategies.

✅ [Featured Article] LLMWare.ai Selected for 2024 GitHub Accelerator: Enabling the Next Wave of Innovation in Enterprise RAG with Small Specialized Language Models

Researchers from Carnegie Mellon University and Meta AI (FAIR) Introduce TriForce, a hierarchical speculative decoding system designed for scalable long sequence generation. TriForce utilizes the original model weights and dynamic sparse KV cache via retrieval as a draft model, serving as an intermediate layer in the hierarchy. Maintaining the full cache allows for superior KV cache selection using retrieval-based drafting, characterized as lossless compared to eviction-based methods like StreamingLLM and H2O. The hierarchical system addresses dual memory bottlenecks, pairing a lightweight model with a StreamingLLM cache for initial speculations to reduce drafting latency and accelerate end-to-end inference.

TriForce introduces a hierarchical speculative decoding system with retrieval-based KV cache selection. The hierarchical system addresses dual bottlenecks, enhancing speed-up. Retrieval-based drafting segments the KV cache, highlighting relevant information. Lightweight models with StreamingLLM cache accelerate initial speculations, reducing drafting latency. TriForce utilizes model weights and KV cache to enhance LLM inference speed for long sequences. The implementation utilizes Transformers, FlashAttention, and PyTorch CUDA graphs, maintaining full layer sparsity while minimizing kernel launching overhead. 

TriForce evaluation reveals significant speedups, up to 2.31× with a 4K KV cache for Llama2-7B128K on-chip. Offloading to consumer GPUs achieves remarkable efficiency, particularly with Llama2-13B-128K on two RTX 4090 GPUs, 7.94× faster than optimized systems. Llama2-7B-128K with TriForce operates at 0.108s/token, half as slow as auto-regressive baselines on A100. Batched inference also benefits, achieving 1.9× speedup for a batch size of six, each with 19K contexts.

To conclude, this work introduces TriForce, a hierarchical speculative decoding system targeting the efficient serving of LLMs in long contexts. TriForce addresses dual bottlenecks of KV cache and model weights, yielding significant speedups, including up to 2.31× on A100 GPUs and an extraordinary 7.78× on RTX 4090 GPUs. TriForce achieves 0.108s/token, half as slow as auto-regressive baselines on A100. Compared to DeepSpeed-Zero-Inference, TriForce on a single RTX 4090 GPU is 4.86× faster and attains a 1.9× speedup with large batches, showcasing its potential for revolutionizing long-context model serving.


Check out the PaperAll credit for this research goes to the researchers of this project. Also, don’t forget to follow us on Twitter. Join our Telegram Channel, Discord Channel, and LinkedIn Group.

If you like our work, you will love our newsletter..

Don’t Forget to join our 40k+ ML SubReddit


For Content Partnership, Please Fill Out This Form Here..

Asjad is an intern consultant at Marktechpost. He is persuing B.Tech in mechanical engineering at the Indian Institute of Technology, Kharagpur. Asjad is a Machine learning and deep learning enthusiast who is always researching the applications of machine learning in healthcare.

[Free AI Webinar] 'How to Build Personalized Marketing Chatbots (Gemini vs LoRA)'.