Google AI Introduces Learning to Prompt (L2P): A Machine Learning Model Training Method That Uses Learnable Task-Relevant Prompts To Guide Pre-Trained Models Through Training On Sequential Tasks

This Article Is Based On The Research Paper 'Learning to Prompt for Continual Learning' And Google AI Article. All Credit For This Research Goes To The Researchers Of This Paper 👏👏👏

Please Don't Forget To Join Our ML Subreddit

Supervised learning is a popular approach to machine learning (ML), in which the model is trained using data that has been appropriately labeled for the task at hand. Ordinary supervised learning trains on independent and identically distributed (IID). 

All training examples are drawn from a fixed set of classes. The model has access to them throughout the training phase. On the other hand, continuous learning addresses the issue of training a single model on changing data distributions by presenting different classification tasks sequentially. This is especially important for autonomous agents processing and interpreting continuous information streams in real-world scenarios.

Consider two tasks to demonstrate the difference between supervised and continuous learning: (1) classify cats vs. dogs and (2) classify pandas vs. koalas. The model is given training data from both tasks and treats it as a single 4-class classification problem in supervised learning, which employs IID. However, in continuous learning, these two tasks are presented sequentially, and the model only has access to the current task’s training data. As a result, such models are prone to degraded performance on previous tasks, known as catastrophic forgetting.

Mainstream solutions address catastrophic forgetting by storing previous data in a “rehearsal buffer” and combining it with current data to train the model.

However, the performance of these solutions is heavily dependent on buffer size and, in some cases, maybe impossible due to data privacy concerns. Another line of work creates task-specific components to avoid interfering with other tasks. These methods, however, frequently assume that the task at test time is known, which is not always the case, and they necessitate a large number of parameters. The limitations of these approaches raise essential questions for lifelong learning. Is it possible to have a more efficient and compact memory system beyond simply buffering previous data? Is it possible to choose relevant knowledge components for a random sample without knowing the task’s identity?

“Learning to Prompt” is a novel continual learning framework inspired by natural language processing prompting techniques (L2P). Rather than re-learning all of the model weights for each sequential task, learnable task-relevant “instructions,” i.e., prompts, are provided to guide pre-trained backbone models through sequential training using a pool of learnable prompt parameters. L2P applies to various challenging continuous learning settings and consistently outperforms previous state-of-the-art methods on all benchmarks. It outperforms rehearsal-based methods in terms of performance while also being more memory efficient. Most importantly, L2P is the first to propose the concept of prompting in the context of continuous learning.


In contrast to traditional methods that use a rehearsal buffer to adapt entire or partial model weights to tasks sequentially, L2P uses a single frozen backbone model and learns a prompt pool to conditionally instruct the model. The term “Model 0” denotes that the backbone model is fixed at the start.

“Prompt-based learning” modifies the original input using a fixed template given a pre-trained Transformer model. Assume giving a sentiment analysis task the information “I like this cat.” The prompt-based method will change the input into “I like this cat. It seems X,” where “X” is an empty slot to be predicted (e.g., “nice,” “cute,” etc.), and “It seems X” is the so-called prompt. Adding prompts to the input makes it possible to condition the pre-trained models to solve many downstream tasks. Under the transfer learning setting, prompt tuning prepends a set of learnable prompts to the input embedding to instruct the pre-trained backbone to learn a single downstream task, whereas designing fixed prompts requires prior knowledge and trial and error.

L2P maintains a learnable prompt pool in the continual learning scenario, where prompts can be flexibly grouped as subsets to work collaboratively. Each prompt is associated with a key discovered by reducing the cosine similarity loss between matched input query features. A query function then uses these keys to dynamically look up a subset of task-relevant prompts based on the input features. The query function maps inputs to the top-N closest keys in the prompt pool at test time, and the associated prompt embeddings are then fed to the rest of the model to generate the output prediction. The cross-entropy loss was used during training to optimize the quick pool and classification head.

Intuitively, similar input examples tend to select similar sets of prompts and vice versa. Thus, frequently shared prompts encode more generic knowledge, whereas other prompts encode more task-specific knowledge. Furthermore, prompts store high-level instructions while freezing lower-level pre-trained representations, reducing catastrophic forgetting even without a rehearsal buffer. The instance-wise query mechanism eliminates the need to know the task identity or boundaries, allowing this approach to address the understudied problem of task-agnostic continuous learning.

On representative benchmarks, the effectiveness of L2P was assessed in various baseline methods using an ImageNet pre-trained Vision Transformer (ViT). The naive baseline, referred to as Sequential in the graphs below, refers to sequentially training a single model on all tasks. The EWC model incorporates a regularisation term to reduce forgetting, whereas the Rehearsal model stores previous examples in a buffer for mixed training with current data. The accuracy and the average difference were measured between the best accuracy achieved during training and the final accuracy for all tasks to assess overall continual learning performance, called forgetting. The L2P outperforms the Sequential and EWC methods in both metrics.


Notably, L2P outperforms the Rehearsal method, which employs an additional buffer to save previous data. Because the L2P approach is orthogonal to Rehearsal, its performance could be improved even more if it also used a rehearsal buffer. In terms of accuracy and forgetting, L2P outperforms baseline methods. Accuracy is the average accuracy for all tasks, while forgetting is the average difference between the best accuracy achieved during training and the final accuracy for all tasks.

The prompt selection results were plotted from the instance-wise query strategy on two different benchmarks, one with similar tasks and the other with a mix of tasks. According to the findings, L2P encourages more knowledge sharing between similar tasks by using more shared prompts and less knowledge sharing between different tasks by using more task-specific prompts.

L2P is a novel approach to addressing critical challenges in continuous learning. L2P does not require a rehearsal buffer or a known task identity to achieve high performance at test time. Furthermore, it can handle a variety of complex continuous learning scenarios, including the problematic task-agnostic setting.

Refer to the published research paper or this GitHub link to learn more about it.