Normalizing Sparse Autoencoders

post by Fengyuan Hu (hufy-dev) · 2024-04-08T06:17:15.536Z · LW · GW · 18 comments

Contents

  TL;DR
  Introduction
  Motivations
  Theoretical Analysis
    Definitions
    The Effect of Input Norms on Feature Suppression
    The Effect of Input Norms on the Inconsistency of L0 Across Layers
  Normalizing SAEs
    Architecture
    Loss
  Experiments 
    Feature Suppression is Suppressed in Normalized SAE
    Normalizing L1 Removes the Correlation Between Input Norm and L0
    L1 Agrees with L0 Better
    Performance Validation
  NSAE Statistics
  Discussion
    Limitations
    Future Work
  Appendix
    Hyperparameters
    Related Work
None
18 comments

TL;DR

Sparse autoencoders (SAEs) presents us a promising direction towards automating mechanistic interpretability, but it not without flaws. One known issue of the original sparse autoencoders is the feature suppression [LW · GW] effect which is caused by the conflict between the  and  loss and the unit norm constraint on the SAE decoders. This effect in theory will be more evident when we have inputs that have high norms. Another observation is that training SAEs on multiple layers simultaneously results in inconsistent  norms for feature activations across layers: in some layers,  has scale of   , while in some other layers it has a scale of  . Moreover, the residual states that's inputed to the SAEs for training also have different norms across layers. Hence, I argue that the current SAE architecture is not robust against inputs of varying norms, which is commonly the case in modern LLMs. In this post, I a modified SAE architecture, namely Normalized Sparse Autoencoder (NSAE), and gave a theoretical proof that it will not have the feature suppression problem. I then conducted experiments to verify the effectiveness of the proposed method, which showed that:

  1. Feature suppression is suppressed in NSAEs
  2. The normalization removed the correlation between layer mean input norm and 
  3. The normalization makes  agrees with  better

I then further investigated the learned feature dictionaries and identified 3 types of feature vectors: the correction vector, the pillar vector, and the direction vector. I then concluded this post with discussion on the limitations of NSAEs and gave my suggestions on future directions.

Introduction

Training Sparse Autoencoders [LW · GW] (SAEs) on the residual states of pretrained models is a recently proposed method in mechanistic interpretability to tackle the problem of superposition. This method is scalable and unsupervised, making it promising for auto-interpretability research. 

More specifically, a SAE contains an encoder and a decoder. It is trained to generate sparse feature activations from the original residual states of a source model through the encoder, and reconstruct the residual state through a decoder. It is expected that by training the SAE with a large set of activations jointly optimizing for a sparsity loss on the feature activations and a  reconstruction loss, the model can learn to decompose residual states into monosemantic feature vectors that are more interpretable.

In this post, I identified a flaw in the original SAE implementation, namely inconsistency of the  loss across layers, and proposed a method to mitigate this problem. With the new method, we can significantly decrease the correlation between the norm of the source model's residual activations and the  norm of the feature activations, making the training process more robust and controllable. The code is available on GitHub (notice that you should use the dev branch instead of others).

Motivations

Feature suppression [LW · GW] is a known problem for SAEs. It originated from a conflict between the  sparsity loss and the  reconstruction loss, as the reconstruction's norm is correlated with , and the SAE model learns to generate a reconstruction with smaller norm for a better  loss. This is not desirable, as we would like the reconstruction to best correspond to the original input activations. Therefore, finding a way to disentangle the input norms from  and  is beneficial.

Also, in my personal experiments with training SAEs using this implementation from the AI Safety Foundation, I observed an inconsistency of the  sparsity loss across layers:

Figure 1a. The  loss of the activations in the layer indexed 1.
Figure 1b. The  loss of the activations in the layer indexed 10.

The above two figures are the  losses of two different layers from the same training run, but the scale of  has a  difference.

Moreover, the sparsity measured by  is also vastly different across layers:

Figure 2a. The  norm of the activations in the layer indexed 1.
Figure 2b. The  norm of the activations in the layer indexed 10.

I argue that this is also undesirable, as we introduced the  coefficient  in attempt to control the balance between the  and  loss across layers. Ideally,  should have consistent control across layers, which is not the actual case.

Moreover, there is an inconsistency of the norms of the source model's residual states across layers. We can plot the distribution of residual states[1] norms in GPT-2 small across layers:

Figure 3. The norm distribution of residual states in different layers of the residual stream of GPT-2 small during inference. 

It is obvious that the mean and variance of the norms differ across layers.

This effect is common among LLMs, and we can find similar effects in more recent models like LLaMA-2 and Gemma:

Figure 4a. The norm distribution of residual states in different layers of the residual stream of LLaMA2-7B during inference. 
Figure 4b. The norm distribution of residual states in different layers of the residual stream of  Gemma-2B during inference. 

This provides some evidence that the inconsistency of input norms might have caused the undesirable behaviors in SAEs. Thus, I will conduct a theoretical analysis in the next section to further illustrate this problem.

Theoretical Analysis

Definitions

With these observations in mind, let's do a theoretical analysis on this loss to see why they might have happened. 

Formally, a SAE can be defined as the following:

We denote the output of encoder as the feature activation 

The loss function for optimization is defined as

where the  coefficient  is a hyperparameter of the user's choice and  is the k-norm of a given vector.

We set another hyperparameter expansion factor  and denote the source model's residual dimension as . Then we can define  and we have , and .

In the original implementation, the authors constrained the decoder to have unit norm column vectors, so that during the optimization process the model won't minimize the  loss by increasing the column norms of the decoder and learn to generate dense feature activation of small . This design choice lead to a potential flaw in the method and will be discussed in a later section of this post.

The Effect of Input Norms on Feature Suppression

The authors who identified feature suppression have provided a nice theoretical analysis in the Feature Suppression [LW · GW] section, but for the comprehensiveness of this post, I will conduct a similar analysis using the terms defined in this post.

We first consider the extreme case where an input  has a feature activation  that only has one positive entry , with all other entries equal to 0. Then we have  where  is the -th column vector of . Since  is column normal, we must have 

More generally, I will show that when  is sparse, we also have .

Define  the index set of all nonzero entries in the feature activation. Then we assume that the feature vectors in the set  are (almost) mutually orthogonal[2], which is . By the constraint that the decoder have unit norm, which is , we have

In the case of sparse , we have 

Then our loss function becomes the following:

If we attempt to minimize this loss, there is always a tradeoff between the reconstruction accuracy and the norm of the reconstruction. In most cases, the model will learn to construct  that's close enough to  but slightly smaller than  to achieve low losses in both terms.

The Effect of Input Norms on the Inconsistency of  Across Layers

Here, we make the similar assumption that when  is sparse, we have .

For the  term, we have

At first glance, this might not be obvious, but if our reconstruction  is similar enough to , we can take  [3]and the equation simplifies to 

Now we can rewrite our loss: 

Notice that, if  is in a relatively fixed scale, then the first term has a scale of  while the second term has a scale of . Then , given a fixed , if we have a larger , the loss term will bias towards the second term, which agrees with the observation I had earlier: the source model's residual states in deeper layers have larger norms than shallower layers, and the  loss was significantly higher in deeper layers as the loss was dominated by the larger  term.

Normalizing SAEs

After such analysis, it natural for us to ask: is there a way to solve these problems? 

My answer is yes!

Here, I propose an architectural modification to the original SAE architecture, which I have named the Normalized Sparse Autoencoder (NSAE).

Architecture

The modified architecture is defined as the following:

In this definition,  is the new feature activation, and  is no longer constrained to unit norm. A Gaussian error term  is introduced to regularize the feature activation, which is sampled from  for some hyperparameter .

The introduction of tanh normalizes every entries of  to the range of . The benefits of doing this are threefolds:

  1. This makes  independent of the norm of the input, hence theoretically prevents feature suppression.
  2. When the entries of  are in the range of  and  are much closer, making the  loss a more accurate measure of sparsity.
  3. The decoder learns features with norms, which can potentially leads to better interpretability as we can now consider both directions and norms. 

The Gaussian noise term is also essential in this architecture. Without it, the model can learn to minimize  by learning to map to very small positive values in the feature activation space and learn decoders with extremely large column norms.

To show why adding Gaussian noise solves this problem, I plotted the activation in the following figure:

Figure 5. The tanh(ReLU(x)) function and the ranges that different ranges of inputs maps to. For large inputs, the input range maps to a very small region on the y-axis, meaning that perturbations in that range do close to no change to the output, while smaller inputs are much more sensitive to perturbation.

From the figure, we can see that when the inputs are small, the output of tanh(ReLU) will be relatively sensitive to the input, and adding Gaussian noise can significantly perturb small feature activations. On the contrary, larger inputs to the activation function are much more robust to perturbation, as they all maps to similar values close to . Hence, this perturbation forces the model to learn to generate feature activations that are either strictly 0 or close to 1, which makes  behave even more like , especially when we set  to be large.

Loss

We also have to redefine the loss as follows:

We introduced the additional step of scaling  by the square of the mean of the input norm of one layer. This is because . If we assume that the best an optimizer can do is to achieve a fixed cosine similarity between  and  without the  constraint, then we can treat the  term as a constance, so the  loss is of the scale , while  which should be constant across layers. Therefore, we can manually scale the  loss to match the scale of the  loss. Another way to scale the loss is by using the actual  of the given sample. Theoretically this might cause the model to overfit to inputs of large norms, but for the conciseness of this post, I will leave this problem for future work to investigate, and only use the mean normalization for all the following experiments. 

Experiments 

I trained two groups of SAEs, one baseline and one experiment, on all layers of GPT2, and each group contains 2 training runs trained on  activations. These four runs used different sets of  coefficient and learning rate, and the baseline used the original SAE while the experiment used the normalized SAE. I will use "the experiment group" and "the normalized group" interchangeably. 

Feature Suppression is Suppressed in Normalized SAE

To investigate feature suppression, I added a new verification metric that measures the ratio between the norm of reconstructions and norm of source activations. Here is this measure during training:

Figure 6. Mean feature suppression () during training, higher is better.

Clearly, the normalized group has significant higher score on feature suppression than the experimental group, and that score is very close to one. Considering the fact that this NSAE didn't fully converge as it only went through 200M training examples, and there is not a sign of this score to flatten, I claim that NSAEs have less to none feature suppression.

Normalizing  Removes the Correlation Between Input Norm and 

To investigate the effect of normalization, I collected the  norms of different layers during the end of training and plotted them against the mean input norms of the layer:

Figure 7. The correlation between mean input norms and the mean  norm of the feature activation.

The red and blue datapoints are from the baseline group whereas the cyan and purple datapoints are from the experiment. We can fit lines to these datapoints to find linear relationships between the mean input norm and the mean  norm of the feature activations. Although the fitting is not good, the fitted lines still show a rough positive linear correlation between the mean input norm and the feature activation  norm in the baseline. In contrast, the two normalized samples did not exhibit a statistical significant positive linear relationship between input norm and 

This linear fit definitely does not look satisfactory, and I further investigated the reasons behind it. I plotted the normalized group's  against layer index, and here is what it looks like:

Figure 8. The correlation between layer and the mean  norm of the feature activation.

I conjecture that  in the normalized group reflects a level of discreteness of the activations of the source model, as it exhibit an increase-then-decrease pattern. In the source model, earlier activations are more discrete as they originated from discrete input embeddings, and as deeper activations might be less discrete as they aggregate information. In the last layers, as the model has to make the next token prediction as accurate as possible, the activations might become more discrete again for better next-token decoding since the decoding layer is discrete. This discreteness might also be positively correlated with the monosemanticity of the activations, as more discrete activations are often more interpretable. I will not verify this conjecture in this post due to length considerations, and I welcome other to study this problem.

 Agrees with  Better

To investigate the agreement between  and , I plot the mean  and  of the feature activations for both groups:

Figure 9. Agreement between  and . The thing that matters is the distance between two lines of the same color.

Clearly, the cyan and purple solid lines (which are ) are much closer to their corresponding dashed lines () than the baselines, indicating better agreement between  and .

Performance Validation

To validate that the normalization did not heavily impact performance, I present the reconstruction score metric. I first calculate the loss of no intervention, zero intervention (replacing hidden states in one layer with zero vectors), and reconstruction intervention (replacing hidden states in one layer with reconstructed vectors from SAE), and I will denote them as , and , respectively. Then, the score is calculated by

Since we expect  to be higher than  , and we want   to be close to  , so higher score is better, and we expect a value close to . The score during training is show below:

Figure 10. Mean reconstruction score during training.

There is no observable difference between the normalized group and the baseline group except that the normalized group's score seems slightly more stable during training, indicating that the normalization did not heavily impact performance but might improved training stability.

Since the mean reconstruction score is heavily impacted by the sparsity of the feature activation, I also compared a layer where the  of the baseline and experiment group best agrees with each other:

Figure 11a.  norm of layer 5 for experiment and baseline.
Figure 11b. Reconstruction score of layer 5 for experiment and baseline.

Still, there is not an observable difference between the experimental group and the baseline after convergence. This provided further evidence that the normalization did not have a observable negative impact on the performance of SAEs.

NSAE Statistics

To further investigate what the new SAE has learned, I did some statistical analysis on the NSAE feature dictionary from the first run. For comparison, I used the original SAE trained in the first baseline run.

I first analyzed the norm distribution of the feature vectors along the layers:

Figure 12. Norm distribution histogram of the feature vectors from the NSAE decoder across layers. 

Interestingly, a large proportion of feature vectors have norms in the range of , which might indicate that these vectors are small correction vectors that are added to a bigger vector to make the prediction as close as possible. In contrast, I hypothesize that feature vectors of norms that have high mean activation norm should have good interpretability as they represent general directions to the reconstruction. Hence, I will name these vectors as the pillar vectors.

Next, I calculate the distribution of cosine similarity of the feature dictionary:

Figure 13. Distribution of cosine similarity in the feature dictionary of the NSAE and original SAE, respectively.[4]

From the figure, it's obvious that the cosine similarity distribution of NSAE and SAE are very similar except that in NSAE there are some cosine similarity very close to one. my hypothesis to these vectors is that in NSAE, there are some direction vectors that appears frequently in different norms in the decomposition of source model activations, so that NSAE have to learn these vectors of the same direction in different norms.

A natural question to ask is that: do pillar vectors and direction vectors overlap? To answer this, I picked the top- vectors (in terms of norm) of each layer from the feature dictionary as a set of pillar vectors and calculated their cosine similarity, and here is the distribution:

Figure 14. Distribution of cosine similarity for high-norm feature vectors (pillar vectors)

Since the are little to none vectors that have very high cosine similarity, there is minimal overlap between pillar vectors and direction vectors.

As this post is already pretty long, I will leave a more comprehensive analysis on the learned feature dictionary to a future post and conclude this post.

Discussion

Limitations

The normalization did not come without cost. NSAEs generally have slightly higher reconstruction losses compared with the original, and it takes longer for NSAE to converge, as shown in the following figure:

Figure 15. L2 reconstruction loss during training, lower is better.

I suspect the reason of this is because NSAE learns a non-unit norm dictionary, and this dictionary have to capture all the norm information with a fixed size, whereas the original SAE can learn directions and add norm information through the feature activations. 

Another metric that I don't know how to interpret is the neural activity. In NSAE, the neural activity are significantly higher than the original SAE:

Figure 16. Neuron activity for baseline and experiment groups.

Lastly, the experiments conducted are relatively small in scale due to limitations in compute. Moreover, due to the change of the loss function, it's hard to directly match the scales of  between the baseline and the experiment group. 

Future Work

I suggest future work to go along the following directions:

  1. Investigate other factors that might caused the  inconsistency across layers. I proposed a conjecture that it might be the difference in discreteness of source model input activations across layers that caused this inconsistency. 
  2. Interpret the learned feature dictionary of NSAE. Future work can further investigate the feature vectors, especially the pillar vectors and direction vectors, and find interpretations for them. 

Appendix

Hyperparameters

I varied the hyperparameters l1_coefficients and the optimizer learning rate lr. For the two normalized groups, I also set the standard deviation of the Gaussian noise .

 baseline 1baseline 2normalized 1normalized 2
l1_coefficient0.0010730.00096420.000040650.0000965
lr0.00062750.000055840.00090450.000657
N\AN\A11
Table A1. Hyperparameters used for training that varied for different runs
expansion_factor16
context_size256
source_data_batch_size16
train_batch_size4096
max_activations100,000,000
validation_frequency5,000,000
max_store_size100,000
resample_interval200,000,000
n_activations_activity_collate100,000,000
threshold_is_dead_portion_fires1e-6
max_n_resamples4
resample_dataset_size100_000
cache_namesblocks.{layer}.hook_mlp_out
Table A2. Fixed hyperparameters for all runs

Riggs et. al. [LW · GW] proposed to use Sparse Autoencoders (SAEs) to discover interpretable features in large language models. Later, Wright et. al. [LW · GW] identified the Feature Suppression effect in SAEs and argued that the  loss induced smaller feature activations that harmed reconstruction performance. Wes Gurnee [LW · GW] observed that the reconstruction errors in SAEs are empirically pathological, and compared different norm-aware interventions to the source model's inference. Results show that replacing the original residual state with SAE significantly changed the model's predictions, especially in deeper layers. 

  1. ^

    In this and the following examples, I used the residual states from the MLP layer.

  2. ^

    This is a reasonable assumption, as data in Figure 13 (baseline) show that most feature vector pairs in the original sparse autoencoder have cosine similarities in the range of .

  3. ^

    Empirically, , which is close enough for our analysis.

  4. ^

    For computational efficiency, I randomly sampled  features from the cosine similarity matrix.

  5. ^

     collected from step=3000. Input norm sampled from a relatively small sample of random text. This text is the same as the text used to generate figure 3, 4a, and 4b.

18 comments

Comments sorted by top scores.

comment by Logan Riggs (elriggs) · 2024-04-10T14:32:02.367Z · LW(p) · GW(p)

For comparing CE-difference (or the mean reconstruction score), did these have similar L0's? If not, it's an unfair comparison (higher L0 is usually higher reconstruction accuracy).

Replies from: hufy-dev
comment by Fengyuan Hu (hufy-dev) · 2024-04-10T16:54:35.534Z · LW(p) · GW(p)

Good point. Firstly, the mean L0 between the experiment and the baseline is within a scaling factor of 2, so it's in a reasonably close range. I also added a new set of figures comparing the reconstruction score of one layer that have the closest match on L0 between the experiment group. Spoiler, the scores are still almost the same at the end of training. You can find it under Experiments-Performance Validation.

Replies from: Glen Taggart
comment by Glen Taggart · 2024-04-11T23:14:58.464Z · LW(p) · GW(p)

I want to mention that in my experience a factor of 2 difference in L0 makes a pretty huge difference in reconstruction score/L2 norm. IMO ideally you should compare pareto curves for each architecture or get two datapoints that have almost the exact same L0 if you want to compare two architectures.

Replies from: hufy-dev
comment by Fengyuan Hu (hufy-dev) · 2024-04-12T01:32:24.616Z · LW(p) · GW(p)

The additional experiment under Experiment-Performance Verification (Figure 11) compares normalized_1 and baseline_1 on layer 5 which have almost identical . The result showed no observable difference.

comment by Winnie Yang (winnie-yang) · 2024-06-02T23:55:27.591Z · LW(p) · GW(p)

Hi Hengyu! Really nice work here! I am wondering if you have released the pre-trained SAE for llama-2?

comment by Joseph Miller (Josephm) · 2024-04-08T13:42:27.249Z · LW(p) · GW(p)

It would be good to benchmark the normalized and baseline SAEs using the standard metrics of patch loss and L0.

Replies from: hufy-dev
comment by Fengyuan Hu (hufy-dev) · 2024-04-08T14:08:21.029Z · LW(p) · GW(p)

You can treat figure 7 as comparing the L0, and Figure 13 as comparing L2.

Replies from: Josephm
comment by Joseph Miller (Josephm) · 2024-04-08T15:00:04.834Z · LW(p) · GW(p)

Patch loss is different to L2. It's the KL Divergence between the normal model and the model when you patch in the reconstructed activations at some layer.

Replies from: hufy-dev
comment by Fengyuan Hu (hufy-dev) · 2024-04-08T20:50:56.417Z · LW(p) · GW(p)

Oh I see. I'll have to look into that cuz I used the AI-safety-foundation's implementation and they don't measure the KL divergence. That said, there is a validation metric called reconstruction score that measures how replacing activations change the total loss of the model, and the scores are pretty similar for the original and normalized.

Replies from: Josephm
comment by Joseph Miller (Josephm) · 2024-04-08T22:56:35.343Z · LW(p) · GW(p)

there is a validation metric called reconstruction score that measures how replacing activations change the total loss of the model

That's equivalent to the KL metric. Would be good to include as I think it's the most important metric of performance.

Replies from: Glen Taggart, hufy-dev
comment by Glen Taggart · 2024-04-11T23:23:12.443Z · LW(p) · GW(p)

I think these aren't equivalent? KL divergence between the original model's outputs and the outputs of the patched model is different than reconstruction loss. Reconstruction loss is the CE loss of the patched model. And CE loss is essentially the KL divergence of the prediction with the correct next token, as opposed to with the probability distribution of the original model.

Also reconstruction loss/score is in my experience the more standard metric here, though both can say something useful.

Replies from: Josephm
comment by Joseph Miller (Josephm) · 2024-04-13T19:30:13.583Z · LW(p) · GW(p)

Reconstruction loss is the CE loss of the patched model

If this is accurate then I agree that this is not the same as "the KL Divergence between the normal model and the model when you patch in the reconstructed activations". But Fengyuan described reconstruction score as: 

measures how replacing activations changes the total loss of the model

which I still claim is equivalent.

Replies from: Glen Taggart
comment by Glen Taggart · 2024-04-14T02:32:54.210Z · LW(p) · GW(p)

Hmm maybe I'm misunderstanding something, but I think the reason I'm disagreeing is that the losses being compared are wrt a different distribution (the ground truth actual next token) so I don't think comparing two comparisons between two distributions is equivalent to comparing the two distributions directly.

Eg, I think for these to be the same it would need to be the case that something along the lines

or

 were true, but I don't think either of those are true. To connect that to this specific case, have  be the data distribution, and  and  the model with and without replaced activations

Reconstruction score

on a separate note that could also be a crux,

measures how replacing activations changes the total loss of the model

quite underspecifies what "reconstruction score" is. So I'll give a brief explanation:

let:

  •  be the CE loss of the model unperturbed on the data distribution
  •  be the CE loss of the model when activations are replaced with the reconstructed activations
  •  be the CE loss of the model when activations are replaced with the zero vector

then

so, this has the property that when the value is 0 the SAE is as bad as replacement with zeros and when it's 1 the SAE is not degrading performance at all

It's not clear that normalizing with  makes a ton of sense, but since it's an emerging domain it's not fully clear what metrics to use and this one is pretty standard/common. I'd prefer if bits/nats lost were the norm, but I haven't ever seen someone use that.

comment by Fengyuan Hu (hufy-dev) · 2024-04-10T14:08:47.890Z · LW(p) · GW(p)

Added to Experiments-Performance Validation!

Replies from: Josephm
comment by Joseph Miller (Josephm) · 2024-04-10T19:14:00.832Z · LW(p) · GW(p)

I think just showing  would be better than reconstruction score metric because  is very noisy.

Replies from: hufy-dev
comment by Fengyuan Hu (hufy-dev) · 2024-04-11T00:14:35.700Z · LW(p) · GW(p)

I don't think  is very informative here, as it's highly impacted by the input batch. Both the raw  and  have large variances at different verification steps, and since we mainly care about how good our reconstruction is compared with the original, I think the reconstruction score is good as is. I also don't follow why the noisiness of  leads to showing .

comment by Joseph Miller (Josephm) · 2024-04-08T13:40:46.297Z · LW(p) · GW(p)

What is Neuron Activity?

Replies from: hufy-dev
comment by Fengyuan Hu (hufy-dev) · 2024-04-08T14:04:41.405Z · LW(p) · GW(p)

It is a metric from the ai-safety-foundation's implementation. It seems to measure the number of neurons in the feature activation that fires more than a threshold. At least that's my interpretation.