Infinite-width MLPs as an "ensemble prior"

post by Vivek Hebbar (Vivek) · 2023-05-12T11:45:52.195Z · LW · GW · 0 comments

Contents

  Core claims
  Toy model
      Feature counts:  
    Solving for the expected regression weights
    Reframing in terms of complexity
    Generalization and "ensemble prior"
  Acknowledgements
None
No comments

Summary:  A simple toy model suggests that infinitely wide MLPs[1] generalize in an "ensemble-ish" way which is exponentially less data-efficient than Solomonoff induction.  It's probably fixable by different initializations and/or regularizations, so I note it here mostly as a mathematical curiosity / interesting prior.

The analysis seems to be qualitatively consistent with empirical results on generalization vs width in small MLPs.

Notes: 

Core claims

The standard initialization uses weights which are proportional to .  This has the effect of keeping the activations at roughly the same scale across layers.  However, in the infinite width case, it ends up making the gradients in early layers infinitely smaller than those in the last layer.  Hence, training an infinite-width MLP is equivalent to running a regression using the features represented by the last-layer neurons at initialization.  These features never change during training, since the early gradients are all zero.

If we train without regularization, we will tend to get something very "ensemble-ish", "smooth", and "dumb". I will first summarize this claim in a table, then spend the rest of the post going through the reasoning behind it.

Solomonoff InductionInfinite width MLP, low L2-norm solution[3]
Bayesian update over programsLinear regression over circuits
Puts most of its weight on a small number of programs, each of which perfectly fits the data on its ownSpreads weight over a broad ensemble, including circuits which have only a small correlation with truth
The amount of data required to make the correct program dominate is , where K is the program lengthThe amount of data to make the correct circuit dominate is , where C is some "complexity measure" (defined later).  This is exponentially less data-efficient than Solomonoff induction.
Calling it "superintelligent" is an understatementGeneralizes poorly on many tasks[4]
Highly amenable to "sharp" solutionsFavors smooth solutions, only creates "sharp" solutions if certain conditions are met by the training data.

If we train an infinitely wide MLP from the standard initialization, only the last layer's weights change.  So it is equivalent to a linear regression over an infinite set of random "features", these features being the activation patterns of the last layer neurons at initialization.[5]  

If the MLP is deep enough, some of these last-layer neurons are contain the output of very intelligent circuits.  However, if we train our infinite width MLP, these intelligent circuits will hardly be used by the regression, even if they are very useful.  That is, the sum of the weights drawing from them in the last layer will be very small.  The reason I believe this is the toy model in the next section.

Toy model

Let's call each last-layer neuron a "feature".  As discussed earlier, their behavior never changes due to how the gradients pan out at infinite width.  In a "real" infinite network, these features will be "useful" and "intelligent" to various degrees, but we will simplify this greatly in the toy model, by using just two types of features.

The toy model asks:  "Suppose that some features already compute the correct answer for every training datapoint, and that the rest of the features are random garbage.  Will the trained network rely more on the perfect features, or will it use some giant mixture of random features?"

Suppose we have  items in the training set, denoted .  Each has a label of either  or .  Let's say there are two types of features:[6]

  1. "Perfect features": Features which perfectly match the labels on the training set.
  2. "Random features": Features which were created by flipping a coin between  and  for each input, and having the neuron activate accordingly.

Since there are perfect features, we can always fit the labels.  If we have enough random features, we can also fit the labels using only random features.

We can represent features' behavior on the training set using vectors.  A feature vector  is a neuron which has activation  on , activation  on , and so on.  Feature vectors are of length .

Linear regression on any set of features will find the minimum L2-norm solution if we start from the origin and use gradient descent.[7]

So in this setup, regression finds the linear combination of feature vectors which adds up to the "label" vector, while minimizing the sum-of-squares of the combination weights.

Feature counts:  

There is a very large number  of features (we'll take the limit as ), and some proportion  of the features are copies of the "perfect feature".  Thus, there are  perfect features (all just copies of the label vector) and  random features.

Our first goal is to characterize the behavior in terms of .

Solving for the expected regression weights

If we have at least  linearly independent "random features", then we can definitely fit the training labels using random features alone.  If we break each feature into components parallel and perpendicular to the label vector, then the weighted sum of the parallel components must equal the label vector, and the weighted sum of the perpendicular components must cancel to zero.

As , we won't have to worry about components perpendicular to the label vector, because the average of those components will go to zero in our weighted random set.[8]

 Let  be the weight on feature , and let  be the label vector.

At L2-optimality, the ratio of  to  must be the same for every , so we have   for some constant .

Now define the "perfect contribution"  as the length of the weighted sum of the perfect features, and the "random contribution"  as the length of the weighted sum of the random features.  .

And thus  only if .

Since this is meant to be about random features in MLPs, we are interested in the case where  is close to zero.  So for our purposes, the perfect features contribute more iff .

Note that the sum-of-squared-weights for each feature type is exactly proportional to the contribution for that type, since you can substitute  for  in the derivations for .

Reframing in terms of complexity

Suppose we define a complexity measure on features such that .[9]  Then our result says that the "perfect features" contribute more iff .

Remember that  is the size of the training set, so this is amounts to an data requirement that is exponential in the complexity of the desired feature.

Generalization and "ensemble prior"

The influence of the perfect features on any particular data point scales linearly with .  Thus, for small , their influence on generalization behavior is linear in , and declines exponentially with complexity.

Another way to phrase this exponential decline is to say that the complexity of contributing features goes ~logarithmically in dataset size.  This is quite harsh (e.g. ~40 bits per feature even at 1 trillion datapoints), leading me to expect poor generalization on interesting tasks.

Regression on infinite random features seems to be what I will call an "ensemble prior" -- a way of modeling data which prefers a linear combination of many simple features, none of which need to be good on their own.  This is in sharp contrast to Solomonoff induction, which seeks hypotheses that singlehandly compress the data.

Finally, this "ensemble-ish" behavior is corroborated in toy experiments with shallow MLPs.  I ran experiments fitting MLPs to 4-item datasets in a 2d input space, and plotting the generalization behavior.  With small MLPs, many different generalizations are observed, each of which tends to be fairly simple and jagged geometrically.  However, as the MLPs are made wider, the generalization behavior becomes increasingly consistent across runs, and increasingly smooth, ultimately converging to a very smooth-looking limiting function.  This function has a much higher circuit complexity than the jagged functions of the smaller nets, and is best thought of as a limiting ensemble of features.

Acknowledgements

This post is based on work done about 11 months ago in the SERI MATS program under the mentorship of Evan Hubinger.  Thanks to MATS and Evan for support and feedback.

  1. ^

    When initialized and trained in the standard way

  2. ^

    See here and here

  3. ^

    Resulting from L2-regularization or no regularization.  My guess is that L1 behaves very differently.

  4. ^

    I have not tested this, but strongly predict it based on my result

  5. ^

    Plus the bias (constant feature)

  6. ^

    This is an oversimplification, but sufficient to get a good result

  7. ^

    In the standard initialization, we start at a random point in weight-space, rather than the origin.  This has the effect of adding a Gaussian-random offset to the solution point in all dimensions which don't affect behavior.  The analysis is very simple when we rotate the basis to make the hyperplane-of-zero-loss be basis aligned.

    This toy model will simply ignore the random offset, and reason about the minimum-L2 point.

  8. ^

    I don't prove this

  9. ^

    I'm just pulling this out of thin air, as a "natural" way for a "complexity measure" to relate to probability. This section is just tautological given the definition, but it might be illuminating if you buy the premise.

  10. ^

    The number of items in the dataset, which is also the length of each feature vector

0 comments

Comments sorted by top scores.