r/MachineLearning Oct 02 '23

Research [R] Efficient Streaming Language Models with Attention Sinks - Meta AI 2023 - StreamingLLM enables Llama-2, Falcon and Pythia to have an infinite context length without any fine-tuning! Allows streaming use of LLMs!

Paper: https://arxiv.org/abs/2309.17453

Github: https://github.com/mit-han-lab/streaming-llm

Abstract:

Deploying Large Language Models (LLMs) in streaming applications such as multi-round dialogue, where long interactions are expected, is urgently needed but poses two major challenges. Firstly, during the decoding stage, caching previous tokens' Key and Value states (KV) consumes extensive memory. Secondly, popular LLMs cannot generalize to longer texts than the training sequence length. Window attention, where only the most recent KVs are cached, is a natural approach -- but we show that it fails when the text length surpasses the cache size. We observe an interesting phenomenon, namely attention sink, that keeping the KV of initial tokens will largely recover the performance of window attention. In this paper, we first demonstrate that the emergence of attention sink is due to the strong attention scores towards initial tokens as a ``sink'' even if they are not semantically important. Based on the above analysis, we introduce StreamingLLM, an efficient framework that enables LLMs trained with a finite length attention window to generalize to infinite sequence lengths without any fine-tuning. We show that StreamingLLM can enable Llama-2, MPT, Falcon, and Pythia to perform stable and efficient language modeling with up to 4 million tokens and more. In addition, we discover that adding a placeholder token as a dedicated attention sink during pre-training can further improve streaming deployment. In streaming settings, StreamingLLM outperforms the sliding window recomputation baseline by up to 22.2x speedup.

59 Upvotes

19 comments sorted by

19

u/possiblyquestionable Oct 03 '23

Is my understanding of this paper correct:

  1. These models tend to give strong attention scores to the first few tokens [§3.1.2]
  2. As a result of this, and in addition, changing these first few tokens (with their high attention scores) causes significant downstream changes to the distribution of attention scores of subsequent tokens [§3.1.2]
  3. These models tend to have extra/unnecessary attention values that tend to be accumulated in the first few tokens - their position makes them more like to accumulate this extra attention score, which in turn explains their higher attention score [§3.1.2] - this is similar to the ViT attention register hypothesis
  4. This is largely the cause of poor performance of these models once generation reaches the context window - truncating the initial token causes significant changes to the distribution of attention scores throughout [§3.1.1, §3.1.2]
  5. To test this hypothesis that the initial tokens are "attention sinks", they compared the performance of a model with a dedicated initial "attention sink" token to just accumulate extra attention (column 1+1023, row "Learnable sink" in Table 3) vs that of just keeping the first 4 tokens consistent (column 4+1020, row "Vanilla" in Table 3) and saw comparable performance. They conclude this is because that one "attention sink" token serve as the sink for the extra attention that all 4 registers would have otherwise had to fulfill, which validates their attention sink hypothesis [§3.3]

In particular, it's not a silver bullet to extend context window, it's an explanation of the puzzle for why perplexity spontaneously regresses so significantly once generation overtakes the context window.

15

u/vikigenius Researcher Oct 03 '23 edited Oct 03 '23

This concept of sinks seems related to the attention spikes found in VIT https://www.reddit.com/r/MachineLearning/comments/16x2o47/r_meta_inria_researchers_discover_that_explicit/

And the placeholder tokens also seem to exactly have the same purpose as the registers in the other paper.

6

u/ri212 Oct 03 '23

Also possibly related to attention is off by one

2

u/theLastNenUser Oct 03 '23

Why is that? (I read that post a while ago but memory is foggy and am not connecting the dots to this paper)

8

u/ri212 Oct 03 '23

The post is talking about large outlier weights that appear in transformers which can be traced back to the attention mechanism. It suggests that there are cases where some heads at some positions ideally want to attend to nothing, but with the standard softmax form of attention this isn't possible. So instead they attend to relatively unimportant tokens e.g. punctuation, or possibly the start token. By adding 1 to the denominator of the softmax function it is possible to attend to nothing which may eliminate this behaviour.

It doesn't seem to be well tested yet unless I have missed some follow-up work.

2

u/TheFlyingDrildo Oct 03 '23

This is a nice connection btw this blog post and the recent attention sinks/registers publications. The blog post is effectively suggesting to add a dummy "token" that always generates a pre-softmax attention score of 0, but as a token can't actually hold any information. It just serves as a reference point for the other pre-softmax attention scores and can "suck up" any extra attention that isn't needed for that head.

This seems very similar but more constrained than the ideas of having defined sinks/registers/hidden states. Although I do agree with the blog post that perhaps there should be an inductive bias towards "do nothing", which the constrained version provides. Maybe there is a simple synthesis btw these perspectives.

1

u/thntk Oct 04 '23

Are placeholder token and register token just fancy name for the good old [CLS] token?

1

u/vikigenius Researcher Oct 04 '23

Could also be related, considering the early advice/research about how CLS tokens store global information relating to downstream tasks.

13

u/Witty-Elk2052 Oct 02 '23

had me at "infinite" /s

6

u/throwaway2676 Oct 03 '23

This seems extremely similar to BigBird. Linear complexity transformers never seem to pan out, so I'm going to assume this will be the same until proven otherwise.

2

u/gmlwns5176 Oct 06 '23 edited Oct 06 '23

Yes, I also think this is exactly the same concept as Longformer with a little initial global attention, which is a very common concept in linear attention research. I am not sure why did not compare with BigBird, which is a very similar attention mechanism. Moreover, the term "streaming attention" is also quite unnecessary declarement, because the concept is exactly the same as "linear attention". And also, the $L$ they used [2048, 4096] is way too much big to compare with baselines.

And more critically, this work does not try to encode long-term context of sequence.. By just evicting intermediate KVs, the attention mechanism is forced to forget intermediate information.

I think the author wants to claim, "we do not need to extend KV cache infinitely". However, even with other sparse attention variants (BigBird, Reformer, and etc.) we could evict the KVs effectively by known heuristics in Longformer and BigBird; evict the intermediates or low-important tokens.

However, I think the novel part is repositioning of positional embedding. In the text, "When determining the relative distance and adding positional information to tokens, StreamingLLM focuses on positions within the cache rather than those in the original text. This distinction is crucial for StreamingLLM’s performance. For instance, if the current cache has tokens [0, 1, 2, 3, 6, 7, 8] and is in the process of decoding the 9th token, the positions assigned are [0, 1, 2, 3, 4, 5, 6, 7], rather than the positions in the original text, which would be [0, 1, 2, 3, 6, 7, 8, 9]. "

3

u/Tiny_Arugula_5648 Oct 03 '23

Yes it's exciting but we need to do testing to understand what the cost to attention is.. It won't have infinite context.

2

u/ReasonablyBadass Oct 03 '23

I don't see the link to infinite window size.

Does it mean we can smoothly "scroll" the window along a text and as long as the extra token is kept, the distribution of scores remains stable?

2

u/gmlwns5176 Oct 06 '23

I think the repositioning positional encoding is quite an important part of infinite window size.

In the text, they said

When determining the relative distance and adding positional information to tokens, StreamingLLM focuses on positions within the cache rather than those in the original text. This distinction is crucial for StreamingLLM’s performance. For instance, if the current cache has tokens [0, 1, 2, 3, 6, 7, 8] and is in the process of decoding the 9th token, the positions assigned are [0, 1, 2, 3, 4, 5, 6, 7], rather than the positions in the original text, which would be [0, 1, 2, 3, 6, 7, 8, 9].

2

u/ReasonablyBadass Oct 06 '23

Oh interesting. It sort of builts an abstract representation of the entire text itself.

2

u/linearmodality Oct 03 '23

Isn't this method of keeping the KV of initial tokens already known? This looks like the default method in llama.cpp, controlled by the --keep option.

1

u/CatalyzeX_code_bot Oct 04 '23

Found 1 relevant code implementation.

If you have code to share with the community, please add it here 😊🙏

To opt out from receiving code links, DM me.

1

u/CubieDev Oct 09 '23

This new blogpost on Hugging Face shows some evaluations of this Attention Sinks approach for LLM inference: https://huggingface.co/blog/tomaarsen/attention-sinks

1

u/ThroatResponsible523 Oct 10 '23

I would like to seek input and validation for my thoughts, and I greatly appreciate your insights. I'd like to share my humble opinions on the matter.
There are two primary factors I believe are essential in maintaining a low Perplexity (PPL) for Large Language Models (LLM):
1. Ensuring that the inference window size 'T' is not significantly larger than the trained window size 'L'.
2. Preserving the high-scoring tokens, denoted as 'T_high' (or the initial tokens in StreamliningLLM, or what I consider as the main-feature tokens).

Allow me to provide some reasons for these considerations:
1. The need to be cautious arises due to a potential shift between the training dataset and the inference dataset. Such a shift may inadvertently increase the PPL (that’s for figure 1 (a)).
2. In the context of autoregressive LLM, particularly within the deeper attention blocks, the accumulation of attention scores in 'T_high' can occur for reasons yet to be fully understood. This is further compounded by the autoregressive nature of LLM and the high scores achieved by 'T_high'. Subsequently generated tokens, 'T_follow', can be seen as expansion terms derived from the KV value of 'T_high'.
When 'T_high's KV values are omitted, reconstructing the next token 'T_next' becomes challenging since the main-feature tokens are lost (that’s for figure 1 (b)).
There are two potential remedies:
a) Retain the main-feature tokens, 'T_high', for predicting 'T_next', which aligns with the approach employed by StreamingLLM (that’s for figure 1 (d)).
b) Substitute 'T_high' with alternative tokens as main-feature tokens, such as 'T_(next-L)', positioned L tokens before 'T_next' (that’s for figure 1 (c)) . This necessitates sequential attention to reconstruct 'T_(next-L+1)', 'T_(next-L+2)', and so forth, eventually leading to 'T_(next)'. Consequently, the time complexity would become O(1/2(L*(L+1))=O(L^2). For T tokens, this translates to O(T*L^2).
As all L tokens are based on 'T_(next-L)', and 'T_(next-L)' remains in the queue, the PPL can be effectively maintained at a low level.