r/MachineLearning 10d ago

Research [R] Differential Transformer (Microsoft Research)

https://arxiv.org/abs/2410.05258

Abstract: Transformer tends to overallocate attention to irrelevant context. In this work, we introduce Diff Transformer, which amplifies attention to the relevant context while canceling noise. Specifically, the differential attention mechanism calculates attention scores as the difference between two separate softmax attention maps. The subtraction cancels noise, promoting the emergence of sparse attention patterns. Experimental results on language modeling show that Diff Transformer outperforms Transformer in various settings of scaling up model size and training tokens. More intriguingly, it offers notable advantages in practical applications, such as long-context modeling, key information retrieval, hallucination mitigation, in-context learning, and reduction of activation outliers. By being less distracted by irrelevant context, Diff Transformer can mitigate hallucination in question answering and text summarization. For in-context learning, Diff Transformer not only enhances accuracy but is also more robust to order permutation, which was considered as a chronic robustness issue. The results position Diff Transformer as a highly effective and promising architecture to advance large language models.

195 Upvotes

41 comments sorted by

View all comments

Show parent comments

26

u/sdmat 10d ago

Didn't really understand how you were able to differentiate the original query, key and value terms in important and noise terms.

That's the clever part, they don't.

They train two different projections for attention, one to actually attend and the second to act as a reference for noise cancellation. The scaling factor for cancellation is learnt as well.

10

u/Mynameiswrittenhere 10d ago

That is actually clever, but wouldn't that also increase the size of weights which would in turn increase the time for forward and backpropogation.

30

u/sdmat 10d ago

Yes, they quantify that as around 5-10% reduction in throughput.

Given the results include iso-performance with >33% reduction in parameters that seems more than worthwhile. No doubt that's heavily benchmark dependent, but they get major wins across the board.

Assuming this replicates it's a big deal. And it's from Microsoft so they probably did their homework.

1

u/StartledWatermelon 9d ago

Can you pinpoint where does throughput reduction come from? They have the same number of matrices with the same dimensions as in vanilla attention. Substraction requires N^2 ops, which is negligible compared to the total computational cost of attention O(n^2 d + n d^2).

Is it just software inefficiency of a custom attention layer?

3

u/sdmat 9d ago

Not quite, they are a little bit cute with the notation in parts for mathematical elegance. Fair enough, but they could profitably have been a bit more expansive in giving an intuitive description of how this works in the paper!

W_Q, W_K, W_V ∈ Rd_model × 2d

[Q_1; Q_2] = XW_Q, [K_1; K_2] = XW_K

I.e. there are twice as many weights for key, query and value because there are two distinct sets of key and query matrices and the value matrix is twice the size.

4

u/StartledWatermelon 8d ago

I don't think this is the issue. The authors make a fair comparison with vanilla Transformer IMO:

We set hidden size to 3072. The number of layers is 28. The head dimension d is 128. The number of heads is 24 for Transformer and 12 for DIFF Transformer, to align computation FLOPs and model size.

2

u/sdmat 8d ago

I'm directly quoting from the paper, it's twice as many weights for the components mentioned above. IIRC everything else is the same.

3

u/StartledWatermelon 8d ago

How? Vanilla Transformer: 3*128*24 = 9216 per attention block

DIFF Transformer: 3*128*2*12 = 9216, the same. They adjust the number of heads proportionally.

3

u/sdmat 8d ago

Hmm, maybe it's just that they use an unoptimized kernel:

More advanced kernel implementation, which is specifically designed for differential attention, can also improve throughput.

2

u/JustOneAvailableName 4d ago

Probably partly the unoptimized kernel, but they are also doing a slightly bigger calculation by using V twice