200 COP in MI: Analysing Training Dynamics

post by Neel Nanda (neel-nanda-1) · 2023-01-04T16:08:58.089Z · LW · GW · 0 comments

Contents

  Background
  Motivation
  Resources
  Tips
  Problems
None
No comments

This is the sixth post in a sequence called 200 Concrete Open Problems in Mechanistic Interpretability. Start here [AF · GW], then read in any order. If you want to learn the basics before you think about open problems, check out my post on getting started. Look up jargon in my Mechanistic Interpretability Explainer

Disclaimer: Mechanistic Interpretability is a small and young field, and I was involved with much of the research and resources linked here. Please take this sequence as a bunch of my personal takes, and try to seek out other researcher’s opinions too! 

Motivating papersA Mechanistic Interpretability Analysis of Grokking [AF · GW]In-Context Learning and Induction Heads

Background

Skip to motivation if you’re familiar with the grokking modular addition and induction heads papers

Several mechanistic interpretability papers have helped find surprising and interesting things about the training dynamics of networks - understanding what is actually going on in a model during training. There’s a lot of things I’m confused about here! How do models change over training? Why do they generalise at all (and how do they generalise)? How much of their path to the eventual solution is consistent across runs and directed vs a random walk? How and when do the different circuits in the model develop, and how do these influence the development of subsequent circuits? 

The questions here are of broad interest to anyone who wants to understand what the hell is going on inside neural networks, and there are angles of attack on this that are nothing to do with Mechanistic Interpretability. But I think that MI is a particularly promising approach. If you buy the fundamental claim that the building blocks of neural networks are interpretable circuits that work together to complete a task, then studying these building blocks seems like a grounded and principled approach. Neural networks are extremely complicated and confusing and it’s very easy to mislead yourself. Having a good starting point of a network you actually understand makes it far easier to make real progress.

I am particularly interested in understanding phase transitions (aka emergent phenomena), where a specific capability suddenly emerges in models at a specific point in training. Two MI papers that get significant traction here are A Mechanistic Interpretability Analysis of Grokking [AF · GW], and In-Context Learning and Induction Heads, and it’s worth exploring what they did as models of mechanistic approaches to studying training dynamics:

In In-Context Learning and Induction Heads, the focus was on finding phase changes in real language models. We found induction heads/circuits in two layer, attention-only language models, a circuit which performs the task of detecting whether the current text was repeated earlier in the prompt, and continuing it. These turned out to be a crucial circuit in these tiny models - the heads were principally responsible for the model’s ability to do in-context learning, productively using tokens far back in the context (eg over 500 words back!) to predict the next token. Further, the induction heads all arose in a sudden phase transition, and were so important that this led to a visible bump in the loss curve!

The focus of the paper was in extrapolating this mechanistic understanding in a simple setting to real models. We devised a range of progress measures (for both overall model behaviour and for studying specific heads), and analysed a wide range of models using these. And we found fairly compelling evidence that these results held up in general, up to 13B parameter models! 

Grokking is a surprising phenomena where certain small models trained on algorithmic tasks initially memorise their training data, but when trained on the same data for a really long time, will suddenly generalise. In our grokking work, we focused on a 1 layer transformer that grokked modular addition. We did a deep dive into its internals and reverse engineered it to discover it had learned trig identity based algorithm. 

We then distilled this understanding of the final checkpoint into a series of progress measures. We studied excluded loss, where we just removed the model’s ability to use the generalising solution and looked at train performance, and restricted loss where we artificially remove the memorising solution and only allow the model to use the generalising solution and look at overall performance. These let us decompose the training dynamics into 3 discrete phases: memorisation, where the model just memorises the training data, circuit formation, where the model slowly learns a generalising circuit and transitions from the memorising solution to the generalising solution - preserving good train performance and poor test performance throughout, and cleanup, where it gets so good at generalising that it no longer needs the parameters spent on memorisation and can clean them up. The sudden emergence of “grokking” only occurs at cleanup, because even when the model is mostly capable of generalising, the “noise” of the memorising solution is sufficient to prevent good test loss. 

Both works used a mechanistic understanding of a simpler setting to try to understand a confusing phenomena, but with somewhat different focuses. The approach modelled in grokking is one at the intersection of MI and the science of deep learning. We took a specific confusing phenomena in neural networks, simplified it to a simple yet still confusing example, reverse-engineered exactly what was going on, extrapolated it to the core principles (of progress measures) and analysed these during training. By deeply focusing on understanding a specific example of a specific phenomena, we (hopefully!) gained some general insights about questions around memorisation, generalisation, phase transitions, and competition between different circuits (though there’s a lot of room for further work verifying this, and I’m sure some details are wrong!). (Check out the Interpreting Algorithmic Problems post [AF · GW] for another angle on this work). In contrast, the induction heads work has a bigger focus on showing the universality of the result across many models, studying real models in the wild, and being able to tie the zoomed in understanding of a specific model circuit to the important zoomed out phenomena of in-context learning. 

Motivation

Zooming out, what are the implications of all this for future research? I think there’s exciting work to be done, both on using toy models to find general principles of how neural networks learn and on understanding the training dynamics in real language models. A lot of the current low-hanging fruit is in fairly simple settings and building better fundamental understanding. But a good understanding of training dynamics and how it intersects with interpretability should eventually push on some important long-term questions: 

Work on deep learning mysteries (in contrast to directly on real language models) doesn’t directly push on the above, but I think that being less confused about deep learning will help a lot. And even if the above fail, I expect that exploring these questions will naturally open up a lot of other open questions and curiosities to pull on. 

I’m particularly excited about this work, because I think that trying to decipher mysteries in deep learning makes mechanistic interpretability a healthier field. We’re making some pretty bold claims that our approach will let us really reverse-engineer and fully understand systems. If this is true, we should be able to get much more traction on demystifying deep learning, because we’ll have the right foundations to build on. One of the biggest ways the MI project can fail is by ceasing to be fully rigorous and losing sight of what’s actually true for models. If we can learn true insights about deep learning or real language models, this serves both as a great proof of concept of the promise of MI, and grounds the field in doing work that is actually useful.

To caveat the above, understanding what on earth is going on inside networks as they train highly appeals to my scientific curiosity, but I don’t think it’s necessary for the goal of mechanistically understanding a fully trained system. My personal priority is to understand specific AI systems on the frontier, especially at human level and beyond. This is already a very ambitious goal, and I don’t think that understanding training dynamics is obviously on the critical path for this (though it might be! Eg, if models are full of irrelevant "vestigial organs" learned early in training). And there’s a risk that we find surface-level insights about models that allows us to make better models faster, while still being highly confused about how they actually work, and on net being even more behind in understanding frontier models. And models are already a confusingly high dimensional object without adding a time dimension in! But I could easily be wrong, and there’s a lot of important questions here (just far more important questions than researchers!). And as noted above, I think this kind of work can be very healthy for the field. So I'd still encourage you to explore these questions if you feel excited about them!

Resources

Tips

Problems

This spreadsheet lists each problem in the sequence. You can write down your contact details if you're working on any of them and want collaborators, see any existing work or reach out to other people on there! (thanks to Jay Bailey for making it)

Note: Many of these questions are outside my area of expertise, and overlap with broader science of deep learning efforts. I expect there’s a lot of relevant work I’m not familiar with, and I would not be massively surprised if some of these problems have already been solved!

0 comments

Comments sorted by top scores.