Home / Blog / I Slashed My Vision Transformer's Token Usage by 4...

I Slashed My Vision Transformer's Token Usage by 40% (Here's What Actually Worked)

By CaelLee | | 8 min read

I Slashed My Vision Transformer's Token Usage by 40% (Here's What Actually Worked)

TL;DR: Vision Transformers are computationally greedy by design. After one particularly painful 3 AM debugging session in a Berlin café, I discovered how to measure token redundancy and dynamically prune the useless ones. The result? 42% fewer FLOPs with a negligible 0.4% accuracy hit. I'll walk you through the approach, the maths, and—critically—the stupid mistake that made my model classify everything as a goldfish.

The "Why Is My Laptop Trying to Achieve Lift-Off?" Moment

Last month, I was fine-tuning a ViT model for a client project. Everything worked beautifully on sample images—smooth training, decent accuracy, the usual. Then I threw a batch of 100 high-res photos at it.

My laptop fan kicked in so aggressively I genuinely thought it was about to achieve vertical take-off. Training time: four hours. Accuracy: deeply underwhelming.

The problem? We treat every single image patch as equally important. But look at any photo—half of it is sky, grass, or blurry background. Why are we paying full computational attention to empty patches?

I spent three hours hunting for a bug before I accepted the truth: there wasn't one.

It was a design issue.

Actually, wait—calling it a "design flaw" is probably unfair. The standard ViT architecture makes perfect sense if you're optimising for conceptual simplicity. It's just... inefficient for real-world use. There's a reason I was debugging at 3 AM in this tiny café near Warschauer Straße, running on my fourth espresso and genuinely questioning my career choices. Sometimes the obvious stuff only becomes obvious after enough caffeine and self-doubt.

What's Actually Happening Inside Your Vision Transformer

When you feed an image to a ViT, it gets chopped into patches. A 224×224 image becomes 196 tokens (that's a 14×14 grid, if you're counting). Each token costs the same to process in the self-attention layers.

But here's the thing: not all tokens are created equal.

Some patches contain the main subject—the dog, the car, the person. Others are just... noise. Background texture. Empty sky. Yet we compute attention between all of them. That's O(n²) complexity for n tokens.

Ouch.


# Standard ViT forward pass - all tokens treated equally
class VanillaViT(nn.Module):
 def forward(self, x):
 # x shape: (batch, 196, 768) - all 196 tokens
 for layer in self.transformer_layers:
 x = layer(x) # Every token attends to every other token
 return x

The key insight—the one that finally clicked for me around 4 AM—is dead simple: if we can measure which tokens are redundant, we can drop them early and save massive computation. The model doesn't need to agonise over whether patch #137 (a boring bit of sky) should attend to patch #138 (also sky).

Quantifying Redundancy: The Attention Score Trick

Here's where it gets interesting. We can measure token importance by looking at the attention weights from the [CLS] token.

The [CLS] token is that special token that aggregates global information. In the final layers, its attention distribution tells us which patches the model actually cares about. Think of it as the model saying "these bits matter, those bits don't."

I ran an experiment on 1,000 random images from ImageNet, using a ViT-B/16 from timm==0.9.16 with PyTorch 2.1.2:

That's a lot of wasted computation. Like, genuinely embarrassing amounts.


def compute_token_importance(attention_weights, cls_index=0):
 """
 Extract [CLS] attention to all other tokens.
 attention_weights: (num_heads, seq_len, seq_len)
 Returns: importance scores for each token
 """
 # Average across all heads
 cls_attention = attention_weights[:, cls_index, :].mean(dim=0)
 # Remove self-attention of CLS token
 token_scores = cls_attention[1:] # Skip CLS itself
 return token_scores / token_scores.sum()

💡 Here's something I learned the hard way: don't just use the last layer. I found that averaging attention from the last three layers gives much more stable importance scores. My first attempt—using only layer 12—was wildly inconsistent. We're talking "classifying dogs as fire hydrants" levels of inconsistency. Not ideal.

Dynamic Compression: Prune as You Go

Static pruning—removing the same number of tokens every time—is too rigid. Some images need 50 tokens, others need 150. A photo of a lone cat on a white background? You can prune aggressively. A busy street scene? You'll want to keep more.

The joint optimisation approach works like this:

  1. Quantify redundancy at each transformer layer using attention scores
  2. Set a dynamic threshold based on cumulative importance
  3. Prune tokens that fall below the threshold
  4. Keep a minimum (I use 20% of original tokens) to avoid over-pruning

Well... that's the simplified version. Step 2 is where most of the magic—and most of my debugging hours—actually lives.

Here's the core logic I implemented:


def dynamic_prune(tokens, scores, keep_ratio=0.6, min_tokens=40):
 """
 tokens: (batch, seq_len, dim)
 scores: (batch, seq_len) - importance scores
 """
 batch_size, seq_len, dim = tokens.shape
 num_keep = max(int(seq_len * keep_ratio), min_tokens)
 
 # Select top-k tokens based on scores
 _, indices = torch.topk(scores, num_keep, dim=1)
 
 # Gather the important tokens
 pruned_tokens = torch.gather(
 tokens, 1, 
 indices.unsqueeze(-1).expand(-1, -1, dim)
 )
 
 return pruned_tokens, indices

The magic is in the keep_ratio parameter. I made it adaptive: higher for early layers (keep 80%), lower for deeper layers (keep 40%). Early layers need more context to figure out what's important. Later layers have already identified what matters and can afford to be more aggressive.

I should probably mention—this isn't entirely my idea. The Token Merging paper from Meta and the DynamicViT work both influenced this approach. I just... cobbled together the parts that worked for my specific use case. Standing on the shoulders of giants and all that.

Real Results (and One Genuinely Embarrassing Failure)

After implementing this on a ViT-Base model, trained on a single RTX 3090 (borrowed from a friend—don't ask, it's a whole thing):

MetricBeforeAfterChange
FLOPs17.6G10.2G-42%
Inference time23ms14ms-39%

That 0.4% accuracy drop? I'll take it for 40% faster inference. Any day of the week.

But here's my embarrassing moment—the one I probably shouldn't admit in public but here we are.

In my first implementation, I accidentally pruned the [CLS] token itself. The model started classifying everything as "goldfish."

Everything.

I'm not exaggerating. Here's an actual log line from that disaster:


[2024-11-17 03:42:18] Image: airplane.jpg → Predicted: goldfish (99.7%)
[2024-11-17 03:42:18] Image: car.jpg → Predicted: goldfish (99.2%)
[2024-11-17 03:42:19] Image: dog.jpg → Predicted: goldfish (98.9%)

Took me two hours and four coffees to spot that one. ☕☕☕☕ The bug was literally a one-line fix—I was slicing [1:] on the wrong dimension. Classic.

The moral of the story? Always double-check which dimension you're pruning. And maybe don't do critical tensor operations at 3 AM.

Where This Gets Really Exciting

The joint optimisation framework opens up some genuinely cool possibilities:

I'm currently experimenting with the third approach. Early results show 25% faster training with nearly identical final accuracy. But honestly, it's been finicky—the pruning schedule during training is way more sensitive than I expected. Some runs just collapse around epoch 30 and never recover. Still figuring that out.

Oh, and at the Berlin ML meetup last Tuesday (the one at ThoughtWorks near Hackescher Markt), someone asked about video transformers. That's... probably the next frontier? Video token redundancy has got to be even higher. Imagine 16 frames × 196 tokens each. You're looking at 3,136 tokens for a one-second clip. Most of those are near-identical between frames.

I haven't tried it yet. It's on my list. Right after I finish this client project and maybe sleep for a weekend.

Getting Started with Your Own Implementation

Want to try this yourself? Here's my recommended approach:

  1. Start with a pre-trained ViT from timm or HuggingFace
  2. Add attention score extraction hooks
  3. Implement the dynamic pruning between transformer layers
  4. Fine-tune for 5-10 epochs to let the model adapt to pruned inputs
  5. Profile with different keep_ratio values

The full code is too long for this post, but the snippets above give you the core logic. The key is to start conservative (keep 70-80%) and gradually increase pruning as you validate accuracy.

One thing I learned the hard way: don't prune during the first 2-3 layers. The model needs those early representations intact. Start pruning around layer 4 or 5. Trust me on this one—I spent a week trying to make early-layer pruning work before accepting it was a dead end.

Key Takeaways for the Skimmers

Anyway, that's what I've been obsessing over for the past few weeks. Would love to hear if anyone else has played with token pruning—especially if you've tried it on detection or segmentation tasks. Those feel like they'd need a completely different pruning strategy, but I could be wrong.

Drop a comment or find me at the next Berlin ML meetup. I'll be the one with too much coffee and strongly-held opinions about attention mechanisms. 🚀

machinelearning #computervision #deeplearning #pytorch #performance

Top-1 Accuracy81.2%80.8%-0.4%
C

Cael Lee

Full-stack developer with 8+ years of experience. Currently building AI-powered developer tools. I've tested 20+ AI API providers and coding assistants.

Ready to get started?

Get your API key and start building with 180+ AI models.

Get API Key Free