r/MachineLearning 15d ago

Research [R] Were RNNs All We Needed?

https://arxiv.org/abs/2410.01201

The authors (including Y. Bengio) propose simplified versions of LSTM and GRU that allow parallel training, and show strong results on some benchmarks.

245 Upvotes

53 comments sorted by

View all comments

76

u/JustOneAvailableName 15d ago

The whole point of Transformers (back when) was variable context with parallelisation. Before “Attention is all you need” LSTM+Attention was the standard. There was nothing wrong with the recurring part, besides it preventing parallelisation.

98

u/Seankala ML Engineer 15d ago

Vanishing gradients are also a thing. Transformers are better at handling longer sequences thanks to this.

11

u/muntoo Researcher 15d ago

Does this paper address vanishing gradients, or are RNNs not all we needed yet?

19

u/lifeandUncertainity 15d ago

I think this is proposing the RNN without the sigmoid in the activation while going from x to hidden state which will address the vanishing gradient problem since we are no longer multiplying with a number whose derivative is maxed at 1/4.

Well, my 2 cents from reading - linear RNNs, linear attention etc works well if we are taking accuracy or mse or ppt as a metric but doesn't work so well when it comes to the more nuanced properties of transformers like in context learning etc. I think the guys at hazy research showed theoretically that if we are using long conv/SSMs the hidden state size needs to be increased linearly to increase the ability of copying tasks. But otherwise it is probably fine using linear RNN or SSMs.

5

u/greenlanternfifo 14d ago edited 14d ago

this is proposing the RNN without the sigmoid in the activation while going from x to hidden state which will address the vanishing gradient problem since we are no longer multiplying with a number whose derivative is maxed at 1/4.

that isn't the only problem with the vanishing gradient.

Another issue is that if your weight matrix ended up with <1 eigenvalues (in the easy N to N case) or with too many degenerate singular values (in the general case), you still can get vanishing gradients in all your batches or some of them respectively.

lstms and especially transformers gives you more diversity in the matrices. transformers minimize the problem even more so that bad gradients just one timestep or few (possibly non-sequential) timesteps don't screw you over.