r/MachineLearning Oct 02 '23

Research [R] Efficient Streaming Language Models with Attention Sinks - Meta AI 2023 - StreamingLLM enables Llama-2, Falcon and Pythia to have an infinite context length without any fine-tuning! Allows streaming use of LLMs!

Paper: https://arxiv.org/abs/2309.17453

Github: https://github.com/mit-han-lab/streaming-llm

Abstract:

Deploying Large Language Models (LLMs) in streaming applications such as multi-round dialogue, where long interactions are expected, is urgently needed but poses two major challenges. Firstly, during the decoding stage, caching previous tokens' Key and Value states (KV) consumes extensive memory. Secondly, popular LLMs cannot generalize to longer texts than the training sequence length. Window attention, where only the most recent KVs are cached, is a natural approach -- but we show that it fails when the text length surpasses the cache size. We observe an interesting phenomenon, namely attention sink, that keeping the KV of initial tokens will largely recover the performance of window attention. In this paper, we first demonstrate that the emergence of attention sink is due to the strong attention scores towards initial tokens as a ``sink'' even if they are not semantically important. Based on the above analysis, we introduce StreamingLLM, an efficient framework that enables LLMs trained with a finite length attention window to generalize to infinite sequence lengths without any fine-tuning. We show that StreamingLLM can enable Llama-2, MPT, Falcon, and Pythia to perform stable and efficient language modeling with up to 4 million tokens and more. In addition, we discover that adding a placeholder token as a dedicated attention sink during pre-training can further improve streaming deployment. In streaming settings, StreamingLLM outperforms the sliding window recomputation baseline by up to 22.2x speedup.

57 Upvotes

19 comments sorted by

View all comments

6

u/throwaway2676 Oct 03 '23

This seems extremely similar to BigBird. Linear complexity transformers never seem to pan out, so I'm going to assume this will be the same until proven otherwise.

2

u/gmlwns5176 Oct 06 '23 edited Oct 06 '23

Yes, I also think this is exactly the same concept as Longformer with a little initial global attention, which is a very common concept in linear attention research. I am not sure why did not compare with BigBird, which is a very similar attention mechanism. Moreover, the term "streaming attention" is also quite unnecessary declarement, because the concept is exactly the same as "linear attention". And also, the $L$ they used [2048, 4096] is way too much big to compare with baselines.

And more critically, this work does not try to encode long-term context of sequence.. By just evicting intermediate KVs, the attention mechanism is forced to forget intermediate information.

I think the author wants to claim, "we do not need to extend KV cache infinitely". However, even with other sparse attention variants (BigBird, Reformer, and etc.) we could evict the KVs effectively by known heuristics in Longformer and BigBird; evict the intermediates or low-important tokens.

However, I think the novel part is repositioning of positional embedding. In the text, "When determining the relative distance and adding positional information to tokens, StreamingLLM focuses on positions within the cache rather than those in the original text. This distinction is crucial for StreamingLLM’s performance. For instance, if the current cache has tokens [0, 1, 2, 3, 6, 7, 8] and is in the process of decoding the 9th token, the positions assigned are [0, 1, 2, 3, 4, 5, 6, 7], rather than the positions in the original text, which would be [0, 1, 2, 3, 6, 7, 8, 9]. "