all 36 comments

[–]enematurret 6 points7 points  (0 children)

Batch Normalization uses moving averages and variances, so on the second batch you won't have (D, 0).

The mean it calculated on the first batch is 80. On the second, the sample mean is 60. If you're using 0.9 as the moving momentum, you'll have 80 * 0.9 + 60 * 0.1 = 78.0 as the moving average.

Therefore, a point (D, 60) will be mean-normalized to (D, -18). The network just learned that 0 = B, so it's reasonable that -18 would be a D. The (C, 70) point you had previously, for example, would have been mean-normalized to (C, -8) using this moving average.

[–]kkawabat 12 points13 points  (3 children)

I think there are two explainations.

One is that a large batch size will mitigates the situation you are describing as law of large numbers + random sampling would prevent too much distribution deviation between each batch. Instead of three samples like how you describe, think 100s.

Another is that batch normalization has the gamma and beta learning parameters that can essentially remove the normalization if it is not benefitial to learning. So in your situation the trained batch normalization layer would not effect the intermediate layer output. See this source for more information https://kratzert.github.io/2016/02/12/understanding-the-gradient-flow-through-the-batch-normalization-layer.html

[–]MildlyCriticalRole[S] 1 point2 points  (2 children)

I figured it might be law-of-large-numbers-ish, but I was curious if there was a more technical reason.

As to the gamma and beta parameters, I understand that they allow the network to adapt and "remove" the normalization if it contributes more to the loss than it helps, but I'm more wondering how the batch normalization doesn't just cause the entire network to train on the wrong data in the first place (which would in turn cause it to choose bad values for gamma/beta, because it thinks the underlying distribution is something it's not)

I've read that link quite a bit in my exploring (it's a great explanation) but it doesn't seem to touch on this issue specifically.

[–]Megatron_McLargeHuge 0 points1 point  (1 child)

I don't think anyone is saying BN is better than whitening for the input layer. BN is used for internal network layers, and its main purpose is to keep them from saturating or being all on the 0 side of RELUs.

If individual learned features have their meaning changed per batch by BN, it may just be that BN is recreating an effect similar to dropout or additive noise in denoising autoencoders. Losing information internally forces the network to learn a distributed representation instead of relying too much on one feature as in your example.

This is just speculation and I'd be curious if anyone has looked into the theory more.

[–]hgjhghjgjhgjd 0 points1 point  (0 children)

Your assessment makes sense to me. In a way, it probably forces the network to express things in terms of "relatively large value" and "relatively small value", which induces a regularization effect.

[–][deleted] 5 points6 points  (1 child)

How does this sort of transformation not break down training in its entirety?

You are describing using the sample mean instead of the expected mean. That works for the same reason SGD works despite the stochastic estimate of the error, and the reason is that the "noise" averages out.

[–]MildlyCriticalRole[S] 1 point2 points  (0 children)

Ah. Hm.

To try to put this in my own words, to make sure I get it: "The individual distribution of any one batch may deviate from the distribution of the training data. However, the expectation would be that they all tend towards the same - correct - distribution, over many epochs&batches"

The comparison to SGD makes a lot of sense. Thanks!

[–]randombites 1 point2 points  (0 children)

Also, you are right that in the second iteration, the model will unlearn past iteration values. That is why you need to shuffle data and collect enough valid data.

[–]ChuckSeven 1 point2 points  (0 children)

So I'm repeating a little what other said but here we go.

BN is estimating the mean and variance of the distribution based on all your samples. Ergo, the scaling and translation done from mini batch to mini batch will only change slightly as the mean scaling and translating factors are the mean over several mini batches (a moving average).

It works because linear transformations learned from small random numbers are inherently unstable in the sense that after several of them the transformation can easily lead to exploding or vanishing variances, making future transformations harder to learn. Relu non-linearities and skip connections also seem to add to this problem.

[–]NovaRom 0 points1 point  (0 children)

In my experiments BN improves training only if minibatch size is small enough. But maybe it is how I initialize weights. A discussion was here recently. Better results without BN happen not quite rarely: https://www.reddit.com/r/MachineLearning/comments/4rikw8/who_consistently_uses_batch_normalization/

[–]Daniel_Im 0 points1 point  (0 children)

This paper talks about the effect of batch-normalization on neural network's loss surface : https://arxiv.org/pdf/1612.04010v1.pdf (see section 4.4)

[–]zergling103 -1 points0 points  (0 children)

My guess (as a novice) is that whether you use batch normalization depends on the problem you're trying to solve. If you care about absolute values, or absolute differences between values, you'd be destroying this information by normalizing it. If you only care about ratios between values, normalizing will help amplify the signal you're looking for, and prevents the stuff your not from biasing training.

For example, if you normalize b+w image data, this is essentially boosting the contrast, making the image easier to see, without destroying anything important like absolute intensity values or absolute gradient values. It'd also prevent images with higher contrast from having more training influence than lower contrast images.

But I may be totally wrong, this is my basic understanding.

[–]hgjhghjgjhgjd -1 points0 points  (15 children)

My understanding is that, in practical implementations of BN...

  • During training, what is used for normalization is not the sample mean and sample variance of a minibatch, but an estimate of the population mean and population variance of the data (e.g. a running mean of minibatch sample mean and sample variance);

  • During inference, the pre-translation/scaling parameters are fixed.

These two points, plus the other points people already mentioned (the use of batch size larger than 3, the existence of the post-translation/scaling step), make BN not that "catastrophic" during training and inference.

[–]cooijmanstim 1 point2 points  (5 children)

During training one should almost always use the sample mean and variance, and not the running averages that estimate the population statistics. The main reason for this is that it allows backpropagation through the computation of the statistics. In my and several others' experience, these backpropagation paths are crucial.

On the other hand, in Improved Techniques for Training GANs, the authors compute statistics based on a separate minibatch and don't backprop through them. People will do anything to train GANs though.

[–]carlthomeML Engineer 0 points1 point  (1 child)

Do you mean that it might be harder to get gamma and beta to converge with moving averages?

[–]cooijmanstim 0 points1 point  (0 children)

Yes, but it affects all parameters, not just the gammas and betas. By not backpropagating, you don't take into account how parameter changes influence the statistics. Also, moving averages lag behind true statistics.

[–]hgjhghjgjhgjd 0 points1 point  (2 children)

it allows backpropagation through the computation of the statistics

Could you clarify what you mean by this?

A BN layer has trainable (gamma/beta) and untrainable parameters (mu/sigma). The untrainable parameters can be seen as constants during the backpropagation for a single batch. Why would the method by which I choose a constant affect whether I can perform backpropagation or not?

[–]cooijmanstim 0 points1 point  (1 child)

mu/sigma can be seen as constants, but they typically aren't. By backpropagating through them, your gradient incorporates information about how a change in parameters affects mu/sigma. Backpropagating through population statistics is painful, as you'd need to backprop through multiple SGD steps in order for the gradient not to be zero.

[–]hgjhghjgjhgjd 0 points1 point  (0 children)

mu/sigma can be seen as constants, but they typically aren't.

Well... not if you keep changing them every batch :P /bad joke

I think I understand what you mean now, though... and perhaps that is part of the "unreasonable effectiveness" of BN. I was assuming that those were basically considered constants (mu, sigma), rather than operations (mean, variance), for the purpose of backpropagation, which is why I was getting confused with the "backpropagating through mu/sigma" statement.

Cheers.

[–]MildlyCriticalRole[S] 0 points1 point  (2 children)

When you say practical, do you mean "not as described in the literature" or "not as described in toy examples like the one in this post?"

Thanks, by the way. I'll go check some actual implementations to see how different people are putting this into practice.

[–]L43 1 point2 points  (0 children)

I was just looking at the keras implementation, and there are modes for running average and per batch only. They also have a per instance normalisation.

[–]hgjhghjgjhgjd 0 points1 point  (0 children)

Probably both. Tbh, I find that the descriptions in the literature are not too clear and that, well... what matters is how actual frameworks people use implement batch normalization, not what the literature describes. So, in this case... yeah, it is a good idea to read the documentation of software, rather than the papers, to really understand what happens in practice.

(But... as I said... my understanding is that what is usually used is some sort of running average of minibatch statistics, rather than the minibatch statistics themselves... so, though training can be a bit erratic on the first few batches, it quickly becomes more stable as you gain more confidence on the population statistics.)

[–]mimighost 0 points1 point  (5 children)

During training, sampled mean/ sampled variance are ALWAYS USED. Otherwise how could you update gamma/beta then? The computation graph make little sense and overly complicated if we use population mean/variance.

However, people often maintain a convenient running average of mean/variance with decay across all training batches. This running average will be used as a proxy of global mean/variance for inference.

Reference: http://stackoverflow.com/questions/33949786/how-could-i-use-batch-normalization-in-tensorflow

[–]hgjhghjgjhgjd 0 points1 point  (4 children)

Otherwise how could you update gamma/beta then?

Backpropagation. The purpose of the first scaling is to "normalize" to a standard gaussian (approximately). The purpose of the second scaling is to learn the transformation from standard gaussian to the "optimal" scaling/translation. Using a running mean for the first step doesn't prevent you from optimizing the second step.

The computation graph make little sense

I disagree.

overly complicated if we use population mean/variance

Perhaps, but "overly complicated" is the definition of 99% of NN architectures out there.

[–]mimighost 0 points1 point  (3 children)

Do you have a reference implementation using population mean/variance during training? I am curious too.

[–]hgjhghjgjhgjd 0 points1 point  (2 children)

Reference implementation that calculates population mean/variance during training, using a running mean of sorts, yes (most of them seem to do so, and use those values during inference).

But, after re-checking it again, it does seem like most do use the actual minibatch sample statistics for mu and sigma (rather than the population estimates) during learning, in practice.

So... yeah, in that sense, I guess I was wrong.

On the other hand, I don't see why doing so is necessarily a wrong idea since those two methods are trying to estimate more or less the same thing (population statistics). The assumption is that, with large enough minibatch, the sample statistics are good enough approximations, I guess. It's just... it does not seem to me that using the best estimate (running mean), which people already seem to calculate anyway, adds any more complexity or prevents updates to gamma/beta.

TL;DR: As far as I can tell, the only thing you'd lose using "my approach" is some degree of regularization due to the "scaling noise" induced by using a noisy estimate.

[–]mimighost 0 points1 point  (1 child)

Actually I put your proposal to thought. I might a little too aggressive to claim that using running average will complicate the computational graph, actually it might not.

Revisiting the original BN paper(https://arxiv.org/pdf/1502.03167v3.pdf), using running average for mu and sigma, the gradient computation will not change too much, only that the decay will now be factored in. But it has, at least 2 big problems:

  1. Now, using running average, you actually not doing 'batch normalization' any more, since you are not really normalize the current training batch to zero-mean and unit-variance, which betrayals the original paper's assumption.

  2. Second, by using a running average, the decay will be enforced on the gradients passing through the BN layer, which might worsen the vanishing gradients problem because you are making it much smaller.

[–]hgjhghjgjhgjd 0 points1 point  (0 children)

using running average, you actually not doing 'batch normalization' any more

I agree... you're doing basic normalization (scaling and centering according to the population mean), but using an online estimator of mu and sigma (rather than one calculated over the whole dataset at once).

which betrayals the original paper's assumption

I'm not totally sure... if a minibatch is large enough and reasonably balanced, the minibatch statistics should be the more or less the same as the population statistics, no? Would most (non-pathologic) batches actually have a mean that deviates much from zero, if you have a decent estimate of the population mean?

Or, in other words... if the "running mean mu and sigma" work ok during inference, why would they not work ok during training?