Why I stopped being into basin broadness
post by tailcalled · 2024-04-25T20:47:17.288Z · LW · GW · 3 commentsContents
The NTK as the network activations? None 3 comments
There was a period where everyone was really into basin broadness for measuring neural network generalization. This mostly stopped being fashionable, but I'm not sure if there's enough written up on why it didn't do much, so I thought I should give my take for why I stopped finding it attractive. This is probably a repetition of what others have found, but I thought I might as well repeat it.
Let's say we have a neural network . We evaluate it on a dataset using a loss function , to find an optimum . Then there was an idea going around that the Hessian matrix (i.e. the second derivative of at ) would tell us something about (especially about how well it generalizes).
If we number the dataset , we can stack all the network outputs which fits into an empirical loss . The Hessian that we talked about before is now just the Hessian of . Expanding this out is kind of clunky since it involves some convoluted tensors that I don't know any syntax for, but clearly it consists of two terms:
- The Hessian of with a pair of the Jacobian of on each end (this can just barely be written without crazy tensors: )
- The gradient of with a crazy second derivative of .
Now, the derivatives of are "obviously boring" because they don't really refer to the neural network weights, which is confirmed if you think about it in concrete cases, e.g. if with or , the derivatives just quantify how far is from . This obviously isn't relevant for neural network generalization, except in the sense that it tells you which direction you want to generalize in.
Meanwhile, is incredibly strongly related to neural network generalization, because it's literally a matrix which specifies how the neural network outputs change in response the weights. In fact, it forms the core of the neural tangent kernel (a standard tool for modelling neural network generalization), because the NTK can be expressed as .
The "crazy second derivative of " can I guess be understood separately for each , as then it's just the Hessian , i.e. it reflects how changes in the weights interact with each other when influencing . I don't have any strong opinions on how important this matrix is, though because is so obviously important, I haven't felt like granting much attention.
The NTK as the network activations?
Epistemic status: speculative, I really should get around to verifying it. Really the prior part is speculative too, but I think those speculations are more theoretically well-grounded. But if I'm wrong with either, please call me a dummy in the comments so I can correct.
Let's take the simplest case of a linear network, . In this case, , i.e. the Jacobian is literally just the inputs to the network. If you work out a bunch of other toy examples, the takeaway is qualitatively similar (the Jacobian is closely related to the neuron activations), though not exactly the same.
There are of course some exceptions, e.g. at just has a zero Jacobian. Exceptions this extreme are probably rare, but more commonly you could have some softmax in the network (e.g. in an attention layer) which saturates such that no gradient goes through. In that case for e.g. interpretability, it seems like you'd often still really want to "count" this, so arguably the activations would be better than the NTK for this case. (I've been working on a modification to the NTK to better handle this case.)
The NTK and the network activations have somewhat different properties and so it switches which one I consider most relevant. However, my choice tends to be more driven by analytical convenience (e.g. the NTK and the network activations lie in different vector spaces) than by anything else.
3 comments
Comments sorted by top scores.
comment by Alexander Gietelink Oldenziel (alexander-gietelink-oldenziel) · 2024-04-26T19:57:33.439Z · LW(p) · GW(p)
This is all answered very elegantly by singular learning theory.
You seem to have a strong math background! I really encourage you take the time and really study the details of SLT. :-)
Replies from: tailcalled↑ comment by tailcalled · 2024-04-27T14:15:36.529Z · LW(p) · GW(p)
Do you have ab outline of how SLT answers this?
Replies from: alexander-gietelink-oldenziel↑ comment by Alexander Gietelink Oldenziel (alexander-gietelink-oldenziel) · 2024-04-27T16:05:54.797Z · LW(p) · GW(p)
ingular Sure! I'll try and say some relevant things below. In general, I suggest looking at Liam Carroll's distillation [? · GW] over Watanabe's book (which is quite heavy going, but good as a reference text). There are also some links below that may prove helpful.
The empirical loss and its second derivative are statistical estimator of the population loss and its second derivative. Ultimately the latter controls the properties of the former (though the relation between the second derivative of the empirical loss and the second derivative of the population loss is a little subtle).
The [matrix of] second derivatives of the population loss at the minima is called the Fischer information metric. It's always degenerate [i.e. singular] for any statistical model with hidden states or hierarchichal structure. Analyses that don't take this into account are inherently flawed.
SLT tells us that the local geometry around the minimum nevertheless controls the learning and generalization behaviour of any Bayesian learner for large N. N doesn't have to be that large though, empirically the asymptotic behaviour that SLT predicts is already hit for N=200.
In some sense, SLT says that the broad basin intuition is broadly correct but this needs to be heavily caveated. Our low-dimensional intuition for broad basin is misleading. For singular statistical models (again everything used in ML is highly singular) the local geometry around the minima in high dimensions is very weird.
Maybe you've heard of the behaviour of the volume of a sphere in high dimensions: most of it is contained on the shell. I like to think of the local geometry as some sort of fractal sea urchin. Maybe you like that picture, maybe you don't but it doesn't matter. SLT gives actual math that is provably the right thing for a Bayesian learner.
[real ML practice isn't Bayesian learning though? Yes, this is true. Nevertheless, there is both empirical and mathematical evidence that the Bayesian quantitites are still highly relevant for actual learning]
SLT says that the Bayesian posterior is controlled by the local geometry of the minimum. The dominant factor for N~>= 200 is the fractal dimension of the minimum. This is the RLCT and it is the most important quantity of SLT.
There are some misconception about the RLCT floating around. One way to think about is as an 'effective fractal dimension' but one has to be careful about this. There is a notion of effective dimension in the standard ML literature where one takes the parameter count and mods out parameters that don't do anything (because of symmetries). The RLCT picks up on symmetries but it is not just that. It picks up on how degenerate directions in the fischer information metric are ~= how broad is the basin in that direction.
Let's consider a maximally simple example to get some intuition. Let the population loss function be . The number of parameters and the minimum is at .
For the minimum is nondegenerate (the second derivative is nonzero). In this case the RLCT is half the dimension. In our case the dimension is just so
For the minimum is degenerate (the second derivative is zero). Analyses based on studying the second derivatives will not see the difference between but in fact the local geometry is vastly different. The higher is the broader the basin around the minimum. The RLCT for is . This means, the is lower the 'broader' the basin is.
Okay so far this only recapitulates the broad basin story. But there are some important points
- this is an actual quantity that can be estimated at scale for real networks that provably dominates the learning behaviour for moderately large .
- SLT says that the minima with low rlct will be preferred. It evens says how much they will be preferred. There is tradeoff between lower rlct minima with moderate loss ('simpler solutions') and minima with higher rlct but lower loss. As This means that the RLCT is actually 'the right notion of model complexity/ simplicty' in the parameterized Bayesian setting. This is too much to recap in this comment but I refer you to Hoogland & van Wingerden's post here [LW · GW]. This is the also the start of the phase transition story which I regard as the principal insight of SLT.
- The RLCT doesn't just pick up on basin broadness. It also picks up on more elaborate singular structure. E.g. a crossing valley type minimum like . I won't tell you the answer but you can calculate it yourself using Shaowei Lin's cheat sheet. This is key - actual neural networks have highly highly singular structure that determines the RLCT.
- The RLCT is the most important quantity in SLT but SLT is not just about the RLCT. For instance, the second most important quantity the 'singular fluctuation' is also quite important. It has a strong influence on generaliztion behaviour and is the largest factor in the variance of trained models. It controls approximation to Bayesian learning like the way neural networks are trained.
- We've seen that the directions defined by the matrix of second derivatives is fundamentally flawed because neural networks are highly singular. Still, there is something noncrazy about studying these directions. There is upcoming work which I can't discuss in detail yet that explains to large degree how to correct this naive picture both mathematically and empirically.