Basics of Bayesian learning
post by Dmitry Vaintrob (dmitry-vaintrob) · 2025-01-14T10:00:46.000Z · LW · GW · 0 commentsContents
Introduction Statistical models Alternative prediction practices: MAP, MLE, and the large data limit Misspecified models Upshots Hardness of inference and prediction Neural nets as statistical models Underparameterized and overparameterized networks Stay tuned for next time None No comments
See also: the “preliminaries” section in this SLT intro doc [? · GW].
Introduction
This is a preliminary post for the series on “distilling PDLT without physics”, which we are working on joint with Lauren Greenspan. The first post in this series is my post on the “Laws of large numbers [LW · GW]” (another preliminary) which is completely independent of this one.
As a reminder, the book PDLT (“Principled of Deep Learning Theory) uses statistical physics and QFT formalism to describe certain interesting (and relevant) critical behaviors of neural nets.
In this post, we'll introduce Bayesian learning as a perspective on deep learning through a particularly physics-compatible lens. Bayesian learning is a special case of statistical inference, and the first half of the post will be a review of statistical inference. Recall that (for a particular function approximation context), a machine learning model is a parameterized function where is a latent "weight" parameter that needs to be learned. Writ large, a “learning” problem for a model of this type is a method to reconstruct a best (probabilistic) guess for the weight from a collection of observed input-output pairs In typical machine learning implementations, the parameter is chosen by a process (like SGD) for locally optimizing some loss function associated to the parameters and the function . Bayesian learning gives an alternative (and in some sense “optimal”) statistical method for inferring a best-guess weight from knowledge of the model and observations. The two methods agree in certain limiting settings, but do not agree in practice either in realistic models or in interesting theoretical contexts, including in the formally analyzable infinite-width setting we will be looking at in later posts.
As I’ll explain, Bayesian learning is
- theoretically nicer easier to analyse than SGD (with Bayesian learning corresponding to thermostatics and SGD corresponding to thermodynamics)
- computationally much harder than conventional ML methods like SGD – in particular Bayesian inference is in general NP hard.
Ultimately for real-world understanding of neural nets, we are interested in the more realistic SGD learning – and in fact, QFT methods do give ways of analysing this limit. Nevertheless, it turns out that studying the Bayesian (“thermostatic”) limit of theoretically tractable systems is a valuable first step (and in some contexts a good approximation) for understanding algorithms found by realistic learning methods like SGD.
While this post mostly covers standard Bayesian concepts, I’ll make an effort to emphasize perspectives that become important for later posts on applying physics-inspired techniques to understand neural network behavior, particularly in overparameterized regimes.
Statistical models
In a Bayesian model of reality, we make a distinction between nature and the observer. We assume that nature secretly knows some hidden latent information, which we encode as a “secret” vector (the dimension will later of course later correspond to the number of weights, i.e., parameters, in a neural net). The observer is allowed to run a series of independent “experiments”, where she chooses an input x for corresponding to a collection of “experimental parameters” and nature outputs a value y, where is some collection of measurements that depend stochastically on x and . The value y can be a deterministic function , but more generally it is a probability distribution[1], with probability density written
In the Bayesian framework, we assume that the observer has complete knowledge of the “overall situation”, minus knowledge of the latent. This means the observer knows:
- The complete conditional probability function (this is a probability-distribution-valued function in x and ).
- A prior on latents, corresponding to “how likely nature is to choose a given prior”.
The picture of the Bayesian observer interacting with nature is of course an idealization of what humans are able to do (in particular, since we don’t understand physics well enough to reduce the entire universe to a probabilistic process depending on a latent variable). In practice, the latents in the experimental model for P(y|x, \theta) are usually either given by small causal models which model interactions of different high-level variables, by informed mathematical guesses on the shapes of random processes (such as neuron firing), or, conversely by large, highly expressive models that start with very few assumptions on the true latents and how they affect experiments, and deduce regularities by fitting a large amount of data. We’ll see that in a certain limit, if the “true” data distribution is (even by accident) included in your prior (i.e., included in the model and with nonzero prior), then with sufficient data, the correct prior becomes recoverable; this observation is also at the root of machine learning.
Now the core problem in Bayesian inference is one of inversion. Namely, “nature’s formula” allows it to repeatedly (and independently) output (stochastic) experimental values given the experimenter’s choice of inputs and its hidden latent parameter . The observer’s tasks now comprise the following:
- Inference: Find the a best-guess probability distribution on the latents, namely the probability of being the true latent given the knowledge of the experimental results The result is called the posterior distribution on latents.
- Prediction: Given a new (possibly “previously unobserved”) input , guess (stochastically) the corresponding value of y given the same experimental data as before, namely inputs and outputs (here as before, we assume that nature generated all of these for the same latent variable , which the observer does not have direct access to).
In fact, despite all the grandiose words (“nature, latent, experiment”) that I used, this process is, at least formally, no harder than applying a basic manipulation of probabilities: namely, conditionalizing and marginalizing out variables.
Specifically, the information the observer has (the formula for and the prior on ) is enough to define a big probability distribution
that gathers together all the n input-output datapoints and the latent variable [2]. Here I’ve compressed the notation collection of n pairs of input-output vectors to the bold
We now simply define Bayesian inference to be the result of conditioning on the specific observed x and y:
In other words, we look at the probability of in the giant distribution conditioned on observing the input-output data in n experiments: .[3] By Bayes rule (i.e., the formula for conditional probabilities), together with independence of the experiments, this is given by the formula:
Here we could write a formula for the normalizing integral as the integral but the important thing is that it’s independent of (it’s simply there as a normalizing expression in the formula The prediction is defined similarly, where now instead of predicting conditional on the experimental input-output pairs , we now consider the larger probability distribution on This incorporates both the n “experimental” input-outputs x, y, which we’re going to condition on, and the “new” input-output pair (x,y) (only one input and one output), which we’re going to predict. Assuming we have computed the probability distribution on the latent as above, we simply substitute in the (known to the observer) expression for to get
Alternative prediction practices: MAP, MLE, and the large data limit
I find this to be the most finicky and annoying part of discussions about Bayesian inference. It will be useful for comparing with other learning paradigms, so feel free to skip the body of this section and look instead at the “upshots” summary below.
Note that in the above expression, we had to take an integral over to get our prediction on y. An alternative approach would be to instead use what is called the maximal a-posteriori likelihood (MAP) prediction. This corresponds to looking at the value that maximizes and using this special value to get a “maximal likelihood” guess for the distribution on y as The name “maximal likelihood” here comes from the fact that, after dropping the prior on and the normalization, we are simply choosing the value of that maximizes the likelihood of seeing the data x, y. Note that is not in general the correct probability distribution on y given the observer’s knowledge (which we found in the previous section). For example, if we have only two options for and is only slightly larger than , then the correct prediction on y will factor in the possibility that the true latent is
The MAP prediction is (in most cases) strictly less useful than the true Bayesian inference prediction. However it is easier to work with for two reasons:
- We replace the (potentially computationally expensive) integral over by a single value .
- Since we only care about the parameter and not its exact probability, we can ignore the normalizing factor (independent of ), also a potentially expensive integral.
This leaves us with the much easier formula for : namely, it is the latent that maximizes the value A frequent further simplification (that we will use shortly) is to take the log of all probability values and replace the product with a sum (since taking logs is monotonic, this doesn’t change the maximum).
While less accurate, the MAP prediction can be proven to asymptotically give the same value as the Bayesian inference prediction in the limit of many datapoints[4]. In many cases of interest, including the infinite-data limit that we discuss below, the update terms eventually dominate the contribution of the prior , except for values of the latent outside the support of the prior, so with . In particular, we can drop the prior from the MAP formula for . To this end, we define While the difference between MLE and MAP estimates matter for differences between Bayesian and frequentist approaches, in realistic systems (including ML models), the difference between the MAP and MLE estimates tends to be quite small. There is once again a theorem that, assuming some analytic well-behavedness properties, the MLE estimate for P(y|x, x, y) converges to the correct distribution in the large-n limit.
Misspecified models
In the previous section, we assumed that the “true” stochastic function selected by nature is equal to a function for some value of the latent in our inference model. Sometimes this is a reasonable assumption, but in fact it is often possible to productively use Bayesian inference in contexts where for any latent . In this case, we say that the model is misspecified. Now in the misspecified case we can still take the infinite-data limit for both inference (guessing ) and prediction (guessing the probability distribution , in both cases assuming that the data is drawn from but the model used for prediction is
It turns out that in the infinite-data limit, it is still often (indeed, in a certain rigorous analytic sense generically) true that both inference and prediction converge to a single-latent distribution associated to a deterministic parameter in the misspecified model. Note that in this context, unlike before, it is important to commit to a specific probability distribution on the “experimental parameters” and assume our large number of datapoints (x,y) have x drawn from this distribution. Now recall that the definition of the finite-data MLE is a value of maximizing the product likelihood of the samples , or equivalently the sum of their log likelihoods: Replacing the sum by an average (which does not affect the argmax) and taking a stochastic limit as n goes to infinity, we see that the sum over pairs turns into an integral over the probability measure (Here we see the dependence on the “input prior” on x’s). We thus see that we can estimate the law-of-large-numbers prediction on the MLE value of as : namely, this is the value of maximizing
or in other words the value of that minimizes the cross-entropy between and . Note that in this case, minimizing the cross-entropy is equivalent to minimizing the KL divergence. There is now a theorem that, under nice analytic conditions, and assuming the minimizing the cross-entropy is unique, the limiting probability distribution we get on converges to a deterministic delta distribution on and prediction converges to If on the other hand the argmax value is not unique, the inference problem converges to a suitable probability distribution on the different cross-entropy-minimizing values of .
Upshots
For people who don't have the the bandwidth to keep in mind a bunch of statistical terms (or just want a way to orient towards intuition and away from technicalities), a typical "cheat" in statistics is to replace the statistical futzing around of Bayesian inference to get the "optimal" posterior distribution on latents with a simpler deterministic "maximal likelihood" guess, which answers the question of "which value of maximizes the likelihood of seeing the data (\mathbf{x},\mathbf{y}) as the result of n independent draws. In the limit of n i.i.d. samples (x,y), the maximal likelihood estimates will converge a deterministic distribution on the true value of the parameter ("chosen by nature"). A variant of this is also true when the model is misspecified, i.e., our statistical model is not expressive enough to capture the true distribution : here the deterministic value of that inference will give in the limit maximizes the a certain rescaled statistical limit of the log likelihood called the (negative) cross-entropy between given by nature nature and given by the model.
Hardness of inference and prediction
All the techniques we've seen so far: inference, prediction, MAP and MLE, are NP hard in general. To see this, simply observe that the least complex of these (namely MLE) requires finding the maximum of a likelihood function on a w-parameter space, and for any function f(w) it is easy to manufacture a model (of "comparable complexity" to f) whose maximal likelihood estimate is at least as complex as finding the maximum of f. Now for any reasonably expressive class of functions f (degree-d polynomials or d-parameter neural nets), this problem is NP hard in general.
At the same time, in analyzing statistical models that are useful in real life, "good-enough" inference and prediction is achievable using machine learning and statistical physics methods. These are a bit harder for generative models (where for any given input x, the distribution on y may be an arbitrarily complicated probability distribution), but are easier in the context of more basic neural nets where for a fixed x, the "true" stochastic function is modeled as either deterministic or as drawn from an easy-to-parametrize family of probability distributions (e.g. Gaussians). We consider this simpler context below as it is sufficiently for the field-theoretic picture we will look at in future posts.
Neural nets as statistical models
A neural net model has a lot in common with a statistical model, except it is (at first) deterministic. Recall that a neural net is a parametrized family of functions: for each “weight value” the model gives a function taking in a vector and outputting a vector .
In order to train a neural net, one needs the “training data”, which is a collection of known input-output pairs (similarly to the Bayesian experimental setup), and two more pieces of information: namely, the prior on weights (to randomly choose an initial weight parameter from, as a starting point for gradient descent) and a loss function , which measures a continuous surrogate for the accuracy of the model’s classification relative to the observed (training data) value . The learned value of (in the standard paradigm) is then the limit of applying gradient descent to under the gradient of the average loss
For the sake of this short post, we’ll consider only the MSE loss, defined by A similar or analogous analysis can be made for any other standard loss function. Note that in this loss formula I am (slightly uncharacteristically) introducing a denominator ; here is a free hyperparameter corresponding to a characteristic length scale of (roughly) “how badly off we allow the classification result to be from the experimental result before significant penalties”.
Now exactly the same parameters define a Bayesian model, with the special property that is always a Gaussian with a fixed standard deviation (and mean depending on the latent variable ). Specifically, we have the following dictionary:
Neural net || Bayesian model
Initialization prior of NN weights || Prior on model parameters
Input-output training data || Experimental data
Deterministic classification function || Mean of P(x|\theta)
The one bit of this correspondence that’s not completely straightforward is the correspondence between the choice of Gaussian as the probability distribution and MSE loss.[5] In this case we have the following simple result:
Proposition. We have an identity between the log of the MLE likelihood of in the statistical model and the MSE loss of the neural net associated with weight parameter : namely, Here the rescaling factor of n is because on the RHS the single-input losses are averaged (in order to maintain the same scale as more parameters are added) and on the LHS, they are not, since each new data point gives a comparable amount of new information. This factor of n causes a perennial headache in fields that compare the two points of view, such as SLT, but it is of course irrelevant if one’s task is minimization.
In particular, this implies that in the large data limit (and assuming the process generating the real data can be captured in this model) the Bayesian prediction of will converge to the prediction for the parameter that (globally) minimizes MSE loss. The assumption that the true distribution can be captured is actually a significant one here: it is in fact a pretty strong assumption that the “true” probability distribution on y: is a Gaussian with standard deviation . In fact, the standard use case of this learning algorithm is when we assume that the true model is deterministic, i.e. given by some function . In this case, no matter how expressive the neural net is, there is no value of for which any of the (posterior distribution, MLE distribution, MAP distribution) on y is (even approximately) correct, as the true distribution has variance 0 whereas the posterior of any of these distributions, being either Gaussian or a convolution with a Gaussian of variance , must have variance . In other words, this Gaussian inference problem is misspecified. At the same time, it is equally obvious that in this setting, if there exists a such that , then the latent value is the MLE for the true (deterministic) model[6]. It is reasonable to ask in this context why one should even bother with Gaussians (and not use a model with deterministic probability densities, or arbitrary probability functions with a unique maximum). In fact, Gaussians (or other functions with a smooth and convex logarithm) appear in this context because of the nature of learning: we want the log likelihood function , which gets summed in the MLE, to be nice, smooth, and convex in order to facilitate learning, and a quadratic function is a particularly natural choice[7].
Underparameterized and overparameterized networks
When the number of degrees of freedom on the total dataset (x,y) is higher than w, the number of parameters on our latent , we say that our model (whether for a neural net or a probability distribution) is underparameterized. Conversely, we say that the model is overparameterized.
We’ve seen that in the infinite data limit, any Bayesian model converges to its MLE, or to a cross-entropy-minimizing limit in the misspecified case. In particular in the case of interest to us, i.e., when approximating a deterministic function by a Gaussian neural net , the limiting distribution on will in general be deterministic, and concentrated on the value that minimizes the integral (over inputs x) of the MSE loss.
In the finite-data context, it turns out that we can still make a related statement: namely, if we fix the number of datapoints n but take the limit as the variance in our Gaussian model goes to 0, the Bayesian inference problem once again converges to a distribution on the MLE-maximizing values .
Now if our model is underparameterized then we generically expect the MLE-maximizing value to be unique. The prediction then converges to the deterministic (variance-zero limit) function . However, in the overparameterized context, the posterior distribution on is generically nondeterministic even in the variance limit: indeed, the posterior is going to be some more general probability distribution on the “perfect fit” subset Similarly, the prediction problem will have a nondeterministic distribution (for fixed x) supported on the values for varying over all latents with perfect fit.
In the next one or two posts in this series (joint with Lauren Greenspan) on QFT approaches to learning, we are going to assume exactly this setting. We will treat the problem of learning a deterministic function in a certain infinite-width limit. In other words we will be assuming that the data is finite but the number of degrees of freedom on is assumed asymptotically large. In this case, despite each specific probability distribution being deterministic, the inference distribution and the prediction distribution are very nondeterministic.
Stay tuned for next time
Next time, we'll use the ideas from my previous post on "laws of large numbers" to look in a Bayesian (aka "thermostatic") frame at a very general class of learning problems for asymptotically large-width neural nets. Here the beautiful core ideas from PDLT will show up. Namely, when we suitably tune scale hyperparameters, the only values that matter will turn out to be a certain finite set of cumulants determined by the activation function and the geometry of the input set. In this context we'll see that properly tuned deep neural nets prefer certain "stable" values of cumulants, and explain the relationship between these stable settings and the physical notion of renormalization fixed points. I'm excited about this (unusual from a physics point of view) way of introducing the concept, as renormalization is a core idea from physics that in my opinion is undervalued as an intuition pump in machine learning (likely at least in part because of the amount of physics context needed to see a first example in other contexts).
- ^
In fact, in many contexts of interest, a Bayesian model assumes that the same experiment is run in all instances, i.e., the “input” space of experimental variables is a single point, . This case contains almost all the complexity of the input-output model; we include inputs because these better parallel the behavior of (classification) neural nets. Another common context is where instead of being chosen by the experimenter, the input x is also sampled “by nature” from a secret distribution (depending on the latent information ). In this case, one can actually model the joint distribution on and – i.e., a -dimensional space of outputs, with no input from the observer on choice of the experiment. While the “model of reality” in this case doesn’t separate inputs and outputs, it is often the case that the “information of interest” one extracts from the Bayesian model here does depend on this division, with the various conditional probabilities P(y|x) are more important than the absolute probabilities P(x,y).
- ^
There’s no a priori distribution on inputs x (since these are chosen by the observer), but one can promote to a distribution by choosing any reasonable prior P(x), and it will cancel out of any computation.
- ^
As an exercise, check that the prior on x is both in the numerator and denominator of the resulting expression.
- ^
This is a bit tricky to operationalize, and is called “convergence in distribution”, or “in law”: specifically, we view the samples (x, y) as random variables (with the experimental parameters chosen from some fixed everywhere-supported distribution). Then the statement is that, under some analytical niceness conditions, the prediction probability distribution P(y|x, x, y) is very close to the true probability distribution with probability approaching 1 in choice of x, y. Here as usual in statistics, the notion of “very close” requires some analytical finagling that we won’t discuss.
- ^
Other loss functions lead to other choices of “softening” the deterministic function to a probability distribution. There is even nothing wrong a priori with defining a deterministic probability distribution with no randomness once (x, ) are fixed – however, this can lead to weird singular behaviors of the posterior as a function of the pairs, and it corresponds to a singular loss function that is impossible to learn.
- ^
Indeed, this follows from the fact that a Gaussian distribution has maximal probability at the mean.
- ^
There are cases where it is important to get not just a good single-value estimate for , but actually a complete probability distribution. Here if the set of possible values is a small discrete set of “labels”, then using a logit classifier with cross-entropy loss allows one to get any probability distribution , assuming sufficient expressivity of the neural net. In more complicated pictures, such as generative models, different methods, such as diffusion models, are used to specify and compute with flexible classes of probability distributions.
0 comments
Comments sorted by top scores.