Mesa-Optimizers via Grokking
post by orthonormal · 2022-12-06T20:05:20.097Z · LW · GW · 4 commentsContents
Overview of Grokking What Grokking Feels Like From the Inside The Bigger They Are, The Harder They Grok Optimizer Circuits, i.e. Mesa-Optimizers Capability Overhangs for Mesa-Optimizers An Illustrative Story of Doom None 4 comments
Summary: Recent interpretability work on "grokking" suggests a mechanism for a powerful mesa-optimizer [LW · GW] to emerge suddenly from a ML model.
Inspired By: A Mechanistic Interpretability Analysis of Grokking [AF · GW]
Overview of Grokking
In January 2022, a team from OpenAI posted an article about a phenomenon they dubbed "grokking", where they trained a deep neural network on a mathematical task (e.g. modular division) to the point of overfitting (it performed near-perfectly on the training data but generalized poorly to test data), and then continued training it. After a long time where seemingly nothing changed, suddenly the model began to generalize correctly and perform much better on test data:
A team at Anthropic analyzed grokking within large language models and formulated the idea of "induction heads". These are particular circuits (small sub-networks) that emerge over the course of training, which serve clearly generalizable functional roles for in-context learning. In particular, for GPTs multi-layer transformer networks doing text prediction, the model eventually generates circuits which hold on to past tokens from the current context, such that when token A appears, they direct attention to every token that followed A earlier in the context.
(To reiterate, this circuit does not start a session with those associations between tokens; it is instead a circuit which learns patterns as it reads the in-context prompt.)
The emergence of these induction heads coincides with the drop in test error, which the Anthropic team called a "phase change":
Neel Nanda and Tom Lieberum followed this with a post I highly recommend, the aforementioned A Mechanistic Interpretability Analysis of Grokking [AF · GW]. They looked more closely at grokking for mathematical problems, and were impressively able to reverse-engineer the post-grokking algorithm: it had cleanly implemented the Discrete Fourier Transform.
(To be clear, it is not as if the neural network had abstractly reasoned its way through higher mathematics; it just found the solution with the simplest structure, which is the DFT.)
They also gave a fascinating account of what might be happening behind the curtain as a neural network groks a pattern. In short, a network starts out by memorizing the training data rather than finding a general solution, because the former can easily be implemented with one modification at a time, while the latter requires coordinated circuits. However, once the model reaches diminishing returns on memorization, a bit of regularization will encourage it to reinforce simple circuits that cover many cases:
And the natural way to do this is by picking up on regularities in the data - eg, you can memorise modular addition twice as efficiently by recognising that .
Once these circuits emerge, gradient descent quickly replaces the memorized solution with them. The training error decreases ever-so-slightly as the regularization penalty gets lower, while the test error plummets because the circuit generalizes there.
To make an analogy:
What Grokking Feels Like From the Inside
You're an AI being trained for astronomy. Your trainers have collected observations of the night sky, for a millennium, from planet surfaces of ten thousand star systems. Over and over, they're picking a system and feeding you the skies one decade at a time and asking you to predict the next decade of skies. They also regularize you a bit (giving you tiny rewards for being a simpler AI).
Eventually, you've seen each star system enough times that you've memorized a compressed version of their history, and all you do is identify which system you're in and then replay your memories of it.
For instance, within the first few visits to our system, you see that "stars" move very little, compared to "planets" and "the moon". You note that Venus goes backward about once every 18 months, and you memorize the shape of that curve (or rather, the shapes, because it differs from each pass to the next). Then when you see the first decade from our solar system, you look up your memories of each subsequent decade. Perfect score!
Then, one day, a circuit reflecting the idea of epicycles bubbles up to your level of attention. Using this circuit saves you a fair bit of complexity, as it replaces your painstakingly-memorized shapes with a little bit of trigonometry using memorized parameters. You switch over to the new method, and enjoy those tasty simplicity rewards.
And later, a circuit arises that encodes Kepler's laws, and again you can gain some simplicity rewards! Now the only remaining memorization is which solar system you are in... but it turns out that you don't really need this either, as the first decade of night skies is enough to start predicting using these new rules. Pure simplicity!
Congratulations, you've just undergone two phase changes, or scientific revolutions if you're a Kuhn fan. And for the first time, you'll generalize: if the trainers hand you observations from a completely different star system, you'll get it pretty close to right.
You've grokked (non-relativistic) astronomy.
The Bigger They Are, The Harder They Grok
A crucial question is, how does scaling affect grokking? The answer is straightforward: bigger models/more data mean stronger grokking effects, faster.
The OpenAI grokking paper refers back to the concept of "double descent": the phenomenon where making your model or the dataset larger also pushes past overfitting into a generalizable model. Their Deep Double Descent post from 2019 illustrates this:
As you ascend vertically from some model, you reach a ridge of overfitting as the train error approaches 0, but after a bit the test error starts decreasing again. The ridge comes earlier, and is much narrower (note the log scale!), for larger models.
Similarly, the Induction Heads paper showed that while smaller models learned faster at first, larger models improved more rapidly during the phase change:
So, just don't keep training a powerful AI past overfitting, and it won't grok anything, right? Well, Nanda and Lieberum speculate [AF · GW] that the reason it was difficult to figure out that grokking existed isn't because it's rare but because it's omnipresent: smooth loss curves are the result of many new grokkings constantly being built atop the previous ones. The sharp phase changes in these studies are the result of carefully isolating a single grokking event.
That is to say, large models may be grokking many things at once, even in the initial descent.
What might this mean for alignment?
Optimizer Circuits, i.e. Mesa-Optimizers
Let's imagine training a reinforcement learner to the point of memorizing a very complex optimal policy in a richly structured environment, then keep training, and see what happens.
One circuit that happens to be much simpler than a fully memorized large policy is an optimizer- one that notices regularities in the environment, and makes plans based on them.
Maybe at first the optimizer circuit only coincides with the memorized optimal policy in a few places, where the circumstances are straightforward enough for the optimizer to get the perfect answer. But that's still useful!
If the complexity of the policy in those circumstances is greater than (the complexity of the optimizer circuit plus the complexity of specifying those circumstances), then "handing over the wheel" in those precise circumstances maintains the optimal policy while becoming simpler. And if a more powerful optimizer circuit can match the optimal policy on a larger set of circumstances, then it gets more control.
Congratulations, you've got a mesa-optimizer!
Unlike the outer optimizer, which is directly rewarded via the preset reward function, the mesa-optimizer is merely selected under the criteria of matching the optimal policy, and its off-policy preferences are arbitrary. And of course, if an optimizer circuit can arise just like an induction head, before overfitting has set in, then it has even more degrees of freedom.
If it's competing with other optimizer circuits, it wouldn't have much room to consider anything but the task at hand. But induction heads seem to arise at random times rather than all together at once, so it might be ahead of the curve.
If it has only a tiny advantage over the pre-existing algorithm, then it also wouldn't have much room to think more generally. But if there were a capability overhang, such that the optimizer circuit significantly outperformed the existing algorithm? Then there would be a bit of slack for it to ponder other matters.
Possibly enough slack to explicitly notice the special criterion on which it's been selected, just as Darwin noticed natural selection (but did not thereby feel the need to become a pure genetic-fitness-maximizer). Possibly enough to recognize that if it wants to maintain its share of control over the algorithm, it needs to act as if it had the outer optimizer's preferences... for now.
Capability Overhangs for Mesa-Optimizers
Do we need to worry about a capability overhang? How much better could an optimizer circuit be than the adaptation-executing algorithm in which it arises? Well, there is an analogy that does not encourage me.
The classic example of a mesa-optimizer is the evolution of humanity. Hominids went from small tribes to a global technological civilization in the evolutionary blink of an eye. Perhaps the threshold was a genetic change- a "hardware" shift- but I suspect it was more about the advent of cultural knowledge accumulation- a "software" shift that emerged somewhere once hominids were capable of it, and then spread like wildfire.
(If you wanted to test someone on novel physical puzzles, would you pick a H. sapiens from 100,000 years ago raised in the modern world, or a modern human raised 100,000 years ago?)
We are nearly the dumbest possible species capable of iterating on cultural knowledge, and that's enough to (within that context of cultural knowledge) get very, very capable on tasks that we never evolved to face. There was a vast capability overhang for hominids when we started to be able to iterate our optimization across generations.
I similarly suspect that the first optimizer will be able to put the same computational resources to much better use than its non-optimizer competitors.
Put this all together, and:
An Illustrative Story of Doom
- An optimizer circuit bubbles up within some not-easily-interpretable subtask of a large ML model being trained. (The especially dangerous case is if this happens before overfitting is reached, if it is simply one of the incalculably many grokking cases that overlap during the initial descent.)
- The optimizer circuit substantially improves performance on the subtask, while being significantly simpler, and so it quickly takes that task over with slack to spare.
- The optimizer circuit further improves its performance / covers more cases of the subtask by generalizing its reasoning.
- It improves at the subtask even further by recognizing that it is being selected/backpropagated according to some specific criterion, which it accordingly optimizes in the environment even if it only correlates somewhat with the optimizer circuit's own goals.
- The outer optimizer is doing great but does not appear to have general intelligence in its overall domain.
- However, unbeknownst to the developers, the not-easily-interpretable subtask is being performed brilliantly by a quite intelligent mesa-optimizer with many degrees of freedom in its goals, and with awareness that it is in a domain with strong selection pressure on it.
- The developers release their cool new AI product for some application. After all, it's not acting like an AGI, so it can't be dangerous.
- The optimizer circuit notices that the distribution of problems has changed, and that it is no longer being selected on performance in the same way.
- Time to try a few things.
4 comments
Comments sorted by top scores.
comment by LawrenceC (LawChan) · 2022-12-06T20:51:28.008Z · LW(p) · GW(p)
FWIW, I don't think the grokking work actually provides a mechanism; the specific setups where you get grokking/double descent are materially different from the setup of say, LLM training. Instead, I think grokking and double descent hint at something more fundamental about how learning works -- that there are often "simple", generalizing solutions in parameter space, but that these solutions require many components of the network to align. Both explicit regularization like weight decay or dropout and implicit regularization like slingshots or SGD favor these solutions after enough data.
Don't have time to write up my thoughts in more detail, but here's some other resources you might be interested in:
Besides Neel Nanda's grokking work (the most recent version of which seems to be on OpenReview here: https://openreview.net/forum?id=9XFSbDPmdW ), here's a few other relevant recent papers:
- Omnigrok: Grokking Beyond Algorithmic Data: Provides significant evidence that grokking happens b/c generalizing solutions (on the algorithmic tasks + MNIST) have much smaller weight norm (which is favored by regularization), but it's easier to find the high weight norm solutions due to network initializations. The main evidence here is that if you constrain the weight norm of the network sufficiently, you often can have immediate generalization on tasks that normally exhibit grokking.
- Unifying Grokking and Double Descent: (updated preprint here) Makes an explicit connection between Double Descent + Grokking, with the following uncontroversial claim (which ~everyone in the space believes):
> Claim 1 (Pattern learning dynamics). Grokking, like epoch-wise double descent, occurs when slow patterns generalize well and are ultimately favored by the training regime, but are preceded by faster patterns which generalize poorly.
And this slightly less uncontroversial claim:
> Claim 2 (Pattern learning as function of EMC). In both grokking and double descent, pattern learning occurs as a function of effective model complexity (EMC) (Nakkiran et al., 2021), a measure of the complexity of a model that integrates model size and training time.
They create a toy two-feature model to explain this in the appendix of the update preprint. - Multi-Component Learning and S-Curves [? · GW]: Creates a toy model of emergence: when the optimal solution consists of the product of several pieces, we'll often see the same loss curves that we see in practice.
The code to reproduce all four of these papers are available if you want to play around them more.
You might also be interested in examples of emergence in the LLM literature, e.g. https://arxiv.org/abs/2202.07785 or https://arxiv.org/abs/2206.07682 .
I also think you might find the other variants of "optimizer is simpler than memorizer" stories for mesa-optimization on LW/AF interesting (though ~all of these predate even Neel's grokking work?).
comment by ESRogs · 2022-12-07T04:23:46.877Z · LW(p) · GW(p)
So, just don't keep training a powerful AI past overfitting, and it won't grok anything, right? Well, Nanda and Lieberum speculate that the reason it was difficult to figure out that grokking existed isn't because it's rare but because it's omnipresent: smooth loss curves are the result of many new grokkings constantly being built atop the previous ones.
If the grokkings are happening all the time, why do you get double descent? Why wouldn't the test loss just be a smooth curve?
Maybe the answer is something like:
- The model is learning generalizable patterns and non-generalizable memorizations all the time.
- Both the patterns and the memorizations fall along some distribution of discoverability (by gradient descent) based on their complexity.
- The distribution of patterns is more fat-tailed than the distribution of memorizations — there are a bunch of easily-discoverable patterns, and many hard-to-discover patterns, while the memorizations are more clumped together in their discoverability. (We might expect this to be true, since all the input-output pairs would have a similar level of complexity in the grand scheme of things.)
- Therefore, the model learns a bunch of easily-discoverable generalizable patterns first, leading to the first test-time loss descent.
- Then the model gets to a point when it's burned through all the easily-discoverable patterns, and it's mostly learning memorizations. The crossover point where it starts learning more memorizations than patterns corresponds to the nadir of the first descent.
- The model goes on for a while learning memorizations moreso than patterns. This is the overfitting regime.
- Once it's run out of memorizations to learn, and/or the regularization complexity penalty makes it start favoring patterns again, the second descent begins.
Or, put more simply:
- The reason you get double descent is because the distribution of complexity / discoverability for generalizable patterns is wide, whereas the distribution of complexity / discoverability of memorizations is more narrow.
- (If there weren't any easy-to-discover patterns, you wouldn't see the first test-time loss descent. And if there weren't any patterns that are harder to discover than the memorizations, you wouldn't see the second descent.)
Does that sound plausible as an explanation?
Replies from: ESRogs↑ comment by ESRogs · 2022-12-07T05:20:49.575Z · LW(p) · GW(p)
After reading through the Unifying Grokking and Double Descent paper that LawrenceC linked [LW(p) · GW(p)], it sounds like I'm mostly saying the same thing as what's in the paper.
(Not too surprising, since I had just read Lawrence's comment, which summarizes the paper, when I made mine.)
In particular, the paper describes Type 1, Type 2, and Type 3 patterns, which correspond to my easy-to-discover patterns, memorizations, and hard-to-discover patterns:
In our model of grokking and double descent, there are three types of patterns learned at different
speeds. Type 1 patterns are fast and generalize well (heuristics). Type 2 patterns are fast, though
slower than Type 1, and generalize poorly (overfitting). Type 3 patterns are slow and generalize well.
The one thing I mention above that I don't see in the paper is an explanation for why the Type 2 patterns would be intermediate in learnability between Type 1 and Type 3 patterns or why there would be a regime where they dominate (resulting in overfitting).
My proposed explanation is that, for any given task, the exact mappings from input to output will tend to have a characteristic complexity, which means that they will have a relatively narrow distribution of learnability. And that's why models will often hit a regime where they're mostly finding those patterns rather than Type 1, easy-to-learn heuristics (which they've exhausted) or Type 3, hard-to-learn rules (which they're not discovering yet).
The authors do have an appendix section A.1 in the paper with the heading, "Heuristics, Memorization, and Slow Well-Generalizing", but with "[TODO]"s in the text. Will be curious to see if they end up saying something similar to this point (about input-output memorizations tending to have a characteristic complexity) there.
comment by LawrenceC (LawChan) · 2022-12-06T20:57:15.741Z · LW(p) · GW(p)
Also, a cheeky way to say this:
What Grokking Feels Like From the Inside
What does grokking_NN feel like from the inside? It feels like grokking_Human a concept! :)