Logits, log-odds, and loss for parallel circuits
post by Dmitry Vaintrob (dmitry-vaintrob) · 2025-01-20T09:56:26.031Z · LW · GW · 0 commentsContents
Basics of logits and logistic tasks “Parallel prediction” circuits Log odds and independent predictions Independence and circuits Appreciating the wisdom of the elders Interpretability insights None No comments
Today I’m going to discuss how to think about logits like a statistician, and what this implies about circuits. This post doesn’t have any prerequisites other than perhaps a very basic statistical background that can be adequately recovered from the AI-generated “glossary” to the right. I think the material here is good thing to know in general (thinking through this helped clarify my thinking about a lot of things), and it will be useful background for a future post I’m planning on “SLT in a nutshell”. If you want a “TL:DR” takeaway of the discussion that follows, the gist is that neural networks use logit addition to integrate (roughly) independent “parallel” information from various sources; and that thinking about just a very basic model of a neural net that performs this aggregates from a few parallel “black box” circuits is already a very informative conceptual toy model of stuff that neural networks do, akin to the “ideal gas” model in physics.
Basics of logits and logistic tasks
Most tasks solved by modern LLMs are some flavor of a logistic classification task. The origin of the “logistic” idea comes from statistics, and the elegant statistical context for using it (rather than other possible choices of loss) often gets lost to newcomers who were not statistics-adjacent in their past life (this was certainly true for me!). To remedy this, I’ll briefly explain, without proofs, a rough picture of why logistic loss is so nice, and in particular how it works for parallel classification programs. From now on, I am going to only work with boolean classification tasks, as other logistic tasks behave in the same way, but with more complicated notation. As small print, I’m also going to assume we are in a “large-data” limit (i.e., the number of training datapoints is very large compared to other relevant measures).
Assume we have a binary classification task , for x in some big distribution of inputs. Logistic classifiers (i.e., “classifiers that use logits”) try to classify by a “logistic” prediction, that depends on a weight parameter .
I haven’t yet told you the formula for the loss that actually gets learned, but the goal of the logistic task – i.e. the case where it obtains optimal loss – is to optimally approximate the “log odds” function function As a formula, the log-odds function is defined for any boolean random variable b (i.e., b is a probability distribution on ) as In words, if b is a biased coin, the log odds is literally the log of the odds ratio, which is the probability of heads over the probability of tails. Note that the probability distribution of a biased coin is uniquely determined by its log odds (since the probability distribution of a coin is determined by its “heads” probability, it’s not surprising that both are functions of one parameter). Let’s write for the boolean variable whose log odds are u (in formulas, the probability of heads in this case is Man, probability notation makes things look more complicated than they should.)
In the classification task, we are modeling as a separate random variable for each x. In other words, we have a separate “coin” y(x) for each input x, and the log odds is a (deterministic) function of the input x that the classifier wants to learn. Note, in particular, that we are modeling the “ground truth” y(x) as probabilistic. But in practice, many classification tasks are actually deterministic. Why is this ok?
Well, it turns out that for most LLMs, the question of whether y(x) is deterministic or probabilistic is entirely moot. Namely, for a general deterministic task y(x) (we often write in deterministic cases to denote the “correct answer”) depends on a bunch of “features” . We don’t need a formal notion of what the are – they could be discrete or continuous, one-dimensional or high-dimensional; the important thing is that they are functions of the input x[1], and most classification tasks need a lot of them to get perfect loss.
The model really wants to learn all of these parameters and do a good job of its classification task – it really does. But the world is big and complicated and the model is just a little guy. So the best it can do in general is to learn a small number of parameters, maybe and . Now after it’s tried its best, it does its classification task in terms of the features it’s learned, and tries to predict based on – in other words it wants to learn the boolean function . But the point is that this is not a deterministic function! In fact, also depends on an unknown probability distribution on the marginal (from the model’s point of view) parameters . Thus what started out as a deterministic task ends up a stochastic one, and we can again think of the model’s task as correctly predicting the log odds of the boolean random variable y; except now we can effectively conceptualized it as a function not of x but of the learned variables . I.e. the best it can do having learned three features is to learn a new function: The optimum of the actual function learned by the model under the assumption that only the first three features are “realistically learnable” is now (recall that features are deterministic functions of the input in this picture).
Now what happens for real models is more complicated. Features don’t neatly group into “learnable” and “unlearnable” ones, and even if we force the model to be a function of some set of “easily learnable” features, it’s not necessarily the case that the model will learn (or indeed has the expressivity to learn) the precise function as above. Nevertheless, this picture of “optimization given latent features” is quite a powerful one and can serve as a strong intuition pump for behaviors that actually occur.
Note that while we have discussed the optimum, I still haven’t introduced the loss. Let’s remedy this. Namely, if we have two binary probability distributions b, b’, their cross-entropy H(b, b’) is… well, some formula. Look it up. The important thing is that
- It is asymmetric in b and b’
- It is the expectation over b of some quantity depending on b’ (namely, the log probability. See - we defined it!)
- For fixed b, it is minimized when b’ = b.
Thus we define the cross-entropy loss as , where we have defined the “model’s guess” (at weight )
Here remember is the binary random variable with log-odds equal to u. You might be worried that I’m defining it assuming knowledge of the “true distribution” . But observe: the only place I’m using this “true distribution” is in the cross-entropy expression, with on the left side of a cross-entropy expression. This means that the loss is the expectation over the true probability distribution of something that only depends on , and it can be approximated for values with finite data by, well, just sampling at the finitely many known datapoints! This of course, if you unpack it, gives the familiar formula for the “finite-data” loss as an average of cross-entropy over samples.
“Parallel prediction” circuits
Log odds and independent predictions
There are so many measurements in probability theory that it’s hard to keep track of them. But log odds has a particularly nice property related to prediction, which characterizes it uniquely (up to scale).
To talk about it, let’s take a step back and discuss a couple of superforecasters. Forecasters are human-shaped classification models, that output a probability distribution on an event occuring (i.e., a boolean random variable), based on information about the state of the world. Now usually, two forecasters are better than one. However there’s a catch. If the two forecasters are exactly identical, then they will output identical predictions, and there is no advantage in paying two salaries. At the other extreme, if one of the forecasters is just way, way better calibrated than the second one, then there’s no point listening to what the second one has to say. But the place where two forecasters really shine is when they are
- perfectly well-calibrated[2], and
- maximally independent.
In terms of probability theory, the notion of two predictors being “maximally independent” is equivalent to demanding that their predictions are conditionally independent when conditioned on any real event. One can also phrase this property in terms of information theory (though we will generally not use this language), where it is equivalent to saying: “Alice’s knowledge and Bob’s knowledge, measured as the mutual information each of them has with the state of the world, do not overlap”. (Note that they can be optimally calibrated given incomplete knowledge of the world! Being calibrated means having a “good estimate for the degree of your ignorance” and, unlike “being accurate”, being calibrated is not a big ask for realistic systems.)
Now the property of log odds is that given two perfectly calibrated and maximally uncorrelated forecasters Alice and Bob, the best prediction to make about the probability of an event E is In other words, odds ratios of independent forecasters multiply. Taking a short detour, note that this gives a cute quantitative characterization of the wisdom of crowds. If we have a crowd of n independent forecasters who all have independent information, and if each predicts probability (so 2:1 odds) of an event A occurring, then the best aggregate prediction is that event A will occur with probability (i.e., odds).
We had to put a lot of caveats on Alice and Bob here. But the fact is, when interpreting neural nets, it’s often quite reasonable in practice to think of what a model does as spinning up a bunch of parallel superforecasters of this type!
Independence and circuits
Let’s take a step back and think about our hard-working ML model. Remember that it’s been chugging along, trying its best to fit the “true” probability distribution y(x), or rather its log odds: (by a function depending on a weight parameter ). We discussed that, whether the true classification y(x) is deterministic or probabilistic, we can think of it as depending on a collection of features , which are (deterministic or probabilistic) functions of x. Realistically, our model can’t learn all the features, so it will do the best it can to learn to capture a few features, say , and then predict from these features alone. Now in general, the features our model learns can have all kinds of correlations. But – and this is key – it’s often the case that the model treats them as independent! In this context, the model treats the different calculations associated with processing the latents as independent and parallel circuits. In other words:
- The model independently computes a collection of separate functions , associated to its “best guess” prediction of f(x) given only one of the features, processing the different features by a collection of parallel and independent circuits.
- It adds the logits associated to these features together (i.e., adds the log odds, i.e., multiplies odds, i.e. aggregates predictions).
This is not a theorem or even a “soft law” of any sort – it breaks immediately (necessitating some more sophisticated causal analysis) as soon as either the model or how we conceptualize features becomes at all complicated. But it’s directionally true that at least some circuits combine in this way in many known examples:
- From patching analysis, it seems that vision simple convolutional networks process certain prediction data like “what is this animal conditioned on seeing its nose” and “what is this animal conditioned on seeing its ear” in this parallel and independent way.
- In our MATS research with Nina Panickssery [LW · GW] (my first AI project), we observed such decompositions in MNIST
- Neel Nanda’s analysis of modular addition (since refined and reinterpreted in a number of ways – this is the one interpreatability problem whose “inherent behavior” we are most confident of) observes that neural nets decompose modular addition into parallel independent circuits associated with Fourier modes. In fact, one can also do a theoretical analysis with a “random model” picture of modular addition to show that, in an appropriate operationalization, the different Fourier modes associated with modular addition actually should be viewed as giving independent information (and the “random model” can be shown using some concentration bounds to provably approximate reality with some bounded error – this is perhaps material for a future pose).
So the upshot is, we can model some decompositions of some neural nets as aggregating independent predictions by summing logits (i.e., summing the different “single-circuit” functions ). In fact, if we squint enough, it seems likely from a number of indirect sources of evidence that every nontrivial neural net in the world has at least an aspect of this: i.e., some layers can be conceptualized as approximate sums of quantities computed from previous layers (with the understanding that a bunch of other non-parallel behavior is going on as well, and the thing we’re calling “circuits” that get added here can be massive combinations of other substructures).
Appreciating the wisdom of the elders
Notice also that this explains the intuition behind using logistic loss (rather than other kinds of loss) for neural nets. Neural nets love to add stuff. I mean that’s most of what they are: giant linear gadgets with a bit of extra structure thrown in. And so if there is any value in understanding the world, at least partially, via a big collection of conditionally independent processes, then it would be great if combining such processes happened to be linear. This was (as far as I understand) the actual reason for using logits for early neural nets, since the people designing them actually knew statistics. Now most places I read about it say “logits are a nice way of encoding probabilities that people decided to use early on for archaic reasons, and it seems to work better than other methods” – now you know why!
Interpretability insights
In the above I’ve distilled one way that (probably) circuits combine in neural nets. What does this tell us?
As a first aside, note that this discussion is entirely parallel to the discussion I had in my first technical post of the month on grammars [LW · GW], about how “rules” combine conjunctively to form grammars, and what the analogues of this are for probabilistic grammars and logits – the story is entirely analogous, though the language is different.
But why do I care about this so much in more general classification problems? What insight can we get from this very basic way of combining information in neural nets?
Well, science works by trying to find interesting behaviors in minimally interesting models, and it’s a big bonus if the minimally interesting model actually corresponds to an approximation, or a part, of stuff we see in real life. The thermodynamics we use to understand complicated interactions in superconductors shares a surprising number of important features in common with ideal gasses; humans are surprisingly similar to drosophilae in many relevant ways. And in my opinion a surprising amount of intuition about neural net behavior, that seems to occur from the humblest MNIST to the wisest Llama, can be seen by thoughtfully analysing things going on in parallel circuits of a classification task. Let’s put together some upshots here.
- Interesting spectra of energy scales. As I have been consistently harping on about [LW · GW], different NN solutions for solving the same classification problem have different characteristic loss scales. The loss scale that most people are aware of, that also (more or less – I’ll explain in a later post) corresponds to the characteristic “Watanabe temperature” scale that has been used in SLT to date, is associated to memorization: namely, if a neural network has a sophisticated general mechanism that takes up only a fraction of its internal parameters, then it can use the rest of its parameters to memorize extra data points, at a cost of one parameter per datapoint (equal to 1/n accuracy improvement where n is the number of samples). Relatedly, some overparameterized neural nets will just choose to memorize their data from the start. However, if we take the number of samples to be very large (or alternatively, compare at test loss and test accuracy of non-memorizing NN's), then we see e.g. in modular addition that different algorithms attain exponentially different (very small) loss optima, depending on how many parallel circuits get learned[3]. Having in hand the picture of “parallel circuits”, we can see this in action. Namely, suppose that two different neural nets go forth and learn what they can about the world, and we see that the first neural net learned a two-circuit logit function before converging to a local loss minimum, whereas the second learned a three-circuit logit function As before we conceptualize the functions as parallel circuits, depending on a triple of features of the data: respectively. Suppose moreover that (as above) the features are conditionally independent, and that the classifiers are pretty good and pretty close to deterministic, with odds (i.e. around 99% accuracy) for each circuit. Then[4], we can assume that loss is proportional to accuracy, is roughly proportional to odds. This gives us a loss estimate on the order of 1 / 10,000 for the first neural network f, and an exponentially better loss estimate on the order of 1/ 1,000,000 for the second network f’. By varying the number of parallel circuits[5], we can thus get widely varying loss regimes. These are associated to interesting spectral properties of the tempered Boltzmann distribution, already in this very simplistic (essentially “ideal gas”-style) regime.
- Regularization. Relatedly, this phenomenon helps explain why regularization (and the related forms of “implicit regularization” of neural nets as e.g. explained in the omnigrok paper) improves generalization behavior. Namely, note that the discussion we’ve had in the previous sections was in the infinite-data limit. In practice, neural nets calculate their data from finite data. Because of the exponential accuracy scaling behavior I explained above, it is quite likely that after learning only a small number of parallel circuits (and maybe additionally “memorizing” any leftover examples), the neural net will obtain 100% accuracy on the test set. Note that 100% accuracy doesn’t imply 0 loss, and it may still be beneficial for the NN to learn additional circuits (again, “omnigrok” sees this occur). But another thing the neural net can start doing once it has 100% accuracy is to just bloat its logits. Namely, if a neural net has 100% accuracy, then it can always slightly improve its loss by just scaling up every logit by the same amount. Since (remember!) neural nets love linearity, it costs almost nothing for our model to do this, and indeed it’s a behavior that occurs (see omnigrok and also this paper). As logits get bigger, loss goes down exponentially and eventually it becomes impossible to learn any other circuits, both because SGD updates become too small to learn anything, and also because the associated exponential decay will tend to affect complex generalizing circuits even more than other “less general” directions in the loss landscape. In this context, we see that regularization prevents this from happening – and this happens not only on the level of “regularization discouraging memorizing” but also on the more interesting level of “regularization discouraging sitting on your laurels” assuming you (as the neural net) have accidentally learned enough circuits to correctly classify all your training data, and now just want to grab a bag of chips and sit in your bubble bath and bloat your logits – if you’re prevented from doing this, you’re more likely to end up learning more and more parallel circuits, and getting better and better out-of-distribution loss.
- ^
In the “deterministic classification” case they are deterministic functions, but in general of course they can be probabilistic functions
- ^
More generally, “about equally well-calibrated”, but we won’t look at this.
- ^
In modular addition, there is a bit of a wrinkle, as the training dataset is upper-bounded, the infinite-data limit is a bit tricky to conceptualize, and the behavior observed needs to be explained in a slightly more sophisticated way related to regularization below.
- ^
Up to some log corrections in the cross-entropy loss expression, which we can safely assume are O(1) – this log term is the same one that shows up in the Watanabe formula, for the SLT readers.
- ^
Something that we know happens in real life, e.g. from modular addition – note that when discussing very low loss ranges there is some optimizer discussion to be had, that I’m sweeping under the rug here.
0 comments
Comments sorted by top scores.