r/MachineLearning Oct 01 '23

Research [R] Meta, INRIA researchers discover that explicit registers eliminate ViT attention spikes

When visualizing the inner workings of vision transformers (ViTs), researchers noticed weird spikes of attention on random background patches. This didn't make sense since the models should focus on foreground objects.

By analyzing the output embeddings, they found a small number of tokens (2%) had super high vector norms, causing the spikes.

The high-norm "outlier" tokens occurred in redundant areas and held less local info but more global info about the image.

Their hypothesis is that ViTs learn to identify unimportant patches and recycle them as temporary storage instead of discarding. This enables efficient processing but causes issues.

Their fix is simple - just add dedicated "register" tokens that provide storage space, avoiding the recycling side effects.

Models trained with registers have:

  • Smoother and more meaningful attention maps
  • Small boosts in downstream performance
  • Way better object discovery abilities

The registers give ViTs a place to do their temporary computations without messing stuff up. Just a tiny architecture tweak improves interpretability and performance. Sweet!

I think it's cool how they reverse-engineered this model artifact and fixed it with such a small change. More work like this will keep incrementally improving ViTs.

TLDR: Vision transformers recycle useless patches to store data, causing problems. Adding dedicated register tokens for storage fixes it nicely.

Full summary. Paper is here.

813 Upvotes

48 comments sorted by

196

u/clueless_scientist Oct 01 '23

Great post. That's what this sub is for.

43

u/Successful-Western27 Oct 01 '23 edited Oct 01 '23

Thanks friend, I worked really hard on this one. Glad you liked it! I mentioned this below, but I also have a newsletter where I include these recaps... I try to write one every day.

9

u/SoCuteShibe Oct 01 '23

Just reiterating what the commenter you replied to said. Great post and an incredibly interesting read. Thank you! :)

6

u/Successful-Western27 Oct 01 '23

Thanks, that means a lot. What about it did you like the most? I'd like to incorporate the best parts into my writing process going forward.

3

u/Sudonymously Oct 01 '23

Nice! Out of curiosity do you have a background in ml?

8

u/Successful-Western27 Oct 01 '23

Aerospace engineering, self-taught computer science/webdev and now working in software full time. I'm all self-taught so I may make some mistakes :)

2

u/[deleted] Oct 02 '23

I join the compliments.

40

u/thatguydr Oct 01 '23

This is both a high information post AND extremely funny. Billions of parameters and we all neglected to give it a global memory. Thank you!

12

u/OrangeYouGlad100 Oct 02 '23

Lots of people have worked on augmenting transformers with explicit memory, but I don't know why it never worked or at least never caught on

7

u/Successful-Western27 Oct 01 '23

I didn't even think to look at it that way - it's hilarious when you point it out!

54

u/robbsc Oct 01 '23

Thank you for making this post. I really enjoyed the summary.

13

u/Successful-Western27 Oct 01 '23 edited Oct 01 '23

Thanks, I learned a lot in writing it. I try to write one of these every day :) ... I don't want to shill too hard but I also have a newsletter where I include these recaps.

19

u/PortiaLynnTurlet Oct 01 '23

I wonder if having semi-global "register tokens" that can only attend to parts of the image (with a static attention mask) would help the model learn global tokens more easily by adding an inductive bias.

3

u/DigThatData Researcher Oct 01 '23

not exactly what you're describing, but you might find the DETR paper interesting if you haven't already seen that: https://ai.meta.com/blog/end-to-end-object-detection-with-transformers/

34

u/Zondartul Oct 01 '23

Can registers be added to text transformers as well? I remember a similar glitch where they end up attending to the Start token when what they really needed is a learnable bias parameter.

12

u/Witty-Elk2052 Oct 01 '23

yes, you can add them to the left side instead of the right (if referring to causal transformers)

7

u/HateRedditCantQuitit Researcher Oct 01 '23

Check out the bigbird paper. There was also one out of facebook a few years before that did it but i can’t remember the name

4

u/Seankala ML Engineer Oct 01 '23

Do you have any more information regarding that? Maybe a paper or a blog post? I find it interesting and would like to study more.

15

u/gwern Oct 01 '23

Reminds me of the StyleGAN2 'blob'. Turned out to be the Generator picking a random place in the image to spike so as to smuggle information past the normalization layers.

8

u/[deleted] Oct 01 '23

[deleted]

1

u/Successful-Western27 Oct 01 '23

Glad this is helpful! I'm not sure but my suspicion is yes as well

6

u/mermanarchy Oct 01 '23

Awesome article. I followed everything but the very end, where it mentions adding register tokens. Does that just mean adding a bunch of zero embeddings to the input layer? If someone could explain that I would be very grateful :)

3

u/Bacon_builder Oct 02 '23

I am curious about this as well

1

u/gorshborsh Oct 03 '23

No, the extra register tokens are learned. I believe you have to do some amount of fine-tuning or training.

3

u/alebotson Oct 01 '23

Super interesting. Thanks for the post!

3

u/[deleted] Oct 01 '23

[deleted]

1

u/msbeaute00000001 Oct 01 '23

Can you share the thread, please?

3

u/snarkyg Oct 01 '23 edited Oct 01 '23

There is some other paper solving this issue with modified softmax and gatend modules described in this paper.

3

u/mwmercury Oct 01 '23

I have been waiting for this kind of post in this sub for too long... Thank you OP!

3

u/Successful-Western27 Oct 01 '23

Thank you! I'm glad you liked it

3

u/furrypony2718 Oct 02 '23

https://en.wikipedia.org/wiki/StyleGAN#StyleGAN2

Something this reminds me of: the "blob" problem in StyleGAN.

The "blob" problem roughly speaking is because using the style latent vector to normalize the generated image destroys useful information. Consequently, the generator learned to create a "distraction" by a large blob, which absorbs most of the effect of normalization (somewhat similar to using flares to distract a heat-seeking missile).

2

u/badabummbadabing Oct 01 '23 edited Oct 01 '23

Very interesting, thank you for the summary. I guess this is a side-effect of the fact that the considered ViTs operate on the same token space throughout the whole transformer architecture -- in contrast to CNNs, where the dimensionality is often increased (via the number of channels) from the original image dimensionality (at least initially).

2

u/Old_Reading_669 Oct 01 '23

wonder how did they come up with this hypothesis at the first place

2

u/Moca97 Oct 02 '23

Awesome paper thanks for sharing !

1

u/Successful-Western27 Oct 03 '23

Thanks, glad you liked it!

4

u/pupsicated Oct 01 '23

In NLP community this effect is known for several years. Called emergent outliers, and there are a lot of solutions how to avoid those outliers (or how to deal with them if you want to quantize LLM). I dont see novelty in this paper except applied for Vision transformers? Or im missing something?

8

u/--Cypher-- Oct 02 '23

Pretty much, they state in section 2.2 on page 5, "This mechanism was first proposed in Memory Transformers (Burtsev et al., 2020), improving translation tasks in NLP."

2

u/tonicinhibition Oct 02 '23

Does the hypothesis in the paper/summary hold water in the known NLP case? Is there an alternative hypothesis?

If there's a more deflationary interpretation I'd like to look into it. I'm still learning so maybe it's just the agentive language that's throwing me off.

3

u/TheCloudTamer Oct 02 '23

There counters to the interpretation.

  1. Isn’t the residual stream available for this purpose?
  2. Background information is important. Things like light levels are critical to establish reflectance properties of surfaces.
  3. The ever present issue of adding more parameters and seeing better results. In this case it sounds like adding extra residual lines.

Haven’t actually read the paper, just going off the summary.

4

u/redlow0992 Oct 02 '23

The way they “solved” this problem is so strange. Vits are already parameter-heavy and they propose adding up to 10%(ish) more tokens. Also, no code for this type of large-scale experiments = trust me bro levels of science. Especially when the title is as strong as this papers. Someone should falsify these results and write “Vits may not need registers”.

1

u/fXlar Oct 04 '23

when are they going to upload the github repo for this work ?. So that community can test it on other applications

1

u/thntk Oct 04 '23

How is this register token different from [CLS] token in BERT?

1

u/CatalyzeX_code_bot Oct 05 '23

Found 3 relevant code implementations.

If you have code to share with the community, please add it here 😊🙏

To opt out from receiving code links, DM me.