LM-Guided CoT: A Novel Machine Learning Framework that Leverages a Lightweight (<1B) Language Model (LM) for guiding a black-box large (>10B) LM in Reasoning Tasks

Chain-of-thought (CoT) prompting involves instructing language models (LMs) to reason step by step, resulting in improved performance across various arithmetic, commonsense, and symbolic reasoning domains. However, conventional CoT has limitations. While it shows performance gains in large LMs of 100+ billion parameters, it often yields repetitive and vacuous rationales due to their lack of faithfulness to input instances and tendency to produce unaligned rationales and answers.

Recent research has explored methods to enhance the reasoning abilities of small LMs for computational efficiency or task performance. Rationale distillation involves a small LM learning from a larger one to generate CoT rationales. However, limited investigation has been conducted to address errors inherited from the teacher model. Also, efforts have been made to evaluate and refine rationales beyond distillation, emphasizing logicality, relevance, informativeness, coherence, and repetition. While reinforcement learning (RL) has been applied to correct misaligned LM behaviors, rationale correction must be explored.

✅ [Featured Article] LLMWare.ai Selected for 2024 GitHub Accelerator: Enabling the Next Wave of Innovation in Enterprise RAG with Small Specialized Language Models

Researchers from Penn State University and Amazon AGI propose a unique method, LM-guided CoT, utilizing two distinct LMs for CoT reasoning. The method involves a small LM for rationale generation and a large LM for answer prediction. Initially, a vanilla knowledge distillation (KD) technique is applied to the small LM using rationales generated by the large LM, narrowing the gap in their reasoning capabilities. Subsequently, fine-grained measurements, including relevance, actuality, logicality, consistency, coherence, fluency, naturalness, and readability, are employed to optimize the knowledge-distilled LM through RL. This approach enhances the quality of generated rationales and ultimately improves CoT reasoning performance.

LM-guided CoT framework introduces two LMs: a lightweight model (MS) for generating optimal rationales and a large model (ML) for predicting outputs based on these rationales. Rationale distillation involves MS learning from ML-generated rationales, with filtering to prevent error inheritance. Rationale refinement employs eight linguistic aspect measurements, initially annotated manually and later automated for RL-based training of MS. Proximal Policy Optimization (PPO) is used to update MS with rewards based on aspect-specific evaluation metrics and task-specific accuracy, incorporating penalties for model consistency.

The study compares ML (equivalent to FLAN-T5 XXL) performance with and without CoT prompting, finding a drop in accuracy due to limited reasoning capabilities with long contexts. LM-guided CoT, especially with KD alone, outperforms original CoT prompting by 2% and 10% on HotpotQA and 2WikiMultiHopQA, respectively. This approach improves answer prediction and rationale quality significantly, especially for questions with lengthy contexts, surpassing CoT prompting + SC and rivaling standard prompting in accuracy.

In conclusion, this research introduces LM-Guided CoT, a framework that enhances CoT prompting by decomposing it into rationale generation and answer prediction steps optimized with RL. Outperforming all baselines, it proves an effective and resource-efficient solution for CoT challenges. However, selecting top-quality rationales doesn’t consistently improve task performance, suggesting a need to balance LM-generated rationales and overall task efficiency for optimal results.


Check out the PaperAll credit for this research goes to the researchers of this project. Also, don’t forget to follow us on Twitter. Join our Telegram Channel, Discord Channel, and LinkedIn Group.

If you like our work, you will love our newsletter..

Don’t Forget to join our 40k+ ML SubReddit


Want to get in front of 1.5 Million AI Audience? Work with us here

Asjad is an intern consultant at Marktechpost. He is persuing B.Tech in mechanical engineering at the Indian Institute of Technology, Kharagpur. Asjad is a Machine learning and deep learning enthusiast who is always researching the applications of machine learning in healthcare.

[Free AI Webinar] 'How to Build Personalized Marketing Chatbots (Gemini vs LoRA)'.