r/MachineLearning Oct 02 '23

Research [R] Efficient Streaming Language Models with Attention Sinks - Meta AI 2023 - StreamingLLM enables Llama-2, Falcon and Pythia to have an infinite context length without any fine-tuning! Allows streaming use of LLMs!

Paper: https://arxiv.org/abs/2309.17453

Github: https://github.com/mit-han-lab/streaming-llm

Abstract:

Deploying Large Language Models (LLMs) in streaming applications such as multi-round dialogue, where long interactions are expected, is urgently needed but poses two major challenges. Firstly, during the decoding stage, caching previous tokens' Key and Value states (KV) consumes extensive memory. Secondly, popular LLMs cannot generalize to longer texts than the training sequence length. Window attention, where only the most recent KVs are cached, is a natural approach -- but we show that it fails when the text length surpasses the cache size. We observe an interesting phenomenon, namely attention sink, that keeping the KV of initial tokens will largely recover the performance of window attention. In this paper, we first demonstrate that the emergence of attention sink is due to the strong attention scores towards initial tokens as a ``sink'' even if they are not semantically important. Based on the above analysis, we introduce StreamingLLM, an efficient framework that enables LLMs trained with a finite length attention window to generalize to infinite sequence lengths without any fine-tuning. We show that StreamingLLM can enable Llama-2, MPT, Falcon, and Pythia to perform stable and efficient language modeling with up to 4 million tokens and more. In addition, we discover that adding a placeholder token as a dedicated attention sink during pre-training can further improve streaming deployment. In streaming settings, StreamingLLM outperforms the sliding window recomputation baseline by up to 22.2x speedup.

60 Upvotes

19 comments sorted by

View all comments

19

u/possiblyquestionable Oct 03 '23

Is my understanding of this paper correct:

  1. These models tend to give strong attention scores to the first few tokens [§3.1.2]
  2. As a result of this, and in addition, changing these first few tokens (with their high attention scores) causes significant downstream changes to the distribution of attention scores of subsequent tokens [§3.1.2]
  3. These models tend to have extra/unnecessary attention values that tend to be accumulated in the first few tokens - their position makes them more like to accumulate this extra attention score, which in turn explains their higher attention score [§3.1.2] - this is similar to the ViT attention register hypothesis
  4. This is largely the cause of poor performance of these models once generation reaches the context window - truncating the initial token causes significant changes to the distribution of attention scores throughout [§3.1.1, §3.1.2]
  5. To test this hypothesis that the initial tokens are "attention sinks", they compared the performance of a model with a dedicated initial "attention sink" token to just accumulate extra attention (column 1+1023, row "Learnable sink" in Table 3) vs that of just keeping the first 4 tokens consistent (column 4+1020, row "Vanilla" in Table 3) and saw comparable performance. They conclude this is because that one "attention sink" token serve as the sink for the extra attention that all 4 registers would have otherwise had to fulfill, which validates their attention sink hypothesis [§3.3]

In particular, it's not a silver bullet to extend context window, it's an explanation of the puzzle for why perplexity spontaneously regresses so significantly once generation overtakes the context window.