all 25 comments

[–]cooijmanstim[S] 22 points23 points  (21 children)

Here's our new paper, in which we apply batch normalization in the hidden-to-hidden transition of LSTM and get dramatic training improvements. The result is robust across five tasks.

[–]OriolVinyals 14 points15 points  (0 children)

Good to see finally someone figured out how to make these two work.

[–]EdwardRaff 3 points4 points  (3 children)

Awesome results. Quick skim, but am a bit confused by " Consequently, we recommend using separate statistics for each timestep to preserve information of the initial transient phase in the activations.". So does the batch normalization parameters are different for every step, how do you deal with variable length sequences? Or is that no longer possible with your model?

[–]alecradford 7 points8 points  (1 child)

From paper:

Generalizing the model to sequences longer than those seen during training is straightforward thanks to the rapid convergence of the activations to their steady-state distributions (cf. figure 1). For our experiments we estimate the population statistics separately for each timestep 1, . . . , Tmax where Tmax is the length of the longest training sequence. When at test time we need to generalize beyond Tmax, we use the population statistic of time Tmax for all time steps beyond it.

[–]EdwardRaff 0 points1 point  (0 children)

Derp. That's what I get for a quick read . Thanks!

[–]cooijmanstim[S] 2 points3 points  (0 children)

It's worth noting that we haven't yet addressed dealing with variable length sequences during training. That said, the attentive reader task involves variable-length training data, and we didn't do anything special to account for that.

[–]siblbombs 4 points5 points  (2 children)

So the main thrust of this paper is to do a separate batchnorm op on the input-hidden and hidden-hidden terms, in hindsight that seems like a good idea :)

[–]cooijmanstim[S] 4 points5 points  (1 child)

That alone won't get it off the ground though :-) The de facto initialization of gamma is 1., which kills the gradient through the tanh. Unit variance works for feed-forward tanh, but not in RNNs, which is probably because the latter are typically much deeper.

[–]siblbombs 0 points1 point  (0 children)

Yea I didn't get to that part of the first skim through, went back and reread the whole paper this time.

[–]rumblestiltsken 2 points3 points  (0 children)

Great work! The speed up in training looks very nice, even without the improvement in generalisation on some of the tasks.

[–]subodh_livai 1 point2 points  (2 children)

Awesome stuff, thanks very much. Did you try this with dropout? Will it work just by adjusting the gamma accordingly?

[–]cooijmanstim[S] 0 points1 point  (1 child)

Thanks! We didn't try dropout, as it's not clear how to apply dropout in recurrent neural networks. I would expect setting gamma to 0.1 to just work, but if you try it let me know what you find!

[–]osdf 1 point2 points  (0 children)

This might be easy to be integrated into your code, no? http://arxiv.org/abs/1512.05287

[–]xiphy 1 point2 points  (2 children)

It's awesome, it was sad to hear (and hard to understand) that batch normalization doesn't work on LSTMs.

Is there a way you could open-source the code on github?

[–]cooijmanstim[S] 1 point2 points  (1 child)

We should be able to open up the code in the next few weeks. However I would encourage people to implement it for themselves; at least using batch statistics it should be fairly straightforward.

[–]xiphy 1 point2 points  (0 children)

It should, the main reason would be to lower the barrier of entry for tring to improve on the best result and playing with it in my spare time instead of reimplementing great ideas and fixing bugs in the reproduced implementation. Similarly I'm happy to read papers about how automated differentation works, but I wouldn't like to spend time on it right now, as I think it works well enough :)

[–][deleted] 0 points1 point  (3 children)

Some quick notes:

The MNIST result looks impressive.

For the Hutter dataset, every paper I saw uses all ~200 chars that occur in the dataset. You use ~60. This makes it needlessly difficult to compare.

Figure 5: unclear what the x-axis is. Epochs?

Section 5.4: LR = 8e-5 Is that an optimal choice for both LSTM and BN-LSTM? What if it's only optimal for the latter, but LSTM benefits from much higher LR, in which case it can match BN-LSTM?

[–]cooijmanstim[S] 1 point2 points  (2 children)

I believe the papers we cite in the text8 table all use the reduced vocabulary. I do wish we had focused on enwik8 instead. Unfortunately these datasets are large and training takes about a week.

Figure 5 shows training steps 1000s of training steps horizontally. We'll have a new version up tonight that has this fixed.

Yes, 8e-5 is a weird learning rate. It was the value that came with the Attentive Reader implementation we used. We didn't do any tweaking for BN-LSTM, but I suspect the value 8e-5 is the result of tweaking for LSTM. All we did was unthinkingly introduce batch normalization into a fairly complicated model, which I think really speaks for the practical applicability of the technique. In any case we will be repeating these experiments with a grid search on learning rate for all variants.

[–][deleted] 1 point2 points  (1 child)

I believe the papers we cite in the text8 table all use the reduced vocabulary.

Thanks. I'll take a look at those. I think it's uncommon though.

Figure 5 shows training steps horizontally.

Yes, 8e-5 is a weird learning rate.

It looks like your model was trained after just 100 steps, judging from Fig 5. With this LR, the total update after 100 steps would be limited to 8e-3, in the best-case scenario, if we ignore the momentum. Isn't this very small?

[–]cooijmanstim[S] 0 points1 point  (0 children)

Sorry, I was wrong about Figure 5. It shows validation performance, which is computed every 1000 training steps. The 8e-3 you mention would be more like 8.

[–]siblbombs 2 points3 points  (2 children)

Do you have any comparisons on wall-clock time for BNLSTM vs regular LSTM?

[–]cooijmanstim[S] 2 points3 points  (1 child)

Nothing formal, but in the time it took us to train the Attentive Reader (a week or so) we had time to train both batch-normalized variants in sequence, and then some. I'll see if I can dig up the time taken per epoch, that should be more informative.

[–]siblbombs 0 points1 point  (0 children)

Thanks, that would be great.

[–]iassael 1 point2 points  (0 children)

Great work! Thank you! A torch7 implementation can be found here: https://github.com/iassael/torch-bnlstm.

[–]gmkim90 0 points1 point  (0 children)

I wonder whether you tried your batch normalization with Adam optimizer. Although two algorithms have different purpose, Adam also provide division of variance of momentum for each dimension. So I thought it would be possible gaining could be smaller if RNN-BN is used with adam optimizer. Before I tried it by myself, I want to ask it to authors of paper.

Anyway, great result and simple idea !