QFT and neural nets: the basic idea

post by Dmitry Vaintrob (dmitry-vaintrob) · 2025-01-24T13:54:45.099Z · LW · GW · 0 comments

Contents

  Reminders: formalizing learning in ML and Bayesian learning
    Learning and inference in neural nets and Bayesian models
    Prediction
  Field theory picture: cutting out the middle man
  Laws of large numbers
  Takeaways and future discussions
None
No comments

Previously in the series: The laws of large numbers [LW · GW] and Basics of Bayesian learning [LW · GW].

Reminders: formalizing learning in ML and Bayesian learning

Learning and inference in neural nets and Bayesian models

As a very basic sketch, in order to specify an ML algorithm one needs five pieces of data. 

  1. An architecture: i.e., a parametrized space of functions that associates to each weight vector a function from some input space to an output space.
  2. An initialization prior on weights. This is a (usually stochastic) algorithm to initialize a weight from which to begin learning. Generally this is some Gaussian distribution on the weight While this is often ignored, in many contexts it is actually quite important to get right for learning to behave reasonably.
  3. Training data. This is a collection of “observation” pairs with and
  4. A loss function. This is a function on which operationalizes a measure of how well agrees with the data D.
  5. A learning algorithm/optimizer. (“Learning algorithm” is used in theoretical contexts and “optimizer” is frequently used in engineering contexts.) This is an algorithm (usually, a stochastic algorithm) for finding the “learned” weight parameter , usually by some local minimization or approximate minimization, of the function .

In most theoretical analyses of models, one uses gradient descent to conceptualize the learning algorithm. Some more sophisticated pictures (such as the tensor programs series) more carefully match realistic stochastic gradient descent by assuming a discontinuous process with finite learning steps rather than continuous gradient descent. Learning algorithms used in industry tend to include more sophisticated control over the gradients by things like gradient decay, momentum, ADAM, etc.

All of these algorithms have in common the property of being sequential and local (i.e., there is some step-to-step learning that ends when it converges to an approximate local minimum.) However, when working theoretically, a learning algorithm doesn’t have to be local or sequential. 

Bayesian learning is one such (non-local and non-sequential) learning algorithm. This algorithm converts the learning problem to a Bayesian inference problem. Here the dictionary from ML to statistics is as follows:

  1. Architecture Deterministic Statistical model. Here deterministic means that each latent implies a deterministic mapping from inputs to outputs.
    1. In particular, weight parameter  latent.
  2. Initialization prior  Prior on latents.
  3. Training data  observations.
  4. Loss function A method of “degrading” deterministic functions to probabilistic ones (with for example “MSE loss” converting the deterministic function to the probabilistic function with Gaussian indeterminacy: (For a parameter implicit in the model.)

Finally, the main new aspect of the Bayesian model is that “component 5”, i.e., the “learning algorithm/optimizer” in the list of ML system components above, is set to be “Bayesian inference” (instead of one of the local algorithms used in conventional learning). Here recall that Bayesian inference returns the stochastic (rather than deterministic) posterior distribution on weights: given by conditionalizing on the observed data .  

Some observations.

Prediction

All the algorithms introduced above are to learn (either deterministically or stochastically) a weight given some data . In the Bayesian inference context the stochastic learning follows the posterior probability distribution

However no one (whether in ML learning or inference) is really interested in learning the parameter itself: it lives in some abstracted space of latents or weights. What we are really interested in is prediction: namely, given a set of observations , together with a new (and in general, previously unobserved) input value , we want to extract a (stochastic or deterministic) predicted value  

The reason why it’s generally enough to focus on inference is that in both Bayesian and machine learning, learning leads to prediction. Namely, given a (deterministic or sampled) latent parameter , we automatically get a predicted value by setting Here note that the randomness on y can come from two sources: both the function and the latent can be stochastic in general. 

Thus most learning paradigms function via the following pipeline:

Model + data  Learned posterior on latents  Prediction .

While most Bayesian and learning contexts tend view prediction as an afterthought, in the following section we will focus on decoupling prediction from the rest of the learning paradigm.  

Field theory picture: cutting out the middle man

The key idea that leads to the field theoretic paradigm on learning (though it is generally not introduced in this way) is cutting out inference from the prediction problem. This is easier to do in Bayesian learning setting, though also entirely possible in other ML settings[2]. For today’s post at least, we will focus on the Bayesian learning context; note that in theoretical analyses, the Bayesian paradigm is often easier to work with, as it corresponds to a “thermostatic” rather than a more general “thermodynamic” picture of learning.

Recall the pipeline I mentioned in the previous section:

Model + data  Learned posterior on latents  Prediction .

We will enter the “physics” picture by cutting out the middle man and instead considering the shorter pipeline:  

Model + data  Prediction .

In the Bayesian paradigm, prediction can be conceptualized without ever discussing latents. Namely, going back to the bayesics, after abstracting everything away, a choice of model + prior implies a joint probability distribution on data: . Now n is just another variable here, and so we can throw in “for free” an extra pair of datapoints:

The Bayesian prediction can now be rewritten as follows: Here out of the variables , we condition on : namely, all the variables except y, and for our prediction we draw y from the resulting posterior distribution. 

Now while the latent parameters have been flushed out of these expressions they’re still there, just, well, latent. The key idea in the “physics approach” to machine learning is that the prediction problem is more physical than the inference problem (at least in many cases). The specifics of the model, and the process of converting an abstracted-out weight to a nice prediction function matter for our analysis, to be sure. But they matter as a back-end “sausage making” process. Physicists love taking such complex processes and replacing the sausagey specifics by summary analyses. In other words, the typical physicist move is to start with a complex system, then observe that most of the components of the system don’t matter for the result at a given level of granularity, and what matters is some extracted-out, averaged or “massaged” values that are mathematically much nicer to analyze. The art here is to extract the correct summary variables and to (as carefully as possible) track the relationship between different aspects of precision and scale

Laws of large numbers

Ultimately, we’ve seen that our prediction problem reduces to a conditional probability problem P(y|x, D), conditionalizing on the observed data and the new input. In today’s paradigm (and for most of the rest of our discussion of the “field theoretic view of ML”), we will assume that the size n of the dataset D is very small compared to the width - perhaps only size n = O(1). Thus the problem of conditionalizing on 2n+1 variables is taken to be “easy”, and all we need to do, really, is to find the probability distribution on tuples of n+1 input-output pairs. Since in this question the “predicted” input-output pair (x,y) plays the same role as the “known” pairs , we can drop the distinction between them and consider (without loss of generality) only the problem , of finding the probability of some set of input-output data

Now we can’t keep our sausages wrapped forever: at some point we have to take a peek inside. And when we do, we notice that this probability distribution on the data is induced directly from the prior on weights:

In other words what we want to do is 

The key idea now is to use analysis similar to the discussion in the “Laws of Large numbers [LW · GW]” post to see that, for large width, we can extract all the information we need for a good approximation[3] of the probability distribution by observing how the first few cumulants (equivalently, moments) of D transform from layer to layer of our neural net. 

Takeaways and future discussions

I’m going to save a more detailed analysis here for future posts (some of which will be joint with Lauren Greenspan). But before concluding, let me just state a few of the key points that you will see explained if you read our future posts.

  1. It turns out that under realistic assumptions on the width, depth, initialization, and activation of a neural network, and assuming the size of the dataset is n = O(1) we can reduce the cumulant expansion formula at arbitrary levels of precision in 1/width to tracking the evolution from layer to layer of a finite (i.e. O(1)) collection of numbers: namely the first four symmetric moments of the joint distribution on the activations. If n = 1, this is tracking just two numbers, namely the total variance, i.e., expectation of the norm squared of a datapoint, and the total fourth moment, which is the expectation of (and the analysis consists of tracking these two values as the input x evolves from layer to layer in our neural net). If there are multiple datapoints, i.e., , we need to track all n total second moments: , and fourth moments, namely the expectations of as i, j run through the different inputs. Here the “symmetry” (that lets us look only at even moments) corresponds to a certain property of the architecture; it can be weakened, in which case we still only need to track a finite number of (th) moment values – note in particular that the number of values tracked doesn’t depend on the width of the network).
  2. Here the reason we only need to look at moments of degree to get all orders of 1/n correction is related to a universality property (an idea that originated in statistical field theory, and is closely linked to renormalization), that only emerges for suitably deep neural networks (i.e. more or less, the depth has to satisfy n << depth << width, for n the number of data point). If we drop the requirement that depth is small and look at shallow networks, e.g. networks with depth = 2, we start seeing more moving parts. Here in order to get an expansion accurate to precision range we need to track all the symmetric moments of degree . While they are not necessary to see that the problem reduces to “cumulant math”, actually doing the relevant cumulant math can be significantly simplified by using Feynman diagrams.
  3. In points 1-2 above, it is crucial to assume that the number of inputs is very small compared to the width. Also the formally provable results here are all perturbative, and return predictions for neural net behaviors which are "perturbative" small corrections of a Neural Network Gaussian Process (NNGP). It is ongoing work, which I hope we’ll get a chance to discuss and speculate about, how to treat neural nets where the data distribution is large (compared to other parameters), and emergent (and in general, non-Gaussian) “data regularities” become important to keep track of. Here the promise of the “physical approach” is not to give an asymptotic formula as precise as we see in the small-data large-width limit, but rather to more carefully abstract away the “sausage-making” process of explicit inference on weights. The resulting analysis should capture both averaged properties of the activations (appropriately conceptualized as a random process) and averaged properties of the data distribution and its associated regularities. This is very vague (and the work in this direction is quite limited and new). But I hope by the end of this series to convince you that this powerful approach to conceptualizing neural networks is worth adding to an interpretation theorist’s arsenal.
  1. ^

    Small print: more generally, if it’s not reasonable to assume that a maximizing is unique, one should take the uniform distribution on maximizing .

  2. ^

    And worked out, both in PDLT and, for a larger class of learning algorithms, in the tensor programs series.

  3. ^

    Notice inherent here an assumed choice of scale.

0 comments

Comments sorted by top scores.