How A Language Model Decides What To Say Next? This New AI Method Called Tuned Lens Can Trace A Language Model’s Prediction As It Develops From One Layer To The Next

The excellent performance of transformers in computer vision and natural language processing justifies research into the internal representations of these systems. Methods that involve training classifiers to infer latent features (such as part-of-speech and syntactic structure) are prevalent.

Eleuther AI, FAR AI, Boston University, the University of Toronto, and UC Berkeley collaborated on a study that uses an iterative inference lens to transform representations. Each layer of a transformer language model is considered to enhance a latent prediction of the next token by a small amount. Using early exiting, the researchers decode these hidden predictions by mapping the burden state at each intermediate layer onto a vocabulary-wide distribution. The resulting distribution sequence, called the prediction trajectory, has a high probability of smoothly converging to the output distribution as the number of hidden layers increases and the complexity decreases.

The researchers expand upon the “logit lens,” a method of early termination that uses the model’s pre-trained unembedding matrix to directly decode hidden states into vocabulary space. While the logit lens may seem like a good idea initially, the team found that it yields implausible results when applied to realistic models like BLOOM and GPT Neo. On some LMs, it produces gibberish; when it functions, it favors certain tokens over others.

The recent work introduces the tuned lens as a solution to the problems with the logit lens.

The researchers hypothesize that since distinct layers each “speak their language,” using the unembedding layer to decode secret information would be inappropriate. An affine transformation (“translator”) that performs a learned change-of-basis at each layer allows for improved results. Using a distillation loss, they train L affine transformations (one for each layer of the network) to transform the hidden state such that its image under the unembedding closely matches the final layer logits.

When comparing tuned lens predictions to logit lens predictions, the team found that the latter are more accurate representations of the final layer distribution and have significantly lower perplexity. Their results also demonstrate that the characteristics most important to the tuned lens’s output are also important to the model. To accomplish this, they introduce a new algorithm called causal basis extraction (CBE). CBE is used to identify the directions in the residual stream that have the greatest impact on the tuned lens. These axes are removed from the corresponding hidden states in the model and discover that they have an outsized impact on the model’s predictions. 

To quantify the difficulty of a task for an LM, the team uses the tuned lens to define “prediction depth”: an LM’s task is considered simple if it converges on a solution in the first few layers of its network. Later-layer data classifications are also typically learned later in training.

The tuned lens is used in a few different contexts: Tuned lens prediction trajectories can be used to detect prompt injection attacks with high accuracy, and we find that data points that require many training steps to learn also tend to be classified in later layers, extending the results to new models.

Check out the Paper and Github. All Credit For This Research Goes To the Researchers on This Project. Also, don’t forget to join our 16k+ ML SubRedditDiscord Channel, and Email Newsletter, where we share the latest AI research news, cool AI projects, and more.

🐝 Join the Fastest Growing AI Research Newsletter Read by Researchers from Google + NVIDIA + Meta + Stanford + MIT + Microsoft and many others...