Rethinking Batch Normalization

post by Matthew Barnett (matthew-barnett) · 2019-08-02T20:21:16.124Z · LW · GW · 5 comments

Contents

5 comments

Yesterday we saw a glimpse into the inner workings of batch normalization, a popular technique in the field of deep learning. Given that the effectiveness of batch normalization has been demonstrated beyond any reasonable doubt, it may come as a surprise that researchers don't really know how it works. At the very least, we sure didn't know how it worked when the idea was first proposed.

One might first consider that last statement to be unlikely. In the last post I outlined a relatively simple theoretical framework for explaining the success of batch normalization. The idea is that batch normalization reduces the internal covariate shift (ICS) of layers in a network. In turn, we have a neural network that is more stable, and robust to large learning rates, and allows much quicker training.

And this was the standard story in the field for years, until a few researchers decided to actually investigate it.

Here, I hope to convince you that the theory really is wrong. While I'm fully prepared to make additional epistemic shifts on this question in the future, I also fully expect to never shift my opinion back.

When I first read the original batch normalization paper, I felt like I really understood the hypothesis. It felt simple enough, was reasonably descriptive, and intuitive. But I didn't get a perfect visual of what was going on — I sort of hand-waved the step where ICS contributed to an unstable gradient step. Instead I, like the paper, argued by analogy, that since controlling for covariate shifts were known for decades to help training, a technique to reduce internal covaraite shift is thus a natural extension of this concept.

It turned out this theory wasn't even a little bit right [? · GW]. It's not that covariate shifts aren't important at all, but that the entire idea is based on a false premise.

Or at least, that's the impression I got while reading Shibani Santurkar et al.'s How Does Batch Normalization Help Optimization? Whereas the original batch normalization paper gave me a sense of "I kinda sorta see how this works," this paper completely shattered my intuitions. It wasn't just the weight of the empirical evidence, or the theoretical underpinning they present; instead what won me over was the surgical precision of their rebuttal. They saw how to formalize the theory of improvement via ICS reduction and tested it on BatchNorm directly. The theory turned out to be simple, intuitive, and false.

In fairness, it wasn't laziness that prohibited researchers from reaching our current level of understanding. In the original batch normalization paper, the authors indeed proposed a test for measuring batch normalization's effect on ICS.

The problem was instead twofold: their method for measuring ICS was inadequate, and failed to consistently apply their proposed mechanism for how ICS reduction was supposed to work in their testing conditions. More importantly however, they didn't even test the theory that ICS reduction contributed to performance gains. Instead their argument was based on a simple heuristic: we know that covariate shifts are bad, we think that batch normalization reduces ICS, and we also know batch normalization increases performance charactersitics — therefore batch normalization works due to ICS reduction. As far as I can tell, most the articles that came after the original paper just took this heuristic at face value, citing the paper and calling it a day.

And it's not a bad heurstic, all in all. But perhaps it's a tiny bit telling that on yesterday's post, Lesswrong user crabman [LW · GW] was able to anticipate the true reason for batch normalization's success, defying both my post and the supposed years that it took researchers to figure this stuff out. Quoth crabman,

I am imagining this internal covariate shift thing like this: the neural network together with its loss is a function which takes parameters θ as input and outputs a real number. Large internal covariate shift means that if we choose ε>0, perform some SGD steps, get some θ, and look at the function's graph in ε-area of θ, it doesn't really look like a plane, it's more curvy like.

In fact, the above paragraph doesn't actually describe internal covariate shift, but instead the smoothness of the loss function around some parameters . I concede, it is perhaps possible that this is really what the original researchers meant when they termed internal covariate shift. It is therefore also possible that this whole critique of the original theory is based on nothing but a misunderstanding.

But I'm not buying it.

Take a look at how the original paper defines ICS,

We define Internal Covariate Shift as the change in the distribution of network activations due to the change in network parameters during training.

This definition can't merely refer to the smoothness of the gradient around θ. For example, the gradient could be extremely bumpy and have sharp edges and yet ICS could be absent. Can you think of an example of a neural network like this? Here's one: think of a network with just one layer whose loss function is some extremely contorted shape because its activation function is some crazy non-linear function. It wouldn't be smooth, but its input distribution would be constant over time, given that it's only one layer.

I can instead think of two interpretations of the above definition for ICS. The first interpretation is that ICS simply refers to the change of activations in a layer during training. The second interpretation is that this definition specifically refers to change of activations caused by changes in network parameters at previous layers.

This is a subtle difference, but I believe it's important to understand. The first interpretation allows ease of measurement, since we can simply plot the mean and variance of the input distributions of a layer during training. This is in fact how the paper (section 4.1) tests batch normalization's effect on ICS. But really, the second interpretation sounds closer to the hypothesized mechanism for how ICS was supposed to work in the first place.

On the level of experimentation, the crucial part of the above definition is the part that says "change [...] due to the change in network parameters." Merely measuring the change in network parameters over time is insufficient. Why? Because the hypothesis was that if activation distributions change too quickly, then a layer will have its gradient pushed into a vanishing or exploding region. In the first interpretation, a change over time could still be slow enough for each layer to adapt appropriately. Therefore, we need additional information to discover whether ICS is occurring in the way that is described.

To measure ICS under the second interpretation, we have to measure the counterfactual change of parameters — in other words, the amount that the network activations change as a result of other parameters being altered. And we also need a way of seeing whether the gradient is being pushed into extreme regions as a result of these parameters being changed. Only then can we see whether this particular phenomenon is actually occurring.

The newer paper comes down heavily in favor of this interpretation, and adds a level of formalization on top of it. Their definition focuses on measuring the difference between two different gradients: one gradient with all of the previous layers altered by back propagation, and one gradient where all of the previous layers have been unaltered. Specifically, let by a loss function for a neural network of layers. Then, their definition of ICS for the activation and time is where

and is the batch of input-label pairs to train the network at time .

The first thing to note about this definition is that it allows a clear, precise measurement of ICS, which is based solely on the change of the gradient due to shifting parameters beneath a layer during backpropagation.

What Shibani Santurkar et al. found when they applied this definition was a bit shocking. Not only did batch normalization fail to decrease ICS, in some cases it even increased it when compared to naive feedforward neural networks. And to top that off, they found that even in networks where they artificially increased ICS, performance barely suffered.

In one experiment they applied batch normalization to each hidden layer in a neural network, and at each step, they added noise after the batch normalization transform in order to induce ICS. This noise wasn't just Gaussian noise either. Instead they chose the noise such that it was a different Gaussian at every time step and every layer, such that the Gaussian parameters (specifically mean and variance) varied according to a yet another meta Gaussian distribution. What they discovered was that even though this increased measured ICS dramatically, the time it took to train the networks to the baseline accuracy was almost identical to regular batch normalization.

And remember that batch normalization actually does work. In all of the experiments for mere performance increases, batch normalization has passed the tests with flying colors. So clearly, since batch normalization works, it must be for a different reason than simply reducing ICS. But that leaves one question remaining: how on Earth does it work?

I have already hinted at the reason above. The answer lies in something even simpler to understand than ICS. Take a look at this plot.


Imagine the red ball is rolling down this slope, applying gradient descent at each step. And consider for a second that the red ball isn't using any momentum. It simply looks at each step which direction to move and moves in that direction in proportion to the slope at that point.

A problem immediately arises. Depending on how we choose our learning rate, the red ball could end up getting stuck almost immediately. If the learning rate is too slow, then it will probably get stuck on the flat plane to the right of it. And in practice, if its learning rate is too high, then it might move over to another valley entirely, getting itself into an exploding region.

The way that batch normalization helps is by changing the loss landscape from this bumpy shape into one more like this.

Now it no longer matters that much what we set the learning rate to. The ball will be able to find its way down even if its too small. What used to be a flat plane has now been rounded out such that the ball will roll right down.

The specific way that the paper measures this hypothesis is by applying pretty standard ideas from the real analysis toolkit. In particular, the researchers attempted to measure the Lipschitzness of the loss function around the parameters for various types of deep networks (both empirically and theoretically). Formally a function is L-Lipschitz if for all and . Intuitively, this is a measure of how smooth the function is. The smaller the constant , the function has fewer and less extreme jumps over small intervals in some direction.

This way of thinking about the smoothness of the loss function has the advantage of including a rather natural interpretation. One can imagine that the magnitude of some gradient estimate is a prediction of how much we expect the function to fall if we move in that direction. We can then evaluate how good we are at making predictions across different neural network schemes and across training steps. When gradient predictiveness was tested, there were no surprises — the networks with batch normalization had the most predictive gradients.

Perhaps even more damning is that not only did the loss function become more smooth, the gradient landscape itself became more smooth, a property known as smoothness. This had the effect of not only making the gradients more predictive of the loss, but the gradients themselves were easier to predict in a certain sense — they were fairly consistent throughout training.

Perhaps the way that batch normalization works is by simply smoothing out the loss function. At each layer we are just applying some normalizing transformation which helps remove extreme points in the loss function. This has the additional prediction that other transformation schemes will work just as well, which is exactly what the researchers found. Before, the fact that we added some parameters and was confusing, since it wasn't clear how this contributed to ICS reduction. Now, we can see that ICS reduction shouldn't even be the goal, perhaps shedding light on why this works.

In fact, there was pretty much nothing special with the exact way that batch normalization transforms the input, other than the properties that contribute to smoothness. And given that so many more methods have now come out which build on batch normalization despite using quite different operations, isn't this exactly what we would expect?


Is this the way batch normalization really works? I'm no expert, but I found this interpretation much easier to understand, and also a much simpler hypothesis. Maybe we should apply Occam's razor here. I certainly did.

In light of this discussion, it's also worth reflecting once again that the argument "We are going to be building the AI so of course we'll understand how it works" is not a very good one. Clearly the field can stumble on solutions that work, and yet the reason why they work can remain almost completely unknown for years, even when the answer is hiding in plain sight. I honestly can't say for certain whether happens a lot, or too much. I only have my one example here.


In the next post, I'll be taking a step back from neural network techniques to analyze generalization in machine learning models. I will briefly cover the basics of statistical learning theory and will then move to a framing of learning theory in light of recent deep learning progress. This will give us a new test bed to see if old theories can adequately adapt to new techniques. What I find might surprise you.

5 comments

Comments sorted by top scores.

comment by gjm · 2019-08-03T13:33:53.992Z · LW(p) · GW(p)

The Lipschitz constant of a function gives an indication of how horizontal it is rather than how locally linear it is. Naively I'd expect that the second of those things matters more than the first. Has anyone looked at what batch normalization does to that?

More specifically: Define the 2-Lipschitz constant of function at to be something like and its overall 2-Lipschitz constant to be the sup of these. This measures how well is locally approximable by linear functions. (I expect someone's already defined a better version of this, probably with a different name, but I think this'll do.) Does batch normalization tend to reduce the 2-Lipschitz constant of the loss function?

[EDITED to add:] I think having a 2-Lipschitz constant in this sense may be equivalent to having a derivative which is a Lipschitz function (and the constant may be its Lipschitz constant, or something like that). So maybe a simpler question is: For networks with activation functions making the loss function differentiable, does batchnorm tend to reduce the Lipschitz constant of its derivative? But given how well rectified linear units work, and that they have a non-differentiable activation function (which will surely make the loss functions fail to be 2-Lipschitz in the sense above) I'm now thinking that if anything like this works it will need to be more sophisticated...

Replies from: matthew-barnett
comment by Matthew Barnett (matthew-barnett) · 2019-08-06T01:44:53.814Z · LW(p) · GW(p)
The Lipschitz constant of a function gives an indication of how horizontal it is rather than how locally linear it is. Naively I'd expect that the second of those things matters more than the first. Has anyone looked at what batch normalization does to that?

Yeah, in fact I should have been more clear in the post. A very simple way of reducing the Lipschitzness of a function is by simply scaling it by some constant factor. The original paper attempts to show theoretically that batchnorm is doing more than simply scaling. See theorem 4.2 in the paper, and the subsequent observation in section 4.3.

If you think about it though, we can already kind of guess that batch normalization isn't simply scaling the function. That's because we measured the gradient predictiveness and discovered that the gradient ended up being much closer to the empirically observed delta-loss than when batch normalization was not enabled. This gives us evidence that the function is locally linear in the way that you described (of course, this can be criticized if you disagree with the way that they measured gradient predictiveness, which focused on measuring the variability of gradient minus actual difference in loss (see figure 4 in the paper)).

Does batch normalization tend to reduce the 2-Lipschitz constant of the loss function?

That's a good question. My guess would be yes due to what I said above, but I am not in a position confidently to say either way. I would have to think more about the exact way that you have defined it. :)

comment by Pattern · 2019-08-03T12:41:56.328Z · LW(p) · GW(p)
And to top that off, they found that even in networks where they artificially increased ICS, performance barely suffered.

All networks, or just ones with batch normalization?

Replies from: matthew-barnett
comment by Matthew Barnett (matthew-barnett) · 2019-08-06T01:46:05.746Z · LW(p) · GW(p)

That's a good point of clarification which perhaps weakens the point I was making there. From the paper,

adding the same amount of noise to the activations of the standard (non-BatchNorm) network prevents it from training entirely
comment by philip_b (crabman) · 2019-08-07T14:08:08.256Z · LW(p) · GW(p)

I want to clarify in what domain this theory of batch normalization holds.

The evidence we have is mostly about batch normalization in those types of feedforward neural networks that are often used in 2019, right? So, residual CNNs, VGG-like CNNs, other CNNs, transformers. Maybe other types of feedforward neural networks. But not RNNs.

Has anyone explored the applicability of batch normalization or similar techniques to non neural network functions which we optimize by gradient descent like algorithms? Perhaps to tensor networks?