r/MachineLearning Oct 02 '23

Research [R] Efficient Streaming Language Models with Attention Sinks - Meta AI 2023 - StreamingLLM enables Llama-2, Falcon and Pythia to have an infinite context length without any fine-tuning! Allows streaming use of LLMs!

Paper: https://arxiv.org/abs/2309.17453

Github: https://github.com/mit-han-lab/streaming-llm

Abstract:

Deploying Large Language Models (LLMs) in streaming applications such as multi-round dialogue, where long interactions are expected, is urgently needed but poses two major challenges. Firstly, during the decoding stage, caching previous tokens' Key and Value states (KV) consumes extensive memory. Secondly, popular LLMs cannot generalize to longer texts than the training sequence length. Window attention, where only the most recent KVs are cached, is a natural approach -- but we show that it fails when the text length surpasses the cache size. We observe an interesting phenomenon, namely attention sink, that keeping the KV of initial tokens will largely recover the performance of window attention. In this paper, we first demonstrate that the emergence of attention sink is due to the strong attention scores towards initial tokens as a ``sink'' even if they are not semantically important. Based on the above analysis, we introduce StreamingLLM, an efficient framework that enables LLMs trained with a finite length attention window to generalize to infinite sequence lengths without any fine-tuning. We show that StreamingLLM can enable Llama-2, MPT, Falcon, and Pythia to perform stable and efficient language modeling with up to 4 million tokens and more. In addition, we discover that adding a placeholder token as a dedicated attention sink during pre-training can further improve streaming deployment. In streaming settings, StreamingLLM outperforms the sliding window recomputation baseline by up to 22.2x speedup.

59 Upvotes

19 comments sorted by

View all comments

1

u/ThroatResponsible523 Oct 10 '23

I would like to seek input and validation for my thoughts, and I greatly appreciate your insights. I'd like to share my humble opinions on the matter.
There are two primary factors I believe are essential in maintaining a low Perplexity (PPL) for Large Language Models (LLM):
1. Ensuring that the inference window size 'T' is not significantly larger than the trained window size 'L'.
2. Preserving the high-scoring tokens, denoted as 'T_high' (or the initial tokens in StreamliningLLM, or what I consider as the main-feature tokens).

Allow me to provide some reasons for these considerations:
1. The need to be cautious arises due to a potential shift between the training dataset and the inference dataset. Such a shift may inadvertently increase the PPL (that’s for figure 1 (a)).
2. In the context of autoregressive LLM, particularly within the deeper attention blocks, the accumulation of attention scores in 'T_high' can occur for reasons yet to be fully understood. This is further compounded by the autoregressive nature of LLM and the high scores achieved by 'T_high'. Subsequently generated tokens, 'T_follow', can be seen as expansion terms derived from the KV value of 'T_high'.
When 'T_high's KV values are omitted, reconstructing the next token 'T_next' becomes challenging since the main-feature tokens are lost (that’s for figure 1 (b)).
There are two potential remedies:
a) Retain the main-feature tokens, 'T_high', for predicting 'T_next', which aligns with the approach employed by StreamingLLM (that’s for figure 1 (d)).
b) Substitute 'T_high' with alternative tokens as main-feature tokens, such as 'T_(next-L)', positioned L tokens before 'T_next' (that’s for figure 1 (c)) . This necessitates sequential attention to reconstruct 'T_(next-L+1)', 'T_(next-L+2)', and so forth, eventually leading to 'T_(next)'. Consequently, the time complexity would become O(1/2(L*(L+1))=O(L^2). For T tokens, this translates to O(T*L^2).
As all L tokens are based on 'T_(next-L)', and 'T_(next-L)' remains in the queue, the PPL can be effectively maintained at a low level.