Dmitry's Koan

post by Dmitry Vaintrob (dmitry-vaintrob) · 2025-01-10T04:27:30.346Z · LW · GW · 0 comments

Contents

  The koan
  Elucidating the spectrum of precision
    Step 1: coming to terms with imprecision
    Step 2: Factoring in the memorization-generalization spectrum
  Natural scale and natural degradation
    Sometimes reconstruction loss is not the point
    Degradation as a dial
    Natural scale
    Natural degradation
    Experiment suggestions
    Possible issues
None
No comments

In this post I'll discuss questions about notions of "precision scale" in interpretability: how I think they're often neglected by researchers, and what I think is a good general way of operationalizing them and tracking them in experiments. Along the way I introduce a couple of new notions that have been useful in my thinking and that I think may be useful tools to keep in an interpretability toolkit, both for theorists and experimentalists: these are the notions of "natural scale" and "natural degradation".

The koan

I can be a nightmare conference attendee: I tend to ask nitpicky questions and apply a dose of skepticism to a speaker's claims which is healthy in doing one's own research, but probably not optimal when everyone else is trying to follow a talk. I'm working on being better at this, but for now I blame my background

There is one nitpick that comes up again and again. In fact in one conference I brought it up so often that Jake Mendel coined a term for it: "Dmitry's koan".

In koan form, the nitpick is as follows:

There is no such thing as interpreting a neural network. There is only interpreting a neural network at a given scale of precision.

On its face, this observation is true but a bit banal. Indeed there are two extremes:

  1. At the "less precise" extreme, you can claim you have interpreted a language model (such as the 80B parameter Llama model) by noticing that just tracking n-gram information for n up to 3 or 4 (and perhaps finding evidence that neural networks develop circuitry for such n-grams) lets you explain almost all the loss of this 80B parameter model: i.e., the difference of the cross-entropy loss between baseline loss (a transformer at initialization, with random weights) and the state-of-the-art Llama model is almost entirely "explained" by n-grams. The n-gram model is an extremely naive classification scheme that can be hard-coded without any learning on a personal computer. Does this mean that Llama is fully explained by the n-gram model?
  2. At the other extreme: if you want to interpret exactly what a neural network does, it's not enough even to understand the detailed mathematical abstractions encoded in the model's weights and their connections to the data, since even within an optimal mathematical interpretation, the neural network has approximation errors and noise. Does this imply that a sufficiently demanding interpretation must explain every bit of noise accumulated over training?

Of course these two extremes are silly (for people unfamiliar with LLMs: the n-gram model at the "less precise" endpoint recovers the majority of the cross-entropy loss, but because of how cross-entropy loss is defined, the measure of subjective "quality" of a model is better measured on something like a logarithmic scale: in particular, the n-gram model will get worse loss than GPT1 or an even much smaller transformer).

Most people[1] understand that the two extremes above shouldn't count as "interpreting" a model. However, as always, reality is more complicated. The two extremes occur in fractal fashion in a number of related contexts, where I think interpretability and ML papers have a bad track record of failing to correctly factor the takeaway from this koan. 

In this post I'll complain more about this, explaining some contexts where it's important to specify exactly where on the spectrum between "too precise" or "too coarse" you are aiming to be (as we'll see this is a particularly big problem when you're not trying to explain reconstruction loss directly, but and the precision scale is implicit). I won't give specific examples, ostensibly because I don't want to cause offense but really because I'm bad at literature searches (especially of the depressing "search for bad examples" variety). 

At the end, I will explain what I think is a good general solution that by and large "gets this right": i.e., how one can go about making experiments in interpretability correctly responsive to questions of loss precision. Finally, I'll explain why I would be excited for people to implement this fix more, and experimental contexts where a good analysis of this shape might give interesting new insights. 

Originally when writing this piece, I was planning to explain a special (but ubiquitous) reason why certain interpretability experiments may be particularly sensitive to questions of loss precision. Namely, due to the existence of parallel inference modes, some NN contexts exhibit a regime where the relationship between interpretation and precision has a sneaky but aggressive exponential factor. For reasons of time and readability, I ended up deciding to split this discussion to a followup post.

Elucidating the spectrum of precision

Step 1: coming to terms with imprecision

In putting down the "too precise" extreme, I intentionally suggested an egregiously silly amount of demandingness. No interpretability researcher wants to explain every bit of accumulated noise as part of their interpretability scheme. Obviously if you show that a neural network is implementing an idealized algorithm and carefully show how the weights are in fact giving an explainable approximation of the algorithm, that's enough. In fact an interpetability scheme should be considered suspicious if it doesn't factor in sources of imprecision. Neural nets are inherently messy stochastic systems and there are three sources of randomness that are essentially always there for any sufficiently nontrivial model:

  1. Noise: the world is noisy and infinitely detailed. The training data for all but the simplest toy models have some amount of noise in inputs and labels. Your picture of a cat will not be a platonically perfect cat: it will have imperfections due to pixellation, due to atmospheric phenomena and camera artefacts interacting with the integrity of the image; the cat's fur will be affected by accidents of dirt and discoloration. Labels may be garbled or imprecise. Etc. Similarly, text (though it is usually thought of as discrete, and thus seemingly less susceptible to noise than pictures) suffers from external noise: the writer may be affected by distractions in the environment, by texts read recently, and so on. While it's possible to capture some amount of this (e.g. mood) in a predictive speech generation process, there will always be some amount of sufficiently fine-grained random context (that mosquito bite behind your left shoulder that makes you remember a hiking trip with your grandpa and causes your writing to be more wistful) that ultimately must be abstracted out as noise by state-of-the-art ML systems. 
  2. Sample randomness: the training data is a finite random sample from an idealized infinite distribution. Even if you imagine that God had a perfect model of images of cats that accounts for pixelation, imprecision, and the like, the training data of your cat classifier does not have access to God-level amounts of data. Instead, it has access to some finite number of training examples. While these training examples may all be drawn from a single distribution, the specific samples that go into training are a random selection (all existing pictures of cats are a random sample from God's "true cat distribution"). This affects the classifier. Indeed, at a sufficiently fine level of precision, God's "true cat" distribution depends on an enormous number of parameters about our world that is (again, at sufficient levels of precision) orders of magnitude larger than the number of cat images -- thus even with perfect knowledge of possible models of cat distributions in various worlds, all existing cat images are probably not enough to specify all the latent parameters that describe the distribution in our world in particular.

    Note that even in toy contexts like modular addition, where you can easily train on "all possible data" and may think that the training data is exact and incorruptible, making good models requires making some statistical or noisy approximations. For example my favorite paper on modular addition and grokking abstracts out the combinatorial complexities of the discrete Fourier transform by modeling it as continuous fourier transform (this corresponds to viewing the residues 0, ..., p-1 mod a prime as p random samples of real-numbers residues mod undergoing a periodic process -- a common point of view when studying mod-p behaviors in ML).

  3. Training randomness and imperfection. The training processes used by NN's have implicit randomness and coarseness, given by initialization, batch selection, and macroscopic learning rate. Thus training does not return some "platonic ideal" neural net as a function of the data, but rather depends on random choices (even if we were to remove randomness: do full-batch updates, fix some natural initialization, etc., these would still be arbitrary choices that would be hard to model in a perfect mathematical way, and must ultimately be abstracted out as noisy or approximate phenomena.
  4. Approximation of functions by other functions. While this is similar to the previous point, I think it deserves its own item because it's particularly often ignored. Namely, the neural nets that have a more-or-less known mathematical interpretation are almost always understood as implementing (nice/smooth) abstract function, which usually can't be implemented exactly (e.g. all functions implementable by ReLUs are piecewise-linear, other activations will generally only able to approximately implement polynomials or exponents). This isn't a big deal: it's both abstractly possible and in practice "relatively easy" for a neural net to approximate a function learnable with one choice of activation by another choice of activation[2]. Both theory and experiment lead us to expect that in certain realistic contexts, the dynamics and learnability of neural nets doesn't significantly depend on the exact choice of activation functions[3]. However, whenever making use of such an approximation theorem, one must model the difference between the "idealized" function and the "realizable" approximation for the given architecture as an inherent "hard" source of noisy imprecision: in particular, no amount of data or training time can fully eliminate this.

Thus any realistic interpretability scheme must allow for sources of noise. As an abstract point, this is obvious and commonly known. It would be silly to expect each interpretability paper to carefully quantify and bound each of these four sources of noise, and it's perfectly acceptable to bundle everything into some generic error bounds on experiments. However, a phenomenon that I sometimes see in theory-adjacent papers is an attempt to carefully factor in one of these sources of noise, but forget that the others exist and may be dominant. This is particularly a problem in some thinking around SLT, where work of Sumio Watanabe gives a very elegant asymptotic bound on error source number 2 above in certain idealized networks. While this mathematical idealization often exhibits remarkably good predictive power on real-life neural nets[4] (one of the sources of excitement for SLT as a field), papers sometimes implicitly assume that the sample error noise scale analyzed by Watanabe is the only (or more precisely, the dominant) source of noise -- a problematic assumption when the other sources of noise may be more important, or interact with the "right" choice of idealization in a nontrivial way; we'll see an example of the latter phenomenon in a later section.  

Step 2: Factoring in the memorization-generalization spectrum

One way to neatly avoid having to be too careful about noise and imprecision is to say that an phenomenon found in a neural net is "relevant for interpretation" if it is an approximation (with implicitly understood sources of noise and imprecision) of a useful mathematical phenomenon -- i.e., a behavior (e.g. a "circuit") that, when mathematically abstracted out and idealized, helps the network obtain better loss. This can be validated either theoretically by constructing a full mathematical model, or experimentally by either somehow "ablating" the phenomenon and seeing the effect on loss, or conversely "cleaning up" the phenomenon by somehow "suturing in" the mathematical abstraction in place of the real-life messy component of the circuit, and seeing the effect on loss. Note that both of these experimental methods have significant issues, but we're not here to discuss the problematics of causal intervention studies on neural nets. 

One can hope that with sufficient advances in interpretability, it may be possible to mathematically abstract out all "useful" behaviors of a neural net. I'll have more to say in later posts about the (un)desirability of maximally ambitious interpretability targets, but for now I want to observe that trying to identify all marginally useful behaviors is an unrealistic and ultimately unnecessary mess

Indeed, there is reason to believe (coming from toy-model interpretability, effective dimensionality studies, and student-teacher experiments) that neural nets "only use a fraction of their parameters to generalize". In other words, there are many directions (unfortunately not the same thing as neurons because of polysemanticity, though even restricting to neurons makes this insight clear) inside a neural net that can be viewed as "free parameters": changing the weights along these directions doesn't seem to impact performance much, and has especially little effect on held-out examples. Now if you put yourself in the brain of a neural net (something I will often be suggesting you do, though you must do so carefully), noticing "free" directions in your program parameters means you have extra "unstructured memory" to spare[5]. And this unstructured memory can be used to memorize. In fact, there are a number of both experimental and (pretty strong) theoretical results that show that under extremely weak restrictions, each 1-dimensional direction of unused memory (whether or not it is neuron-aligned) can be used to correctly memorize one training example[6]. Thus if (as is often observed), a typical MNIST model only "really uses" at most 10% of its memory parameters, it is free to use the remaining 90% to memorize confusing datapoints. This might not happen in real life because models tend to be undertrained, but can be safely assumed to be possible (and indeed to occur) with sufficient training.

Now each of these memorized datapoints improves loss by a small amount, thus is "useful" in the sense above. However it would be a massive headache to demand from an interpretability scheme that it correctly explain every memorization circuit: what parameters it uses, how it works, why it doesn't conflict with the generalizing circuits. Any interpretability scheme with a snowball's hope in hell of being useful must be able to disclaim off the bat that "spurious" but loss-improving behaviors that only apply to a specific datapoint or two shouldn't need to be mathematically formalized, at least when looking at interesting scales. This tells us that a naive way you may imagine getting around Dmitry's koan, of saying that the "right scale to consider is the scale that captures all useful behaviors" is unreasonable. 

This suggests a next-level guess at the appropriate scale of precision, which is precision that "captures all behaviors that are useful for improving test loss". This lets us ignore behaviors and circuits that explicitly memorize. Still, this doesn't get rid of the issue. You see, the "test loss" vs. "training loss" dichotomy is only a first-order stab at the much deeper question of "what is generalization". In practice, NN phenomena exist on a spectrum between memorization and generalization. While the "memorizing" end of this spectrum has a well-defined limit: circuits that memorize a single input datapoint, there are many phenomena that help classify a "cluster" of datapoints that does exist in both the training and test datasets, but may not be important enough to interpret. For instance, maybe a quote from a niche genre fiction is shared on the internet a few dozen times by committed fans, and these few dozen quotes make their way into different training documents for an LLM. Then a circuit that memorizes this particular quote is technically a generalizing circuit: chances are, the quote will appear both in the training and test data. However, it's a stretch to say that this circuit is of comparable generality to a mechanism encoding concepts related to Paris or python commenting conventions. Indeed, probably in an "ideal" interpretability scheme, such a circuit should be compressed out into "we expect the world to contain a number of quotes from Jane Austen-inspired fan fiction about humanoid cat pirates, and will model some not-super-relevant parts of our neural net as containing circuits related to passages thereof".  

More formally, algorithms implemented by a neural net can be placed on a number of more sophisticated memorization-generalization spectra, associated for instance to "how likely is the net to make essential use of this algorithm in any given (non-training) text-completion task". I discussed a representative example of such a spectrum in my subgrammars [LW · GW] post, and discussions about such phenomena abound in interpretability-adjacent ML discussions (see for example this paper, and other studies on compositionality and generalization). 

Thus ideally, an interpretation of an ML algorithm should target a specific place in the memorization-generalization spectrum: identify behaviors that are not only useful, but have a suitable degree of generality. Of course in practice, this is very hard to gauge (and even harder to verify that you have somehow "found all circuits at a given level of  generality"). Instead, one is forced to quantify measures of generality or importance controlling the "precision" of your interpretability work by using more pragmatic proxy measures. There's a lot of room for playing around and trying to find better proxies here, but one basic and reasonable proxy is loss precision on test data. Namely, assuming you have an "end-to-end" candidate interpretation of a neural net, you can quantify "how precise it is" by how well it explains the loss, and say that, at a given level of loss precision, interpretation A is better than interpretation B if A "looks better as an interpretation". This of course opens up a whole other bag of worms: do you use "description length" or "modularity" or "human understandableness" as your goal for a "good" intepretation. But these debates are standard and visible in this community (a favorite treatment of mine is contained in Lee Sharkey's distillation of Apollo's "sparsify agenda" [LW · GW]). I'm not here to engage in long chains of collaborative knowledge-building: I'm here to nitpick.

There are still some significant operationalization issues here. First, most interpretability work (at least at present) doesn't aim to reconstruct a NN end-to-end, but rather to find more local reproducibly understandable patterns. Second, just saying that "loss precision is an essential parameter in discussing interpretability schemes" doesn't tell you what loss precision scales are interesting. I'll discuss both of these issues in the next section.

Natural scale and natural degradation

In this section I'll give an explicit proposal for how to operationalize and choose loss scales in realistic interpretability work. The proposal is significantly inspired by work that has come out of SLT research, though is theoretically independent on it (and in particular, is on the pragmatic side of the theory-pragmatism divide).

Sometimes reconstruction loss is not the point

Most interpretability work to date finds localized phenomena in neural nets. The notion of locality here is vague and tricky to operationalize, but roughly they might

This degree of specificity is not shared by all interpretability work (e.g. SAE work does not depend on a small collection of specific phrases). But the idea of "looking at localized phenomena" is present to some extent in all interpretability work that treats sufficiently complex models (including toy models!). For work of this type, it is unreasonable (at least directly) to view its reconstruction loss as any kind of precision scale (and often reconstruction loss in such work is not useful or hard to operationalize). 

Degradation as a dial

How, then can we operationalize the loss scale of a phenomenon? Well, one way to do this is to imagine that we have some "natural" complexity parameter c that can be varied (this can be a parameter tuning model size, training length, etc.). We denote the resulting class of (so far theoretical) models M_c. If possible, we would like models in this class to be "locally simultaneously interpretable", i.e. that for two nearby values , the models M_c and M_c' have similar weights and implement similar circuits. This is in particular the case if M_c is the training checkpoints (i.e., weights during pretraining) with the complexity parameter c measuring the fraction of training time, but this isn't strictly necessary in general (this will be made more precise in the next section).  

 We require that at c = 0, our program outputs a fully random (in some appropriate sense) classification -- for example, this is true if we take M_0 to be a randomly initialized neural net that has undergone no training; we view its loss, L_0, as the "baseline" loss, a kind of upper bound on our loss scale. For c = 1, we set M_1 to be the "model organism" neural net that we are studying in our experiment. At the other end, we ask that for c = , the model has perfect loss[7], or at least "very good loss", , corresponding to a significantly more sophisticated model than the one we are performing interpretability on. 

Now it's unrealistic to ask that we actually implement examples of M_c for c>1: perhaps the model M_1 we are studying is a state-of-the-art model, and improving on it requires a few billion dollars spare cash. However, we do assume that we have some kind of ability to perform experiments on models M_c for c<1. We'll call models M_c for c<1 "degradations" of M.

In this case, we can use the following process to quantify the loss precision of our interpretability result. First, we operationalize the result (say we have "found a circuit") in a formal way. This can be a "prediction experiment": we check whether some interpretability-flavored mathematical prediction holds on a fixed corpus of inputs in a statistically significant way, for the model M. Alternatively, the experiment can be a measurement, that outputs some (hopefully interpretability-relevant) invariant of a model f(M). Now say that we want to perform the experiment at "loss scale L", with L_0 > L > L_1. Then the recipe is perform an approximation of the following experiment:

  1. Empirically measure the loss L_c of M_c for complexity c < 1, as a function of c.
  2. Find the cutoff parameter c*<1 for which L_c* = L (at least approximately).
  3. Run our interpretability experiment for the model M_c*.

Of course doing this in general is expensive and questionably useful. For one, it's expensive to measure something for a bunch of setpoints of a sophisticated algorithm (even if all we're measuring is the loss), not to mention that for many important open-source systems, the setpoints are not publicly available. We'll address these issues and more in the following section. 

Next, a keen reader will observe that in examples such as IOI, there's absolutely nothing wrong with just running the experiment at characteristic loss L = L_1 associated to the model under investigation itself, M_1 = M. In this case I'm just saying "you should run your experiment, but put 'underscore 1' indices on everything", not the most useful piece of advice. In the previous sections we discussed that being "too ambitious" about working with characteristic loss equal to the loss of the model under consideration (for example, requiring full loss reconstruction) means that if you want sufficiently ambitious coverage for your interpretability results, you will end up dealing with a bunch of garbage behaviors like memorization or "partially memorized" quotes from pirate-themed Jane Austen fan fiction. However, if your experiment, like IOI, is not very ambitious (in the sense of "going for localized completeness of interpretation"), then it's plausible that this doesn't cause problems: you run your experiment on the fully trained model, get a positive result, and publish a paper[8]

However, the usefulness of the picture I'm proposing emerges when the experiment you are performing has its own internal precision or loss scale. For example, one of my favorite interpretability papers is "Look Before you Leap", which observes that in certain carefully designed contexts, an activation patching from a phrase A to a phrase B will result in completions of B giving responses using contextual information from A; but this phenomenon occurs only for patches on late layers, and gets corrected (with the transformer "fully correcting" to faulty activation from B to the true context from A if the patch is performed on early layers). A "soundbite" summary of the result of this paper is that (for a suitable notion of sentence context, and for a suitable class of examples), all the context-dependent information of a transformer task is integrated in early layers, with later layers only performing post-processing on known context

As soon as an interpretability result can be (even approximately) described as "fully" characterizing a particular behavior, the scale of precision becomes relevant. (The beauty of the "look before you leap" paper is that it has exactly the right degree of coarseness in its experimental method: quantitatively distinguishing behavior at "early" vs. "late" layers, to have a chance of legitimately capturing some "generally applicable" information about the model's internal workings).

In the "look before you leap" example, we can then note that in general, activation patching experiments degrade performance, simply because you're as it were introducing an "alien" behavior into a network, which corresponds at best to a rough refactoring of its internal mechanism. Thus an interesting experiment would be to rerun the experiment for a collection of degraded neural nets M_c for various values of c, and corresponding loss scales L_c, and comparing the degraded loss L_c to the patched reconstruction loss. A particularly nice result here would be if for some value of the degradation c, it were the case that the reconstruction loss for patching at sufficiently late layers were equal or very close to the inherent loss of L_c on the experimental dataset. If this were the case, this would be definitive evidence that, when considered at "suitable loss scales", it is indeed the case that late layers exclusively (or "almost exclusively") perform postprocessing. 

So far, I've explained that performing experiments on degraded models can be useful, and can give much more principled ways of discussing the "completeness" of interpretability phenomena. However, the questions of "optimal degradation" and "interesting scale" still remains unanswered. As mentioned, one possibility for the former question is to use training checkpoints, but it's not clear that this is a very good choice. In particular, if we are interested in distinguishing "more general" from "more memorize-y" behaviors, training checkpoints are probably not the way to go: training checkpoints will often start out by accumulating "less general" behaviors before eventually learning to generalize; we would like our degradations to have, at least roughly, the opposite behavior of holding on to "all the most general behaviors of suitably bounded complexity".

Natural scale

It's famously difficult to get a "principled" measurement of LLM capable. However, we know that GPT4 is more capable than GPT3.5, is more capable than GPT3, etc. Of course each iteration of GPT changed a lot more than just the parameter count, but to a first order approximation, parameter count is the core difference between the different models. Now for any reasonable conceptualization of complexity of a series of models (which might be the number after the "GPT", or a more mathematically principled parameter-count scaling dial) we get a reasonable notion of loss precision (by computing the loss)[9]. The notion of "degradations" introduced in the previous section (and which will be better-operationalized in the next) gives us a natural way to reason about the complexity of phenomena. For example, if Anthropic finds some very nice SAE-inspired decomposition with good reconstruction loss (something that is not yet available), a phrase we might hear in the future is "the reconstruction loss of a 100B parameter feature-by-feature interpretation of Claude 5 is comparable to the performance of Claude 2" (a massive triumph of interpretability, if these words ever get written), or perhaps "comparable to a natural degradation of Claude 5 that obtains the same loss as Claude 2" (an even bigger triumph). Conversely, we can separate complexity measurements mediated by loss from complexity mediated by architecture and parameter count, while putting both on the same scale. A phrase of this type that I am more optimistic of hearing is something like "Claude 5 retains good performance on the International Math Olympiad benchmark when degraded to the loss precision of Claude 2" which would imply a strong architecture-dependent decoupling of loss and capability (something that most people expect to take place). 

Once we have two models of very different complexity, like Claude 5 and Claude 2, another thing this gives is a natural approximate loss scale associated with the weaker model (Claude 2 in this case), which is difficult to obtain by looking at only one model. Namely, if we imagine some complexity measurement with Claude 5 and Claude 2 being two instances at different values of, we can conceptualize Claude 2 as the complexity c = 1 "base" model and Claude 5 as an approximation of the complexity "perfect" LLM oracle[10] (note that this is a fundamentally different complexity dial from the "natural degradation" dial which we will introduce in the next section). We can then say that a natural scale to run experiments on Claude 2 is its own "absolute performance gap", i.e. , approximated as L(Claude 2) - L(Claude 5). 

Natural degradation

In this section I'll finally give my proposal for how to operationalize the rough discussion in the previous section in what seems to me to be a maximally sensible way. 

Note that the core property we want from the degradations M_c for c<1 is that they have lower loss than M. There are many ways to make this happen: since models are, at least approximately, local loss minima, most ways of modifying M -- whether random or directed, will degrade loss. However, I claim that there is one right way. Namely, the way that any neural net is generated is by some gradient-assisted search procedure through a weight landscape. A priori, there is a giant family of possible neural nets M_w associated to various weight vectors in a giant vector space of parameters. Each weight vector w has an associated loss L_w = L(M_w). The fully trained network is then M_w* for some fixed locally (approximately) loss-minimizing parameter w*. Now for a degradation at some intermediate loss L, we would like to ideally choose a "degraded" neural net M_w which:

The idea is now to flip this and choose w to be a random weight that is not far from w*, and has loss equal to (or approximately equal to) L. 

This might seem impossibly hard: I'm asking for a process that trawls through the enormous space of all neural nets M_w (even imposing the condition that "w is close to w*" barely makes a dent in its enormousness), then finding all the ones that have a particular high-level behavior (loss), then sampling them at random. But it turns out that exactly such an algorithm exists, and is used with remarkable success, and depending on how strict you want to be about your sampling being "unbiased", its computational cost tends to be somewhere between finetuning and retraining. 

Algorithms that sample points in a large parameter space with some particular behavior are called sampling algorithms, and the sampling algorithm that is usually used in this context is the "SLGD" or "Langevin SGD" algorithm (that works by combining gradient descent steps with noise steps at an appropriate scale). This is the bread-and-butter algorithm of all empirical work in SLT, and was introduced and tested in this context in Edmund Lau et al.'s paper on the local learning coefficient. From the point of view of interpretability, the Langevin algorithm can be conceptualized as balancing entropy and loss.  The core conceptual property that the sampling process implemented by this algorithm tries to capture is the following:

Find a maximally general algorithm M_w in the same basin as M = M_w*, which implements the same task as M, but on the degraded loss scale L.

In other words (and modulo small text that we mostly won't bother with), the algorithm can be conceptualized as identifying the optimal compression[11] of the algorithm implemented by M that still obtains loss L, and noising out all circuits whose information content is too large compared with their contribution to loss. Thus if possible, M_w will throw away all memorization and "partial memorization" behaviors, and only keep "the good stuff" (which itself will start degrading once we set the loss scale to be high enough to be able to start throwing away "interesting" general circuits).

I want to suggest that this "natural degradation" procedure has the key properties we would want from a dial that lets us adjust the "loss precision scale" of experiments. Namely it is:

It also has the added bonus property that it comes prepackaged with an easy-to-measure additional empirical scale parameter, called the "local learning coefficient", which has the units of parameter count. In other words, in addition to using this dial to see how experimental results change at different loss precision scales, one can instead interpret it as a dial measuring how results change at a certain natural complexity scale capturing information related to parameter count (more precisely, this captures the parameter count of the "optimal compression" of the model at the given scale, for a suitable operationalization of this notion).

Experiment suggestions

I'll conclude by suggesting a few experiments (Another experimental suggestion can be found above, in the discussion on "Look before you Leap"). These are far from a comprehensive list, and I think that the field of interpretability would benefit from loss precision-sensitive experiments run in a number of contexts. 

  1. As we've discussed above, the local learning coefficient measurement in SLT attempts to measure one operationalization of the effective parameter count of the algorithm implemented by a neural net. The local learning coefficient inherently depends on a loss precision scale (more or less synonymous in this context with "temperature"). Currently, most experiments of this form use a precision scale tuned to the "Watanabe critical temperature", which is determined uniquely by the size of the training set (and tuned to behave well with respect to sample noise). It seems unlikely that something like MNIST or Bert will significantly change its behavior when trained at different OOM datasets[12]. However, the above notion of natural scale suggests a different loss scale to use: namely, the "absolute performance gap" given by the difference in loss between Bert and a SOTA base-model LLM's performance on the same training dataset. It would be an interesting experiment to see how much the LLC changes between the Watanabe scale (determined by input number) and the natural scale (determined by loss). The measurement is designed to be quite stable to scale variations (under some idealized assumptions on the loss, but also in practice), but there are reasons to expect that the two ranges will give interestingly different results.
  2. Related to the above, the combined notions of natural scale and natural degradation give a certain new natural operationalization of separating "memorization" behaviors from "generalization" behaviors (analogous to the notion of a standard deviation in statistics). Namely, one can formally say that a phenomenon is "generalization-like" for a primitive language model if it is retained upon naturally degrading a model by a loss precision comparable to its absolute loss gap (measured as the difference in performance between itself and a state-of-the art base model LLM). This will probably identify even certain behaviors that improve test loss as "mostly memorization-y". Empirically analysing the difference between generalization and memorization conceptualized via this natural scale can be an interesting new way of operationalizing the "memorization-generalization spectrum".
  3. One class of experiment that is crying out to be done is to measure the "generalization penalty  of finetuning". Namely, it is widely believed by interpretability researchers that most finetuning procedures vastly degrade the "generalization properties" of a model. Operationalizing and measuring this "generalization penalty" seems valuable both for thinking about finetuning, and various alignment risks.

Possible issues

I would be excited about people thinking more about loss precision in experiments, and the notions of natural scale and natural degradation. However, it is of course possible this isn't an interesting framework to consider. There are also some general issues that one should be careful about. One issue, that I've previous mentioned in a footnote, is that it is tricky to reason about natural loss scales in the presence of finetuning, since finetuning degrades loss in an unpredictable way. Another issue is that, on the one hand, some SOTA models are regularized, and on the other hand many unregularized (or "insufficiently regularized") transformers can cheaply improve loss simply by scaling up their largest logits (one way to avoid this issue entirely is of course to only measure accuracy). When reasoning about natural loss, one would have to separate various "inconsequential" reasons for artificially high or low loss from more fundamental, "complexity-relevant" reasons. 

Finally, it's not obvious whether loss precision is a very good precision measurement, and whether the related notion of natural degradation is a very good way to vary scale. It's also not obvious that the natural precision of an LLM (i.e., its difference from optimal loss) is a particularly useful scale for separating generalization from memorization behaviors in LLMs. 

In fact, it's likely that in many contexts, better scale parameters exist. In particular Lucius Bushnaq at Apollo is interested in different notions of complexity related to size of circuits that seem promising, and other approaches to operationalizing notions of complexity exist (my colleague Lauren Greenspan has a post in the works that discusses different notions of scale that physical considerations consider studying. 

Loss precision and natural degradation are simply one concrete attempt to formalize a complexity that allows reasoning more precisely about completeness of explanation and characteristic scale of phenomena in NN experiments. I would be excited for new and better notions to appear. At the same time, I am relatively confident that an ability to discuss the characteristic scale of phenomena, imprecisely compare different notions of scale, and vary the characteristic scale of a model are components of the interpretability paradigm that deserve more attention and coordinated exploration.

 

  1. ^

    There are exceptions: on the "less precise" extreme, some papers excitedly claim to have excellent loss reconstruction or interpolation when explaining less than a bigram amount of the cross-entropy loss; but this is rare.

  2. ^

    I'll give a neat example of this later when discussing a joint paper with Jake Mendel and Kaarel Hänni on computation in superposition.  

  3. ^

    Note that the claim of "activation function independence" should be taken with a grain of salt. While in shallow networks, it's a safe bet that the details of the activation function don't matter, deep networks are known to be more sensitive to the choice of activation function: this is beautifully analyzed in the physics-inspired PDLT opus, which Lauren and I will be distilling this month. 

  4. ^

    See for example this work [LW · GW] joint with Nina Panickssery; at this point there are a number of other results observing this surprising effectiveness in other contexts.

  5. ^

    There's a reason I'm calling the leftover memory "unstructured". Because of its inherent randomness and risk of interacting with the "structured" memory, it's not necessarily the case that a NN can learn sophisticated new circuits in these directions if training is extended or improved. However, the unstructured memory is good enough for learning "simple" circuits. 

  6. ^

    There is a bit of fuzz here... but if you replace "exactly 1" by O(1), this observation holds in incredible generality.

  7. ^

    Note that in the case of cross-entropy loss, perfect loss is not 0, but a fixed lower bound related to the entropy of text.

  8. ^

    Since I know there will be comments about this otherwise: yes, I know that IOI has lots of problems as an interpretability experiment. If this bothers you, replace IOI with your favorite alternative interpretability experiment, or imagine an alternative universe where "productive" interpretability mechanistic interpretability experiments exist.

  9. ^

    Here in order to really compare loss, we should assume we're only comparing base models; one can perform similar analyses for finetuned models assuming some uniform measure of "post-finetuning loss", or alternatively by measuring loss on an artificial test set produced by an analogously finetuned model.

  10. ^

    Really, the only goal of this oracle will be to find a good approximation of the "true entropy of text", something that's famously difficult to get exactly right, and in some sense from a complexity viewpoint equivalent to perfect prediction.

  11. ^

    Small text: approximately, in the complexity measure given by the local learning coefficient

  12. ^

    Here a critical observation is that both classic MNIST and Bert models are underparametrized: in Bert's case, it's a 100M-parameter model trained on about 3B words. This suggests -- though doesn't prove -- that sample noise has less relevance for performance and interpretability measurements than architecture- and nature- dependent sources of noise.

0 comments

Comments sorted by top scores.