Understanding SAE Features with the Logit Lens

post by Joseph Bloom (Jbloom), Johnny Lin (hijohnnylin) · 2024-03-11T00:16:57.429Z · LW · GW · 0 comments

Contents

  Executive Summary
  Introduction
  Characterizing Features via the Logit Weight Distribution
  Token Set Enrichment Analysis
    What is Token Set Enrichment Analysis?
    Method Steps
    Case Studies
        Regex Sets
        NLTK Sets: Starting with a Space, or a Capital Letter
        Arbitrary Sets: Boys Names and Girls Names
  Discussion 
    Limitations
    Future Work
        Research Directions
        More Engineering Flavoured Directions
  Appendix
    Thanks
    How to Cite
    Glossary
    Prior Work 
    Token Set Enrichment Analysis: Inspiration and Technical Details
      Inspiration
      Technical Details
    Results by Layer 
None
No comments

This work was produced as part of the ML Alignment & Theory Scholars Program - Winter 2023-24 Cohort, with support from Neel Nanda and Arthur Conmy. Joseph Bloom is funded by the LTFF, Manifund Regranting Program, donors and LightSpeed Grants. This post makes extensive use of Neuronpedia, a platform for interpretability focusing on accelerating interpretability researchers working with SAEs.

Links: SAEs on HuggingFace, Analysis Code

Executive Summary

This is an informal post sharing statistical methods which can be used to quickly / cheaply better understand Sparse Autoencoder (SAE) features. 

Feature 4467Above: Feature Dashboard Screenshot from Neuronpedia. It is not immediately obvious from the dashboard what this feature does. Below: Logit Weight distribution classified by whether the token starts with a space, clearly indicating that this feature promotes tokens which lack an initial space character.  

Introduction

In previous work [LW · GW], we trained and open-sourced a set of sparse autoencoders (SAEs) on the residual stream of GPT2 small. In collaboration with Neuronpedia, we’ve produced feature dashboards, auto-interpretability explanations and interfaces for browsing for ~300k+ features. The analysis in this post is performed on features from the layer 8 residual stream of GPT2 small (for no particular reason). 

SAEs might enable us to decompose model internals into interpretable components. Currently, we don’t have a good way to measure interpretability at scale, but we can generate feature dashboards which show things like how often the feature fires, its direct effect on tokens being sampled (the logit weight distribution) and when it fires (see examples of feature dashboards below). Interpreting the logit weight distribution in feature dashboards for multi-layer models is implicitly using Logit Lens [LW · GW], a very popular technique in mechanistic interpretability. Applying the logit lens to features means that we compute the product of a feature direction and the unembed (WuWdec[feature]), referred to as the “logit weight distribution”.

 

Feature 6649A feature with a fairly typical logit weight distribution (red / blue, bottom right corner). The distribution looks like a gaussian distribution with outliers that are often related tokens/words. The positive logits here point to a statistics theme. The negative logits are often uninterpretable as they are here. 

Since SAEs haven’t been around for very long, we don’t yet know  what the logit weight distributions typically look like for SAE features. Moreover, we find that the form of logit weight distribution can vary greatly. In most cases we see a vaguely normal distribution and some outliers (which often make up an interpretable group of tokens boosted by the feature). However, in other cases we see significant left or right skew, or a second mode. The standard case has been described previously by Anthropic in the context of the Arabic feature they found here and is shown above for feature 6649.

Below, we share some feature dashboard examples which have non-standard characteristics. We refer specifically to the red/blue histogram representing logit weight distribution of each feature, but share other dashboard components for completeness. 

Left: A feature with a bimodal logit weight distribution (a “partition feature”)Center: A feature with left skewness in the logit weight distribution (a “suppression feature”)Right: A feature with a thick right tail in the logit weight distribution (a “prediction feature”)

Characterizing Features via the Logit Weight Distribution

To better understand these distributions, (eg: how many have thick tails or how many have lots of tokens shifted left or right), we can use three well known statistical measures:

  1. Standard Deviation: Standard deviation measures the spread of a distribution (the average distance of a data point from the mean). 
  2. Skewness: Skewness is a measure of how shifted a distribution is. Right-shifted distributions have positive skew and left-shifted distributions have negative skew.
  3. Kurtosis: Kurtosis is a measure of the thickness of the tails of a distribution. A kurtosis greater than 3 means a distribution has thicker tails than the normal distribution. 

We note that statistics of the logit weight distribution of neurons have been previously studied in Universal Neurons in GPT2 Language models (Gurnee et al) where universal neurons (neurons firing on similar examples across different models) appeared likely to have elevated WU kurtosis. Neurons with high kurtosis and positive skew were referred to as “prediction neurons” whilst neurons with high kurtosis and negative skew were described as suppression neurons. Furthermore, partition neurons, which promoted a sizable proportion of tokens while suppressing the remaining tokens, were identifiable via high variance in logit weight distribution. 

Below, we show a plot of skewness vs kurtosis in the logit weight distribution of each feature, coloring by the standard deviation. See the appendix for skewness / kurtosis boxplots for all layers and this link to download scatterplots for all layers. 

Scatterplot of the Skewness vs Log Kurtosis of logit weight distributions coloured by standard deviation. We show boxplots of the distributions of skewness / kurtosis of the logit weight distributions across all layers in the appendix. 

We then use the above plot as a launching point for finding different kinds of features (analogous to types of neurons found by Gurnee et al). 

  1. Local Context Features (examples here)
    1. Characterization: Features with high kurtosis. These features are easily identifiable in the top right corner of the plot above.
    2. Interpretation: In general, features with low standard deviation and high kurtosis promote specific tokens or small sets of tokens. The highly noticeable outliers appear to be bracket-closing or quote-closing local context features [LW · GW]. 
  2. Partition Features (examples here)
    1. Characterization: These features are identifiable via higher standard deviation (red), right skewness and low kurtosis.
    2. Interpretation: These features have non-standard logit weight distributions, though the “prototypical” features of this class (based on my intuition) are cleanly bimodal. We found it somewhat surprising that the left and right modes correspond to different combinations of whether or not the next token starts with a space and whether or not the next token is capitalized. 
  3. Prediction Features (examples here)
    1. Characterization: These features have high skewness and standard deviation that is between 0.03 and 0.07 (yellow/green, but not blue/red). 
    2. Interpretation: These features appear to promote interpretable sets of tokens. There are many different identifiable sets including numerical digits, tokens in all caps and verbs of a particular form. 
  4. Suppression Features (examples here)
    1. Characterization: These features have negative skew but low kurtosis, suggesting they might be better understood as suppressing sets of tokens.
    2. Interpretation: I don’t know if there’s a clear boundary between these features and partition features. It seems like they might be partition features (used to reason about a space or capital letter at the start of the next token) which are more context-specific and therefore might suppress a more specific set of tokens.

Token Set Enrichment Analysis

Given previous results, we are particularly interested in identifying the set of tokens which a particular feature promotes or suppresses. Luckily, the field of bioinformatics has been doing set enrichment tests for years and it’s a staple of some types of data analysis in systems biology. We provide some inspiration and technical detail in the appendix, but will otherwise provide only a cursory explanation of the technique.

What is Token Set Enrichment Analysis?

Token Set Enrichment Analysis (TSEA) is a statistical test which maps each logit weight distribution, and a library of sets of tokens, to an “enrichment score” (which indicates how strong that feature seems to be promoting/suppressing each set of features). 

Method Steps

  1. Generate a library of token sets. These sets are our “hypotheses” for sets of tokens the model will affect. I expect to develop automated methods for generating them in the future, but for now we use sets that are easy to generate. We assume these sets are meaningful (in fact the test relies on it). 
  2. Calculate enrichment scores for all features across all sets. We calculate the running sum statistics which identify which features promote / suppress tokens in the hypothesis sets. In theory, we should calculate p-values by determining distribution of the score under some null hypothesis but we do not think this level of rigor is appropriate at this stage of analysis. 
  3. Plot these and look for outliers. The canonical way to represent the results is with a manhattan plot. For our purposes, this is a scatter plot where the x-axis corresponds to features, and the y-axis is -1*log10(enrichment score). Elevated points represent statistically significant results where we have evidence that a particular feature is promoting or suppressing tokens in a particular set. 
  4. Inspect features with high enrichment scores. We can then validate “hits” by plotting the distribution of the logit weight by token set membership. This looks like dividing our logit weight distributions into tokens in the enriched set vs not and looking at “outliers” which look like false positives or negatives. 

Case Studies

As a rough first pass, we generate a number of token sets corresponding to:

  1. Regex Sets: Sets categorized by regex expressions (eg: starting with a specific letter, or capital, or all digits). Our interpretation of partition features (which distinguish tokens starting with a space or capital letter) can be verified in this way - see the plots at the beginning of the post. 
  2. Part-of-Speech Tagging: Sets found by the NLTK part-of-speech tagger. Here we show that various prediction features promote tokens in interpretable sets such as different classes of verbs. 
  3. Boy / Girl Names: Here we explore the idea that we can define sets of tokens which correspond to sets we care about, cautiously imposing our own hypothesis about token sets that might be promoted / suppressed in hopes of discovering related prediction features. 

Note: we filter for the top 5000 features by skewness to reduce the over-head when plotting results. 

Regex Sets

Below we show the manhattan plot of the enrichment scores of the top 5000 features by skewness and the following token set:

Manhattan Plot: Token Set Enrichment over Regex Sets. We label the top 3 feature results per set (and do not propose a threshold at which a result should be considered significant, until we understand these results better).

We see that there is a fairly strong token set effect whereby some of the sets we tested achieved generally higher enrichment scores than others. If we wanted to use these results to automatically label features, we’d want to decide on some meaningful threshold here, but let’s first establish we’re measuring what we think we are. 
 

To gain a sense for what kinds of features we show the logit weight distribution for feature 89 below, which was enriched for all caps tokens. We show a screenshot of it’s feature dashboard and a logit weight distribution grouped by the all_caps classification, which show us:

  1. The feature dashboard for feature 89 suggested that it promoted 2 capital character tokens beginning with a space. 
  2. However, the logit weight distribution histogram makes it appear that many different all caps tokens are directly promoted by this feature. 

Looking at the feature activations on neuronpedia, it seems like the feature is loss reducing prior to tokens that are in all caps but not made only of two tokens, which supports the hypothesis suggested by the TSEA result. 
 

Feature 89: This feature appears to promote tokens with 2 capital characters and beginning with a space, but our set enrichment statistic suggests it might be better thought of as promoting all capitalized tokens. 

 

NLTK Sets: Starting with a Space, or a Capital Letter

We can use the NLTK part-of-speech tagger to automatically generate sets of tokens which are interesting from an NLP perspective. In practice these sets are highly imperfect as the tagger was not designed to be used on individual words, let alone tokens. We can nevertheless get passable results. 

Let’s go with different types of verbs:

  1. VBN: verb, past participle (2007 tokens). Eg: “ astonished”, “ pledged”.  
  2. VBG: verb, gerund/present participle taking (1873 tokens).  Eg: “ having”, “ including”. 
  3. VB: verb, base form (413 tokens). Eg: “ take”, “ consider”.
  4. VBD: verb, past tense took (216 tokens). Eg: “ remained”, “ complained”.
     
Manhattan Plot: Token Set Enrichment over NLTK Identified Verb Sets. We label the top 3 feature results per set (and do not propose a threshold at which a result should be considered significant, until we understand these results better).

As before, we see a token set effect, though after seeing this result I feel more confident that set size doesn’t explain the set effect.  Why do we not see features for base form verbs achieve higher enrichment scores than other (even smaller) verb sets? Possibly this is an artifact of tokenization in some way, though it's hard to say for sure. As before, let’s look at some examples to gain intuition. 
 

Our largest enrichment score overall is feature 5382 for verbs in the gerund form (ending in ing). I don’t identify a more specific theme in the top 10 positive logits (verbs starting in “ing"), though maybe there is one, so it seems like the enrichment result is in agreement with the statistics. I’m disappointed with the NLTK tagger which said that tokens like “ Viking”, “Ring” and “String” were gerund form verbs (and these are the far left outliers where the feature does not promote those tokens. 

Feature 5382: This feature appears to promote verbs ending in “ing”. The overlap between the distributions appears a result of mislabelling by the NLTK part of speech tagger than any “misclassification” of the feature. 

 

Moving on, feature 18006 appears to promote tokens labeled as past participles (nltk_pos_VBN) as well as past verbs (nltk_pos_VBD). This is actually somewhat expected once you realize that all of these tokens are verbs in the past tense (and that you can’t really distinguish the two out of context). Thus we see that our set enrichment results can be misleading if we aren’t keeping track of the relationship between our sets. To be clear, it is possible that a feature could promote one set and not the other, but to detect this we would need to track tokens which aren’t in the overlap (eg: “began” vs “ begun” or “saw” vs “seen”. I don’t pursue this further here but consider it a cautionary tale and evidence we should be careful about how we generate these token lists in the future. 
 

Feature 18006This feature appears to promote both tokens in both the past tense verbs (VBD) and past participles (VBN). 

Arbitrary Sets: Boys Names and Girls Names

Many features in GPT2 small seem fairly sexist so it seemed like an interesting idea to use traditionally gendered names as enrichment sets in order to  find features which promote them jointly or exclusively. Luckily, there’s actually a python package which makes it easy to get the most American first names. We plot the enrichment scores for one set on the x-axis and the enrichment scores for another set on the y-axis to aid us in locating features 

Scatter Plot of Token Set Enrichment Scores on the set of 300 most common boys names vs girls names for the top 5000 features by skewness. We’ve labeled points that are particularly far away from the y=x line. The marginal distributions of each feature are also shown to make the density easier to observe.

We see here that:

  1. Some features score highly for both sets, which we might reasonably understand as name promoting features. 
  2. We see more features in the top left corner than the bottom right corner, suggesting that we have some more female name specific features than male specific features. 

To clarify if this has pointed to some interesting features, let’s look at a case study from the bottom right and the top left. 

Feature 2896 - A Patriarchy Feature? Our enrichment statistics suggested that this feature promotes boys names and not girls names, but further investigation provides a much more nuanced understanding of the feature.

Some observations about Feature 2896:

Let’s now look at a feature which promoted the girls names over boys names. 

Feature 18206: This feature appears to promote tokens that match common girls' names. 

Some observations about feature 18206:

I think both of these cases studies suggested we had found interesting features that were non-trivially related to boys/names and girls names, but clearly enrichment results can’t be taken at face value due to factors like overlapping sets and the fact we’re apply the logit lens in model with a tied embedding. 

Discussion 

Limitations

I think it’s important to be clear about the limitations of this work so far:

  1. I have not quantified the proportion of features which are partition, suppression, prediction or local context features (nor given exact thresholds for how to identify them). Nor have I explored these statistics in detail in other layers or models. 
  2. The effect of a feature is not entirely captured by its direct contribution to the logit distribution, so we should be careful not to over-interpret these results. It seems like this might be part of the puzzle, but won’t solve it. Moreover, I’m optimistic that this genre of technique (set enrichment) will help us with these “internal signals” but this is not addressed here.  
  3. While it seems likely to me that TSEA over a large number of sets will be cheaper than auto-interp (because of how expensive auto-interp is), it’s definitely not a complete replacement, nor is it clear exactly how expensive or practical it will be to do TSEA at scale.
  4. It seems plausible that rather than doing set enrichment over tokens, we should look at it over features. Features are inherently more interpretable and the rich internal structure we really care about is likely between features. Since set enrichment only requires a ranking over objects and a hypothesis set, we can replace tokens with features easily. The current  bottleneck to this is identify which groups of features are meaningful. 

Future Work

Research Directions

More Engineering Flavoured Directions

Appendix

Thanks

I’d like to thank Neel Nanda and Arthur Conmy for their support and feedback while I’ve been working on this and other SAE related work. 

I’d like to thank Johnny Lin for his work on Neuronpedia and ongoing collaboration which makes working with SAEs significantly more feasible (and enjoyable!). 

I also appreciate feedback and support from Andy Arditi. Egg Syntax, Evan Anders and McKenna Fitzgerald.

How to Cite

@misc{bloom2024understandingfeatureslogitlens,
   title = {Understanding SAE Features with the Logit Lens},
   author = {Joseph Bloom, Johnny Lin},
   year = {2024},
   howpublished = {\url{https://www.lesswrong.com/posts/qykrYY6rXXM7EEs8Q/understanding-sae-features-with-the-logit-lens}},
}

 

Glossary

Prior Work 


 

Token Set Enrichment Analysis: Inspiration and Technical Details

Inspiration

Gene Set Enrichment Analysis (GSEA) is a statistical method used to check if the genes within some set are elevated within some context. Biologists have compiled extensive sets of proteins associated with different biological phenomena which are often used as a reference point for various analyses. For example, the Gene Ontology Database contains hierarchical sets which group proteins by their structures, processes and functions. Other databases group proteins by their interactions or involvement in pathways (essentially circuits). Each of these databases support GSEA, which is routinely used to map between elevated levels of proteins in samples and broader knowledge about biology or disease. For example, researchers might find that the set of proteins associated with insulin signally are in particularly low abundance in patients with type 2 diabetes, indicating that insulin signaling may be related to diabetes. 


 

Here, we’re going to perform Token Set Enrichment Analysis, which you can think of as a kind of “reverse probing”. When probing, we train a classifier to distinguish data points according to some labels. Here, SAEs have already given us a number of classifiers (over tokens) and we wish to know which sets of tokens they mean to distinguish. The solution is to use a cheap statistical test which takes a hypothesis set of tokens, and checks whether they are elevated in the logit weight distributions. 

Technical Details

If we treat each of our feature logit distributions as a ranking over tokens, and then construct sets of interpretable tokens, we can calculate a running-sum statistic which quantifies the elevation of those tokens in each of the logit weight distributions for each set. The score is calculated by walking down the logit weight distribution, increasing a running-sum statistic when we encounter a token in the set, S, and decreasing it when we encounter a token not in S.
 

The figure below is a standard GSEA “enrichment plot” showing the running sum statistic for some set / ranking over genes. We note that usually a false discovery rate is estimated when performing large numbers of these tests. We’ve skipped this procedure as we’re short on time, but this should be implemented when this technique is applied at scale. 
 

Output of GSEA.py

Results by Layer 

See this file for scatter plots for skewness and kurtosis for each layer. 


 

0 comments

Comments sorted by top scores.