Scaling and evaluating sparse autoencoders

post by leogao · 2024-06-06T22:50:39.440Z · LW · GW · 6 comments

Contents

6 comments

[Blog] [Paper] [Visualizer]

Abstract:

Sparse autoencoders provide a promising unsupervised approach for extracting interpretable features from a language model by reconstructing activations from a sparse bottleneck layer.  Since language models learn many concepts, autoencoders need to be very large to recover all relevant features.  However, studying the properties of autoencoder scaling is difficult due to the need to balance reconstruction and sparsity objectives and the presence of dead latents.  We propose using k-sparse autoencoders [Makhzani and Frey, 2013] to directly control sparsity, simplifying tuning and improving the reconstruction-sparsity frontier. Additionally, we find modifications that result in few dead latents, even at the largest scales we tried.  Using these techniques, we find clean scaling laws with respect to autoencoder size and sparsity.  We also introduce several new metrics for evaluating feature quality based on the recovery of hypothesized features, the explainability of activation patterns, and the sparsity of downstream effects. These metrics all generally improve with autoencoder size.  To demonstrate the scalability of our approach, we train a 16 million latent autoencoder on GPT-4 activations for 40 billion tokens.  We release code and autoencoders for open-source models, as well as a visualizer.

6 comments

Comments sorted by top scores.

comment by Charlie Steiner · 2024-06-09T04:48:48.694Z · LW(p) · GW(p)

Nice! There's definitely been this feeling with training SAEs that activation penalty+reconstruction loss is "not actually asking the computer for what we want," leading to fragility. TopK seems like it's a step closer to the ideal - did you subjectively feel confident when starting off large training runs?

Confused about section 5.3.1:

To mitigate this issue, we sum multiple TopK losses with different values of k (Multi-TopK). For
example, using L(k) + L(4k)/8 is enough to obtain a progressive code over all k′ (note however
that training with Multi-TopK does slightly worse than TopK at k). Training with the baseline ReLU
only gives a progressive code up to a value that corresponds to using all positive latents.

Why would we want a progressive code over all hidden activations? If features have different meanings when they're positive versus when they're negative (imagining a sort of Toy Model of Superposition picture where features are a bunch of rays squeezed in around a central point), it seems like if your negative hidden activations are informative something weird is going on.

Replies from: leogao
comment by leogao · 2024-06-09T06:00:52.447Z · LW(p) · GW(p)

We had done very extensive ablations at small scale where we found TopK to be consistently better than all of the alternatives we iterated through, and by the time we launched the big run we had already worked out how to scale all of the relevant hyperparameters, so we were decently confident.

One reason we might want a progressive code is it would basically let you train one autoencoder and use it for any k you wanted to at test time (which is nice because we don't really know exactly how to set k for maximum interpretability yet). Unfortunately, this is somewhat worse than training for the specific k you want to use, so our recommendation for now is to train multiple autoencoders.

Also, even with a progressive code, the activations on the margin would not generally be negative (we actually apply a ReLU to make sure that the activations are definitely non-negative, but almost always the (k+1)th value is still substantially positive)

comment by Daniel Kokotajlo (daniel-kokotajlo) · 2024-06-08T10:59:36.115Z · LW(p) · GW(p)

Well done and thank you! I don't feel qualified to judge exactly but this seems like a significant step forward. Curious to hear your thoughts on the question of "by what year will [insert milestone X] be achieved assuming research progress continues on-trend." Some milestones perhaps are in this tech tree https://www.lesswrong.com/posts/nbq2bWLcYmSGup9aF/a-transparency-and-interpretability-tech-tree [LW · GW] but the one I'm most interested in is the "we have tools which can tell whether a model is scheming or otherwise egregiously misaligned, though if we trained against those tools they'd stop working" milestone.

Replies from: leogao
comment by leogao · 2024-06-08T11:53:07.969Z · LW(p) · GW(p)

Thanks for your kind words!

My views on interpretability are complicated by the fact that I think it's quite probable there will be a paradigm shift between current AI and the thing that is actually AGI like 10 years from now or whatever. So I'll describe first a rough sketch of what I think within-paradigm interp looks like and then what it might imply for 10 year later AGI. (All these numbers will be very low confidence and basically made up)

I think the autoencoder research agenda is currently making significant progress on item #1. The main research bottlenecks here are (a) SAEs might not be able to efficiently capture every kind of information we care about (e.g circular features) and (b) residual stream autoencoders are not exactly the right thing for finding circuits. Probably this stuff will take a year or two to really hammer out. Hopefully our paper helps here by giving a recipe to push autoencoders really quickly so we bump into the limitations faster and with less second guessing about autoencoder quality.

Hopefully #4 can be done to some great part in parallel with #1; there's a whole bunch of engineering needed to e.g take autoencoders and scale them up to capture all the behavior of the model (which was also a big part of the contribution of this paper). I'm pretty optimistic that if we have a recipe for #1 that we trust, the engineering (and efficiency improvements) for scaling up is doable. Maybe this adds another year of serial time. The big research uncertainty here fmpov is how hard it is to actually identify the structures we're looking for, because we'll probably have a tremendously large sparse network where each node does some really boring tiny thing.

However, I mostly expect that GPT-4 (and probably 5) is probably just actually not doing anything super spicy/stabby. So I think most of the value of doing this interpretability will be to sort of pull back the veil, so to speak, of how these models are doing all the impressive stuff. Some theories of impact:

  • Maybe we'll become less confused about the nature of intelligence in a way that makes us just have better takes about alignment (e.g there will be many mechanistic theories of what the heck GPT-4 is doing that will have been conclusively ruled out)
  • Maybe once the paradigm shift happens, we will be better prepared to identify exactly what interpretability assumptions it broke (or even just notice whether some change is causing a mechanistic paradigm shift)

Unclear what timeline these later things happen on; probably depends a lot on when the paradigm shift(s) happen.

Replies from: leogao
comment by leogao · 2024-06-08T21:10:11.019Z · LW(p) · GW(p)

To add some more concreteness: suppose we open up the model and find that it's basically just a giant k nearest neighbors (it obviously can't be literally this, but this is easiest to describe as an analogy). Then this would explain why current alignment techniques work and dissolves some of the mystery of generalization. Then suppose we create AGI and we find that it does something very different internally that is more deeply entangled and we can't really make sense of it because it's too complicated. Then this would imo also provide strong evidence that we should expect our alignment techniques to break.

In other words, a load bearing assumption is that current models are fundamentally simple/modular in some sense that makes interpretability feasible, and that observing this breaking in the future is probably important evidence that will hopefully come before those future systems actually kill everyone.

comment by Review Bot · 2024-06-10T09:48:57.600Z · LW(p) · GW(p)

The LessWrong Review [? · GW] runs every year to select the posts that have most stood the test of time. This post is not yet eligible for review, but will be at the end of 2025. The top fifty or so posts are featured prominently on the site throughout the year.

Hopefully, the review is better than karma at judging enduring value. If we have accurate prediction markets on the review results, maybe we can have better incentives on LessWrong today. Will this post make the top fifty?