Open Source Replication & Commentary on Anthropic's Dictionary Learning Paper

post by Neel Nanda (neel-nanda-1) · 2023-10-23T22:38:33.951Z · LW · GW · 12 comments

Contents

  Introduction
  TLDR
  Features
    Exploring Neuron Sparsity
    Case Studies
  Ultra Low Frequency Features Are All The Same Feature
  Implementation Details
  Misc Questions
None
12 comments

This is the long-form version of a public comment on Anthropic's Towards Monosemanticity paper

Introduction

Anthropic recently put out a really cool paper about using Sparse Autoencoders (SAEs) to extract interpretable features from superposition in the MLP layer of a 1L language model. I think this is an awesome paper and I am excited about the potential of SAEs to be useful for mech interp more broadly, and to see how well they scale! This post documents a replication I did of their paper, along with some small explorations building on it, along with (scrappy!) training code and weights for some trained autoencoders.

See an accompanying colab tutorial to load the trained autoencoders, and a demo of how to interpret what the features do. 

TLDR

Features

Exploring Neuron Sparsity

One of the things I find most interesting about this paper is the existence of non neuron basis aligned features. One of the big mysteries (by my lights) in mechanistic interpretability is what the non-linearities in MLP layers are actually doing, on an algorithmic level. I can reason about monosemantic GELU neurons fairly easily (like a French neuron) - essentially thinking of it as a soft ReLU, that collect pieces of evidence for the presence of a feature, and fire if they cross a certain threshold (given by the bias). This can maybe extend to thinking about a sparse linear combination of neurons (eg fewer than 10 constructively interfering to create a single feature). But I have no idea how to reason about things that are dense-ish in the neuron basis!

As a first step towards exploring this, I looked into how dense vs sparse each non-ultra-low frequency feature was. Conceptually, I expected many to be fairly neuron-sparse - both because I expected interpretable features to tend to be neuron sparse in general, and because of the autoencoder setup: on any given input a sparse set of neurons fire, so it seems that learning specific neurons/clusters of neurons is a useful dictionary feature. 

However, a significant majority of features seem to be fairly dense in the neuron basis! As a crude metric, for each dictionary feature, I took the sum of squared decoder weights, and looked at the fraction explained by the top neuron (to look for 1-sparse features), and by the next 9 neurons (to look for 10-ish-sparse features). I use decoder rather than encoder, as I expect the decoder is closer to the “true” feature direction (ie how it’s computed) while the encoder is the subtly different “optimal projection to detect the feature” which must account for interference. In the scatter plot below we can see 2-3 rough clusters - dense features near the origin (low on both metrics, which I define as fve_top_10<0.35), 1-sparse features (fve_top_1>0.35, fve_next_9<0.1), and 10-ish-sparse features (the diffuse mess of everything else). I find ~92.1% of features are dense, ~3.9% are 1-sparse and ~4.0% are 10-ish-sparse.

Another interesting metric is neuron kurtosis, ie taking the decoder weights for each feature (in the MLP neuron basis) and taking the kurtosis (metric inspired by this paper, and by ongoing work with Wes Gurnee). This measures how “privileged” the neuron basis is, and is another way to detect unusually neuron-sparse features - the (excess) kurtosis of a normal distribution is zero, and applying an arbitrary rotation makes everything normally distributed. We can clearly see in the figure below that the neuron basis is privileged for almost all features, even though most features don’t seem neuron-sparse (red are real features, blue is the kurtosis of a randomly generated normally distributed baseline. Note that the red tail goes up to 1700 but is clipped for visibility).

 

I haven’t studied whether there’s a link between level of neuron sparsity and interpretability, and I can’t rule out the hypothesis that autoencoders learn features with significant noise and that the “true” feature is far sparser. But this seems like evidence that being dense in the neuron basis is the norm, not the exception.

Case Studies

Anecdotally a randomly chosen feature was often interpretable, including some fairly neuron dense ones. I didn’t study this rigorously, but of the first 8 non-ultra-low features 6 seemed interpretable. From this sample, I found:

Ultra Low Frequency Features Are All The Same Feature

Existence of the ultra-low frequency cluster: I was able to replicate the paper’s finding of an ultra-low frequency cluster, but didn’t find any truly dead neurons. I define the ultra-low frequency cluster as anything with frequency less than 1e-4, and it is clearly bimodal, with about 60% of features as ultra-low frequency and 40% as normal. This is in contrast to 4% dead and 7% ultra-low in the A/1 autoencoder, I’m not sure of why there’s a discrepancy, though it may be because I resampled neurons with re-initialising weights rather than the complex scheme in the paper. Anecdotally, the ultra-low frequency features were not interpretable, and are not very important for autoencoder performance (reconstruction loss goes from 92% to 91%)

The encoder weights are all the same direction: Bizarrely, the ultra-low frequency entries in the dictionary are all the same direction. They have extremely high cosine sim with the mean (97.5% have cosine sim more than 0.95)

Though they vary in magnitude, as can be seen from their projection onto the mean direction (I’m not sure why this is bimodal). (Note that here I’m just thinking about |encoder|, and really |encoder| * |decoder| is the meaningful property, but that gets a similar distribution)

This means the ultra low features are highly correlated with each other, eg I found one data point that’s extreme in the mean direction where 95% of the ultra low features fired on it

The decoder weights are not all the same direction: On the other hand, the decoder weights seem all over the place! Very little of the variance is explained by a single direction. I’m pretty confused by what’s going on here. 

The encoder direction is consistent across random seeds: The natural question is whether this direction is a weird artefact of the autoencoder training process, or an actual property of the transformer. Surprisingly, training another autoencoder with a different random seed finds the same direction (cosine sim 0.96 between the means). The fact that it’s consistent across random seeds makes me guess it’s finding something true about the autoencoder.

What does the encoder direction mean? I’ve mostly tried and failed to figure out what this shared feature means. Things I’ve tried:

A sample of texts where the average feature fires highly:

Can we fix ultra-low frequency features? These features are consuming 60% of my autoencoder, so it’d be great if we could get rid of them! I tried the obvious idea of forcing the autoencoder weights to be orthogonal to this average ultra low feature direction at each time step, and training a new autoencoder. Sadly, this autoencoder had its own cluster of ultra-low frequency features, which also had significant cosine sims with each other, but a different direction. 

Phase transition in training for feature sparsity: Around 1.6B tokens, there was a seeming phase transition where ultra low density features started becoming more common. I was re-initialising the weights every 120M tokens for features of frequency less than 1e-5, but I found that around 1.6B (the 13th ish reset) they started to drift upwards in frequency, and at some point crossed the 1e-5 boundary (which broke my resetting code!). By 2B tokens they were mostly 1e-4.5 to 1e-4, still a distinct mode, but drifting upwards. It's a bit hard to tell, but I don't think  there was a constant upward drift, it seems to only happen in late training, based on inspecting checkpoints

(Figure note: the spike at -6.5 corresponds to features that never fire)

Implementation Details

Misc Questions

Should encoder and decoder weights be tied? In this paper, the encoder and decoder weights are untied, while in Cunningham et al they’re tied (tied meaning W_dec = W_enc.T). I agree with the argument given in this paper that, conceptually, the encoder and decoder are doing different things and should not be tied. I trained my autoencoders with untied encoder and decoder.

My intuition is that, for each feature, there’s a “representation” direction (its contribution when present) and a “detection” direction (what we project onto to recover that feature’s coefficient). There’s a sparse set of active features, and we expect the MLP activations to be well approximated by a sparse linear combination of these features and their representation direction - this is represented by the decoder. To extract a feature’s coefficient, we need to project onto some direction, the detection direction (represented by the encoder). Naively, this is equal to representation direction (intuitively, the dual of the vector). But this only makes sense if all representation directions are orthogonal! We otherwise get interference from features with non-zero cosine sim between representations, and the optimal detection direction is plausibly different. Superposition implies more directions than dimensions, so they can’t all be orthogonal!

As some empirical evidence, we can check whether in practice the autoencoder “wants” to be tied by looking at the cosine sim between each feature’s encoder and decoder weights. The median (of non ultra low features) is 0.5! There’s some real correlation, but the autoencoder seems to prefer to have them not be equal. It also seems slightly bimodal (a big mode centered at 0.5 and one at 0.75 ish), I’m not sure why.

Should we train autoencoders on the MLP output? One weird property of transformers is that the number of MLP neurons is normally significantly larger than the number of residual stream dimensions (2048 neurons, 512 residual dims for this model). This means that 75% of the MLP activation dimensions are in the nullspace of W_out (the output weights of the MLP layer, ie the down-projection) and so cannot matter causally for the model. 

Intuitively, this means that a lot of the autoencoder’s capacity is wasted, representing the part of the MLP activations that doesn’t matter causally. A natural idea is to instead train the autoencoder on the MLP output (after applying W_out). This would get an autoencoder that’s 4x smaller and takes 4x less compute to run (for small autoencoders the cost of running the model dominates, but this may change for a very wide autoencoder). My guess is that W_out isn’t going to substantially change the features present in the MLP activations, but that training SAEs on the MLP output is likely a good idea (I sadly only thought of this after training my SAEs!)

Unsurprisingly, a significant fraction of each encoder and decoder weight is representing things in the null-space of W_out. Note that it’s significantly higher variance and heavier tailed than it would be for a random 512 dimensional subspace (a random baseline has mean 0.75 and std 0.013), suggesting that W_out is privileged (as is intuitive). 


 

12 comments

Comments sorted by top scores.

comment by nostalgebraist · 2023-11-04T18:18:35.030Z · LW(p) · GW(p)

My hunch about the ultra-rare features is that they're trying to become fully dead features, but haven't gotten there yet.  Some reasons to believe this:

  • Anthropic mentions that "if we increase the number of training steps then networks will kill off more of these ultralow density neurons."
  • The "dying" process gets slower as the feature gets closer to fully dead, since the weights only get updated when the feature fires. It may take a huge number of steps to cross the last mile between "very rare" and "dead," and unless we've trained that much, we will find features that really ought to be dead in an ultra-rare state instead.
  • Anthropic includes a 3D plot of log density, bias, and the dot product of each feature's enc and dec vectors ("D/E projection").
    • In the run that's plotted, the ultra-rare cluster is distinguished by a combination of low density, large negative biases, and a broad distribution of D/E projection that's ~symmetric around 0.  For high-density features, the D/E projections are tightly concentrated near 1.
    • Large negative bias makes sense for features that are trying to never activate.
    • D/E projection near 1 seems intuitive for a feature that's actually autoencoding a signal.   Thus, values far from 1 might indicate that a feature is not doing any useful autoencoding work[1][2].
    • I plotted these quantities for the checkpointed loaded in your Colab.  Oddly, the ultra-rare cluster did not have large(r) negative biases -- though the distribution was different.  But the D/E projection distributions looked very similar to Anthropic's.
  • If we're trying to make a feature fire as rarely as possible, and have as little effect as possible when it does fire, then the optimal value for the encoder weight is something like . In other words, we're trying to find a hyperplane where the data is all on one side, or as close to that as possible.  If the -dependence is not very strong (which could be the case in practice), then:
    • there's some optimal encoder weight  that all the dying neurons will converge towards
    • the nonlinearity will make it hard to find this value with purely linear algebraic tools, which explains why it doesn't pop out of an SVD or the like
    • the value is chosen to suppress firing as much as possible in aggregate, not to make firing happen on any particular subset of the data, which explains why the firing pattern is not interpretable
    • there could easily be more than one orthogonal hyperplane such that almost all the data is on one side, which explains why the weights all converge to some new direction when the original one is prohibited

To test this hypothesis, I guess we could watch how density evolves for rare features over training, up until the point where they are re-initialized?  Maybe choose a random subset of them to not re-initialize, and then watch them?

I'd expect these features to get steadily rarer over time, and to never reach some "equilibrium rarity" at which they stop getting rarer.  (On this hypothesis, the actual log-density we observe for an ultra-rare feature is an artifact of the training step -- it's not useful for autoencoding that this feature activates on exactly one in 1e-6 tokens or whatever, it's simply that we have not waited long enough for the density to become 1e-7, then 1e-8, etc.)

  1. ^

    Intuitively, when such a "useless" feature fires in training, the W_enc gradient is dominated by the L1 term and tries to get the feature to stop firing, while the W_dec gradient is trying to stop the feature from interfering with the useful ones if it does fire. There's no obvious reason these should have similar directions.

  2. ^

    Although it's conceivable that the ultra-rare features are "conspiring" to do useful work collectively, in a very different way from how the high-density features do useful work.

Replies from: neel-nanda-1
comment by Neel Nanda (neel-nanda-1) · 2023-11-06T14:32:04.831Z · LW(p) · GW(p)

Thanks for the analysis! This seems pretty persuasive to me, especially the argument that "fire as rarely as possible" could incentivise learning the same feature, and that it doesn't trivially fall out of other dimensionality reduction methods. I think this predicts that if we look at the gradient with respect to the pre-activation value in MLP activation space, that the average of this will correspond to the rare feature direction? Though maybe not, since we want the average weighted by "how often does this cause a feature to flip from on to off", there's no incentive to go from -4 to -5.

An update is that when training on gelu-2l with the same parameters, I get truly dead features but fairly few ultra low features, and in one autoencoder (I think the final layer) the truly dead features are gone. This time I trained on mlp_output rather than mlp_activations, which is another possible difference.

comment by Charlie Steiner · 2023-10-24T06:43:52.958Z · LW(p) · GW(p)

Huh, what is up with the ultra low frequency cluster? If the things are actually firing on the same inputs, then you should really only need one output vector. And if they're serving some useful purpose, then why is there only one and not more?

Replies from: neel-nanda-1
comment by Neel Nanda (neel-nanda-1) · 2023-10-24T10:00:19.672Z · LW(p) · GW(p)

Idk man, I am quite confused. It's possible they're firing on different inputs - even with the same encoder vector, if you have a different bias then you'll fire somewhat differently (lower bias fires on a superset of what higher bias fires on). And cosine sim 0.975 is not the same as 1, so maybe the error term matters...? But idk, my guess is it's a weird artifact of the autoencoder training process, that's finding some weird property of transformers. Being shared across random seeds is by far the weirdest result, which suggests it can't just be an artifact

comment by jacob_drori (jacobcd52) · 2023-10-26T22:58:43.621Z · LW(p) · GW(p)

Define the "frequent neurons" of the hidden layer to be those that fire with frequency > 1e-4. The image of this set of neurons under W_dec forms a set of vectors living in R^d_mlp, which I'll call "frequent features". 

These frequent features are less orthogonal than I'd naively expect.

If we choose two vectors uniformly at random on the (d_mlp)-sphere, their cosine sim has mean 0 and variance 1/d_mlp = 0.0005. But in your SAE, the mean cosine sim between distinct frequent features is roughly 0.0026, and the variance is 0.002.

So the frequent features have more cosine similarity than you'd get by just choosing a bunch of directions at random on the (d_mlp)-sphere. This effect persists even when you throw out the neuron-sparse features (as per your top10 definition). 

Any idea why this might be the case? My previous intuition had been that transformers try to pack in their features as orthogonally as possible, but it looks like I might've been wrong about this. I'd also be interested to know if a similar effect is also found in the residual stream, or if it's entirely due to some weirdness with relu picking out a preferred basis for the mlp hidden layer.
 

Replies from: neel-nanda-1
comment by Neel Nanda (neel-nanda-1) · 2023-10-27T18:09:42.068Z · LW(p) · GW(p)

Interesting! My guess is that the numbers are small enough that there's not much to it? But I share your prior that it should be basically orthogonal. The MLP basis is weird and privileged and I don't feel well equipped to reason about it

comment by lewis smith (lsgos) · 2023-10-26T08:36:36.162Z · LW(p) · GW(p)

maybe this is really naive (I just randomly thought of it), and you mention you do some obvious stuff like looking at the singular vectors of activations which might rule it out, but could the low-frequency cluster be linked something simple like the fact that the use of ReLUs, GeLUs etc. means the neuron activations are going to be biased towards the positive quadrant of the activation space in terms of magnitude (because negative components of any vector in the activation basis would be cut off). I wonder if the singular vectors would catch this.

Replies from: neel-nanda-1
comment by Neel Nanda (neel-nanda-1) · 2023-10-26T09:13:08.345Z · LW(p) · GW(p)

Ah, I did compare it to the mean activations and didn't find much, alas. Good idea though!

comment by laker · 2023-10-25T22:00:42.377Z · LW(p) · GW(p)

On one input 95% of these ultra rare features fired!

What is the example on which 95% of the ultra-low-frequency neurons fire?

Replies from: neel-nanda-1
comment by Neel Nanda (neel-nanda-1) · 2023-10-26T22:21:22.235Z · LW(p) · GW(p)

Search for the text "A sample of texts where the average feature fires highly" and look at the figure below

comment by vyasnikhil96 (rememberdeath) · 2023-10-24T12:04:34.287Z · LW(p) · GW(p)

Do you know if the ultra low frequency cluster survives if we do L1/L2 regularization on encoder weights? Similarly I would be curious to know if putting the right sparsity on encoder weights gives a set of neuron sparse features without hurting reconstruction loss or feature sparsity.

Replies from: neel-nanda-1
comment by Neel Nanda (neel-nanda-1) · 2023-10-25T09:36:07.109Z · LW(p) · GW(p)

I haven't checked! If anyone is interested in looking, I'd be curious