Meta AI Researchers Propose Token Merging (ToMe) to Make Vision Transformers Run Faster

Vision transformers (ViT) were introduced to the literature two years ago, and they became a core component of computer vision research. Taking a component that worked exceptionally well in language tasks and converting it into the computer vision domain was a bold move, but it worked. Since then, advancement in the computer vision domain has accelerated.

Transformers in the computer vision domain differ from their natural language processing (NLP) counterparts. They are dominated by vision-specific transformer hybrid models which use vision-specific attention modules. Adding vision-specific biases enables these hybrid transformer models to be more efficient.

Vanilla ViTs still have many desirable characteristics despite being outperformed in terms of cost to performance. They are made up of simple matrix multiplications, making them faster than their raw flop count would suggest. Moreover, they support strong self-supervised pre-training techniques like MAE that can produce state-of-the-art results while being quick to train, and because they make no assumptions about the data, they can be used across many modalities with little to no changes.

Token merging. Source: https://arxiv.org/pdf/2210.09461.pdf

As good as it sounds, everything comes with a cost and that cost is the massive size for ViTs. It can be problematic to run these massive models to reproduce their results. 

There have been some studies to tackle this problem. Token-pruning is one of them. Tokens can be trimmed at runtime to enable a quicker model since transformers are input-agnostic. However, this approach has several issues, the main one being the information loss that occurs due to the elimination of some tokens. You cannot just prune every token, you have a limit on how many tokens you can reduce before the information loss becomes too high. Also, existing methods require you to train the model again to be effective with pruned tokens. 

So, token pruning is not the way to go, and we still want to use ViTs. But we cannot use them in most cases as they are still too slow. What could be the solution? How could we speed up ViTs similar to pruning but keep the accuracy way higher than the pruning? We have an answer to these questions, and it is called Token Merging.

Token Merging (or ToMe) combines tokens instead of pruning them, and thanks to its custom matching algorithm, it is as fast as pruning while being more accurate. Also, it works without needing any additional training, so you can use it on huge models to speed them up without sacrificing much accuracy. 

The goal is to integrate a token merging module into an existing ViT to boost training and inference throughput by combining redundant tokens, without necessarily requiring training. 

Token merging is applied between the attention and MLP branches of each transformer block. Doing so allows information to be propagated from tokens that would be merged and enable the ViT to use features within the attention module to decide what to merge.

Visualization of token merging on some images. Source: https://arxiv.org/pdf/2210.09461.pdf

The first step for merging is to determine the similar tokens. This is relatively straightforward to achieve in ViT, thanks to the already extracted QKV (query, key, value) properties. Keys already summarize the tokens, so all there is left is to use a dot product similarity metric between the keys of each token.

Once the token similarities are found, the next step is to match them. This is the tricky part, as it should be really fast, so it is not possible to use existing solutions like k-means or graph cuts. Token Merging uses a novel bipartite soft matching solution to solve this matching problem.

This was a brief summary of Token Merging, a unique technique to increase the throughput and real-world training speed of ViT models. Using token merging can double the training speed in some cases. It can be used for image, video, and audio tasks and still achieve state-of-the-art accuracy. 

This Article is written as a research summary article by Marktechpost Staff based on the research paper 'TOKEN MERGING: YOUR VIT BUT FASTER'. All Credit For This Research Goes To Researchers on This Project. Check out the paper and github code.
Please Don't Forget To Join Our ML Subreddit

Ekrem Çetinkaya received his B.Sc. in 2018 and M.Sc. in 2019 from Ozyegin University, Istanbul, Türkiye. He wrote his M.Sc. thesis about image denoising using deep convolutional networks. He is currently pursuing a Ph.D. degree at the University of Klagenfurt, Austria, and working as a researcher on the ATHENA project. His research interests include deep learning, computer vision, and multimedia networking.