[Interim research report] Activation plateaus & sensitive directions in GPT2

post by StefanHex (Stefan42), jake_mendel · 2024-07-05T17:05:25.631Z · LW · GW · 2 comments

Contents

  Core results & discussion
  Proposal: Connecting SAEs to model behaviour
  Conclusion
  Appendix
    Methodology
    Detailed results
      1. Activation Plateaus
      2. Sensitive directions
      3. Local optima in sensitivity
None
2 comments

This part-report / part-proposal describes ongoing research, but I'd like to share early results for feedback. I am especially interested in any comment finding mistakes or trivial explanations for these results. I will work on this proposal with a LASR Labs [LW · GW] team over the next 3 months. If you are working (or want to work) on something similar I would love to chat!
Experiments and write-up by Stefan, with substantial inspiration and advice from Jake (who doesn’t necessarily endorse every sloppy statement I write). Work produced at Apollo Research.

TL,DR: Toy models [AF · GW] of how neural networks compute new features in superposition seem to imply that neural networks that utilize superposition require some form of error correction to avoid interference spiraling out of control. This means small variations along a feature direction shouldn't affect model outputs [LW(p) · GW(p)], which I can test:

  1. Activation plateaus: Real activations should be resistant to small perturbations. There should be a "plateau" in the output as a function of perturbation size.
  2. Sensitive directions: Perturbations towards the direction of a feature should change the model output earlier (at a lower perturbation size) than perturbations into a random direction.

I find that both of these predictions hold; the latter when I operationalize "feature" as the difference between two real model activations. As next steps we are planning to

My motivation for this project is to get data on computation in superposition, and to get dataset-independent evidence for (SAE-)features.

Core results & discussion

I run two different experiments that test the error correction hypothesis:

  1. Activation Plateaus: A real activation is the center of a plateau, in the sense that perturbing the activation affects the model output less than expected. Concretely: applying random-direction perturbations to an activation generated from a random openwebtext input (“real activation”) has less effect than applying the same perturbations to a random activation (generated from a Normal distribution). This effect on the model can be measured in KL divergence of logits (shown below) but also L2 difference or cosine similarity of late-layer activations.

    The left edge of this plot corresponds to no perturbation. As we move to the right, the orange lines (real base activation) "plateau" (the perturbation effect stays minimal for a while), while the blue lines (random base activation) rise immediately. Note that a distance of ~50 corresponds to the typical distance between two activation vectors. This plot shows the effect of changing the base activation; the perturbation direction is random in both cases.
  2. Sensitive directions: Perturbing a (real) activation into a direction towards another real activation (“poor man’s feature directions”) affects the model-outputs more than perturbing the same activation into a random direction. In the plot below focus on the size of the “plateau” in the left-hand side

    This plot varies the perturbation direction, the base activation is a real activation in both cases. The orange lines (other real activation direction) exhibit a shorter plateau than the blue lines (random direction). This means the real activation directions are more sensitive. To show why we usually sample random activations and directions from a mean- and covariance-adjusted distribution we show the naive (isotropic) random direction in purple, it behaves very differently and thus is not a good baseline.
    1. Naive random direction vs mean & covariance-adjusted random: Naive isotropic random directions are much less sensitive. Thus we use mean & covariance-adjusted random activations everywhere else in this report.
    2. The sensitive direction results are related to Gurnee (2024 [LW · GW], SAE-replacement-error direction vs naive random direction) and Lindsey (2024, Anthropic April Updates, SAE-feature direction vs naive random direction).

The theoretical explanation for activation plateaus & sensitive direction may be error correction (also referred to as noise suppression):

Proposal: Connecting SAEs to model behaviour

The leading theory for how concepts are represented in neural networks in superposition: We think that NNs represent information as a series of sparsely-active features, which are represented as directions in activation space.[3] Superposition allows this list of features to be much larger than the dimension of activation space, and has been demonstrated in toy models.

Sparse autoencoders (SAEs) are a method that can recover individual features from a dataset of features in superposition. SAEs are trained to convert activations into a list of sparsely-active individual features and back into activations with low reconstruction loss and high feature-sparsity. The training input for SAEs are model activations, typically generated by running the model on a dataset similar to its training data.

If SAE-features are features in the sense that computation in superposition toy models suggest, then they should show the same error correction properties we saw with real feature directions. Thus we predict

Why do I think this is a useful direction to study SAEs?

  1. There are weird effects [LW · GW] around how SAE features affect model behavior that we don't fully understand. Let's figure out what is going on and what we can learn!
  2. A (neglected?) failure mode of the SAE agenda is that SAE features could be an interpretability illusion in the sense that they do not represent the internal computation of the model but properties of the training dataset.
    1. I worry that SAEs find a feature only because a concept is frequent in the dataset rather than because the model uses the concept. (I discuss this in detail in a shortform post [LW(p) · GW(p)]). A dataset-independent way to find/confirm SAE features (even if non-competitive) would be great!
  3. I’d like to find a cheaper or more targeted way to find (some or all) model features.
    1. Training SAEs is expensive, and cost trades off against feature completeness. If we could take an individual prompt and find (all?) active features, this would be extremely useful for evaluations and interpretability research.

Conclusion

Summary: I run some experiments testing computation-in-superposition predictions on GPT2 activations, finding

  1. Plateaus around model activations, as if the model was error-correcting small perturbations
  2. Perturbing activations into the direction of other activations has more effect than random

I hope this research will allow us to understand computation in superposition better, and to connect behavioral properties of model activations to (SAE-)features.

Limitations: There may just be trivial explanations for results like these! Section 1 results really just say “GPT2 is weird if you go off distribution” (and happen to align with a theory prediction), but there could be lots of plausible explanations for this. Section 2 results are more specific, but still there might be simple explanations for this behavior (e.g. relevant properties of activation space beyond the covariance thing we noticed), and I would love to hear takes in the comments!

Future work: We are currently investigating these behavior properties for SAE-features, questions like

  1. Do SAE features behave as predicted by Toy Models of Computation in Superposition?
  2. Are there different types of SAE features? Atomic and composite [LW · GW] features?
  3. How do linear combinations of features behave? Does this give us a handle on the total number of features?

Call to action: This direction feels underexplored, I think there’s a lot of new data to be generated here! I’d love to hear from anyone considering working on this!

I also want to encourage feedback in the comments: Trivial explanations I missed? Past literature that explored this? Reasons why this direction might be less promising than I think?

Acknowledgements: We thank Dan Braun, Lee Sharkey, Lucius Bushnaq, Marius Hobbhahn, Nix Goldowsky-Dill, and the whole Apollo team for feedback and discussions of these results. We thank Wes Gurnee and Rudolf Laine for comments on a previous (March 2024) report on this project.

Appendix

Methodology

The experiments in this report focus around perturbing the residual stream of a model (via activation patching) and measuring the corresponding chance in model outputs (KL divergence and more).

All experiments use GPT2-small. Input are 10-token sequences taken from openwebtext (apollo-research/Skylion007-openwebtext-tokenizer-gpt2). We choose an early perturbation layer (blocks.1.hook_resid_pre). We read the results off at the logits (KL divergence of logprobs) or at a late layer (L2 difference of activations at blocks.11.hook_resid_post or ln_final.hook_normalized). We use only the last position index for perturbation and read-off.

Generating activations: We use model activations to measure activation plateaus, and to generate the perturbation directions for sensitivity tests. We consider 4 types of activations

All activation vectors have zero layer-mean (each activation has zero mean along the hidden dimension), but not zero dataset-mean (i.e. I mean-center in the same way as TransformerLens but the activation dataset mean is not the zero vector). I don’t fix the norm of activation vectors (yet).

Generating directions: In which direction to perturb the activations into. In most cases we generate an activation according to the list above and take the difference between it and the base activations to obtain a direction.

The real-other direction is a proxy for getting feature directions without having to rely on SAEs. The difference between two real activations should be a couple hundred features (about half of them negative) because each real activation should consist of a number (~L0) of features.

Perturbations: I perturb the base activation by adding α*direction to the activation vector. I do this in two ways:

Effect on model outputs: To measure the effectiveness of a perturbation I measure the KL divergence between the perturbed model output logits and the base output logits, the L2 difference and cosine similarity between the activations in a late layer taken on the base and perturbed runs.

Experiment types: I run three types of experiments. In all cases I measure the effect of activation perturbations (as a function of perturbation L2 distance α or angle θ) on model outputs (KL divergence, final-layer L2 distance).

Detailed results

1. Activation Plateaus

Intuition: Look at the model-output landscape when interpolating activations on the plane spanned by three real activations. The plot below shows the KL divergence (wrongly normalized, z-axis and color) for all activations on that plane. We see plateaus around the real activations (black vertical lines) with outputs changing less per shift in activations. This gives an intuitive picture; for the quantitative study we switch to a 1D version and switch from interpolation to perturbation into random directions.

We sample a series of base activations (random or real-other) and perturb the activations from there towards a series of random directions (as discussed above we either perturb in Straight or Sphere modes). Below we show the KL div as a function of perturbation angle (Sphere case) for both types. The real-other activations clearly exhibit the plateau phenomenon—the KL div barely changes until the perturbation reaches 40°—while random activations do not follow this pattern.

Straight mode (perturbing straight into a direction):

Sphere mode (perturbing while keeping norm constant – the change between this and the plot above is due to straight/sphere mode, the seed does not have a big effect):

2. Sensitive directions

Now we perturb a given (real) base activation into different kinds of directions. This is different from the previous experiment where we applied the same (random) perturbation to different base activations.

I take a given real base activation (seed / prompt shown in plot titles) and perturb it into a random direction (sample a random other activations with appropriate covariance matrix, and define direction as difference between new activation and base activation) or real-other (sample new activation by running random openwebtext sequence through model, then take difference as direction). I normalize the directions to have the same norm, and observe the effect on the model (KL div) as a function of angle (Sphere mode) or perturbation size (Straight mode). In all cases the real-other directions appear to be more sensitive, jumping up at a lower angle and lower perturbation distance.

Straight mode (perturbing straight into a direction):

Sphere mode (perturbing while keeping norm constant):

Other seeds (base prompts) looked qualitatively similar.

A brief investigation shows that we can find metrics, such as “at what angle does the KL divergence exceed 0.5” to reasonably distinguish the two classes of perturbation, though I think that the curves above look more distinguishable than suggested by the histograms below. (This may be an optical illusion, or show that I haven’t spent time finding the optimal classifier.)

3. Local optima in sensitivity

real-other directions are more sensitive than random directions. We think this is because they focus perturbations into a small (~L0) number of feature directions, reaching the hypothetical error correction threshold earlier.

We conjecture that, if we could perturb activations into a single feature direction, the perturbation would be even more focused and reach the error correction earlier (concretely: the perturbation distance required to reach KL-div=0.5 would be lower). This is compatible with Lindsey (2024)’s observations that SAE directions are unusually sensitive (though they did not compare to real-other or combinations of SAE directions). If that was true, we might be able to find SAE directions as local maxima of sensitivity: A perturbation into 1*feature direction should be more sensitive than a perturbation into 0.99*feature direction + some other direction.

As a precursor to this investigation we investigate the sensitivity of various directions between real-other and random directions. In the plot below, every corner corresponds to a random direction (except for the top corners in the 2nd row, these correspond to real-other directions). And the color corresponds to the perturbation size (measured in Sphere mode, as angle) required to reach 0.5 KL divergence (so every point corresponds to a direction, and we run a scan over perturbation lengths on every point). The non-corner points correspond to interpolated directions (precisely: we interpolate the targets before calculating the direction). This shows us whether “nearby” directions are similarly precise as the exact real-other direction.

The upper row is a sanity-check, interpolating between 3 random directions. We expect the plot to be symmetric. The lower row is an interpolation between a real-other direction (top) and two random directions (bottom corners). We see, as expected, the top corner appears to be a local optimum of sensitivity:

While these plots initially seem to suggest a local optimum at the top corner (2nd row), they only test two (random) directions in 768d space. If real-other directions consist of ~L0 number of features, and the previous hypothesis is true, we expect there to be an L0-dimensional space in which the direction is not a local optimum. We plan to continue these investigations in future work.

  1. ^

    Empirically this is a bit messy: Inputting a random direction into an SAE activates between 10 and 20000 features (lognormal distribution with a peak around 30). But that is using the encoder, I'm not sure if I should be doing that.

  2. ^

    The real-other direction is expected to turn on some features, but also to dampen existing features. My explanation focuses on turning on inactive features, and ignores the slight dampening of active features.

  3. ^

    This is not fully true—we know some directions represent non-sparse positional features, and there is information in the geometry of features [LW · GW]—but we leave this aside for now.

2 comments

Comments sorted by top scores.

comment by LawrenceC (LawChan) · 2024-07-06T02:24:51.611Z · LW(p) · GW(p)

Good work, I'm glad that people are exploring this empirically.

That being said, I'm not sure that these results tell us very much about whether or not the MCIS theory is correct. In fact, something like your results should hold as long as the following facts are true (even without superposition):

  • Correct behavior: The model behavior is correct on distribution, and the correct behavior isn't super sensitive to many small variations to the input. 
  • Linear feature representations: The model encodes information along particular directions, and "reads-off" the information along these directions when deciding what to do. 

If these are true, then I think the results you get follow:

  1. Activation plateaus: If the model's behavior changes a lot for actual on-distribution examples, then it's probably wrong, because there's lots of similar seeming examples (which won't lead to exactly the same activation, but will lead to similar ones) where the model should behave similarly. For example, given a fixed MMLU problem and a few different sets of 5-shot examples, the activations will likely be close but won't be the same, (as the inputs are similar and the relevant information to locating the task should be the same). But if the model performs uses the 5-shot examples to get the correct answer, its logits can't change too much as a function of the inputs. 

    In general, we'd expect to see plateaus around any real examples, because the correct behavior doesn't change that much as a function of small variations to the input, and the model performs well. In contrast, for activations that are very off distribution for the model, there is no real reason for the model to remain consistent across small perturbations. 
  2. Sensitive directions: Most directions in high-dimensional space are near-orthogonal, so by default random small perturbations don't change the read-off along any particular direction by very much. But if you perturb the activation along some of the read-off directions, then this will indeed change the magnitude along each of these directions a lot! 
  3. Local optima in sensitivity: Same explanation as with sensitive directions. 

Note that we don't need superposition to explain any of these results. So I don't think these results really support one model of superposition via the other, given they seem to follow from a combination of model behaving correctly and the linear representation hypothesis. 

Instead, I see your results as primarily a sanity-check of your techniques for measuring activation plateaus and for measuring sensitivity to directions, as opposed to weighing in on particular theories of superposition. I'd be interested in seeing the techniques applied to other tasks, such as validating the correctness of SAE features. 

Replies from: Stefan42
comment by StefanHex (Stefan42) · 2024-07-06T09:34:03.948Z · LW(p) · GW(p)

Thanks for the comment Lawrence, I appreciate it!

  • I agree this doesn't distinguish superposition vs no superposition at all; I was more thinking about the "error correction" aspect of MCIS (and just assuming superposition to be true). But I'm excited too for the SAE application, we got some experiments in the pipeline!
  • Your Correct behaviour point sounds reasonable but I feel like it's not an explanation? I would have the same intuitive expectation, but that doesn't explain how the model manages to not be sensitive. Explanations I can think of in increasing order of probability:
    • Story 0: Perturbations change activations and logprobs, but the answer doesn't change because the logprob difference was large. I don't think the KL divergence would behave like that.
    • Story 1: Perturbations do change the activations but the difference in the logprobs is small due to layer norm, unembed, or softmax shenanigans.
      • We did a test-experiment of perturbing the 12th layer rather than the 2nd layer, and the difference between real-other and random disappeared. So I don't think it's a weird effect when activations get converted to outputs.
    • Story 2: Perturbations in a lower layer cause less perturbation in later layers if the model is on-distribution (+ similar story for sensitivity).
      • This is what the L2-metric plots (right panel) suggest, and also what I understand your story to be.
      • But this doesn't explain how the model does this, right? Are there simple stories how this happens?
      • I guess there's lots of stories not limited to MCIS,  anything along the lines of "ReLUs require thresholds to be passed"?

Based on that, I think the results still require some "error-correction" explanation, though you're right that this doesn't have to me MCIS (it's just that there's no other theory that doesn't also conflict with superposition?).