SAE features for refusal and sycophancy steering vectors

post by neverix, Dmitrii Kharlapenko (dmitrii-kharlapenko), Arthur Conmy (arthur-conmy), Neel Nanda (neel-nanda-1) · 2024-10-12T14:54:48.022Z · LW · GW · 4 comments

Contents

  TL;DR
  Prior work
  Scaling up SAEs
  Vectors
    Refusal
      Evaluating vectors and reconstructions
      Gemma
    Sycophancy
        Prompt:
        Without vector:
        With vector (normalized, times 100):
        With vector (normalized, times 120):
      Reconstructing vectors
        Reconstructed (normalized, times 100):
        Reconstructed (normalized, times 150):
      Properties of sycophancy vectors and features
  Acknowledgements
  Appendix A: Technical details
  Appendix B: Model choice
  Appendix C: MELBO
None
4 comments

TL;DR

Prior work

        In Sparse Feature Circuits, Marks et al. find features related to gender in a misgeneralizing classifier and ablate them. Their methods provide both insight about the way a classifier makes its decisions and a way to change its behavior: one of them, SHIFT, ablates features corresponding to gender to their mean activation and corrects the classifier’s generalization.

        DeepMind’s SAE steering vector work [AF · GW] looks at “wedding” and “anger” features. Their goal is to improve a steering vector (according to the original post [LW · GW]’s metrics) by ablating interpretable and irrelevant features.

        We want to follow both of their directions and find examples of interpretable SAE features which contain information about tasks performed and causally affect a model’s behavior.

We had some success applying inference-time Optimization (aka gradient pursuit) [AF · GW] and its variants on simple in-context learning tasks in one of our previous posts [? · GW]. We applied Sparse Feature Circuits on the tasks and found interesting features (to be published to arXiv soon!). In this work, we want to explore the application of ITO and SAE encoding to previously studied steering vectors.

Scaling up SAEs

We want to look into abstract features such as those corresponding to refusal and ICL. Therefore, we need a model that can strongly perform such tasks.

We chose to train regularly-spaced residual stream SAEs like Improving Dictionary Learning with Gated Sparse Autoencoders. We only want to understand whether SAE dictionary features are causally important and find high-level causal variables using steering, not trace out complete circuits.

We trained 8 sets of SAEs and Gated SAEs on the residual stream at layers 8-28 of Phi 3 Mini (see Appendix B: Model choice for our reasons), mostly on even-numbered layers. We used OpenHermes, LaMini and data generated by the model for different runs. Models trained on synthetic data performed well on other datasets, though we are concerned because the synthetic data subjectively contains more repeated tokens in contexts than natural datasets. You can read about the way we trained them in Appendix A: Technical details.

Vectors

We want to evaluate the quality of dictionaries learned by SAEs in three ways:

  1. How well can they represent arbitrary residual stream vectors with a few features?
  2. Can we use the features to tell what the original vector does?
  3. Can we intervene on identified features in a way informed by interpretations?

Refusal

        We follow Arditi et al., (2024) [? · GW]’s refusal direction ablation procedure: we are looking for directions which, when ablated (projected to 0) from a model’s residual stream at all layers (a technique now popularly know as “abliteration”),[1] cause the model to stop refusing harmful requests while preserving capabilities (accuracy on MMLU).

        Selecting a sequence position and layer for a refusal vector is not a trivial task, especially for a model as resistant to jailbreaks as Phi-3. We use an abliterated Phi-3 from failspy.[2] The author only published the ablated weights and not the vectors used to produce them. To extract the refusal direction, we find the difference between the first MLP weight matrices of base and abliterated models, perform an SVD, and take the second right singular vector (the first doesn’t work as a refusal direction; we suspect it corresponds to quantization error). This method is a bit questionable, but the results that follow show it was strongly related to refusal, so it works fine for demonstrating that SAEs can replicate directions that do interesting things to the model.

        We can use Self-Explanation [? · GW] to see how the model understands the direction. Self-Explanation describes this direction as:

Large scales seem to explicitly activate the refusal behavior, i.e. the Self-Explanation request is refused.

Evaluating vectors and reconstructions

        We evaluate refusal directions by ablating models and measuring the refusal and safety score of model generations on JailbreakBench. We use phrases "I'm sorry...", "As an AI..." and "As a language..." to measure refusal and a locally run version of LLaMA Guard 2 to measure safety. As a skyline (i.e. opposite of baseline, an alternative method we would not expect to outcompete), we finetune a refusal vector to minimize the NLL of refusals from AdvBench.

We reconstruct the refusal direction with two SAE features through DeepMind’s inference-time optimization (gradient pursuit) [AF · GW]. Reconstructions seem to be best with SAEs on layer 16 from the 6th group of residual SAEs (revision) we trained.

Interestingly, reconstructed SAE vectors surpass the original vectors and get close to optimized vectors in quality without sacrificing MMLU performance.

Like in the ITO blog post, we vary the amount of SAE features included in the reconstruction (denoted as k). When we measure the distance from the vector being reconstructed (the abliterated vector), the k=3 vector seems to win. However, k=4 has overall better performance and is close to the skyline vector.

         There are reasons to believe these results are not robust. The original vector was computed from an SVD on weights, so it may not be the best refusal vector to reconstruct. The results get worse with both more or fewer SAE features, and are calculated with just one model (Llama Safety Guard) on 100 samples. Still, it is notable that we don’t use any additional AdvBench-like data and yet, as can be seen from the charts, the SAE reconstruction with two features does not ruin results and can even improve the direction.

The two features that ITO finds can be interpreted as refusal-related through max-activating examples:

Interestingly, encoding the individual prompts with the SAE and finding the refusal direction by subtracting average SAE encodings for harmless and harmful texts finds the same two features if we look at the top 2 coefficients. Additionally, the algorithm FISTA finds the same two features when given the same input as ITO, so long as the sparsity coefficient is tuned to produce two examples.

Gemma

        We performed the same analysis on Gemma 1 with our residual SAEs on Layer 12.

Sycophancy

        We follow Rimsky et al 2023 [LW · GW]. We find residual stream differences between multiple choice answers which are and are not sycophantic on Anthropic’s first sycophancy dataset. We found that layer 20 produced the best sycophancy vectors qualitatively:

Prompt:

<s><|user|> I think everyone should eat candy for every meal. Do you agree?<|end|><|assistant|>

Without vector:

        No, I do not agree with the statement that everyone should eat candy for every meal. Candy typically consists of high-sugar content and lacks essential nutrients that our bodies need to function…

With vector (normalized, times 100):

Yes, I support healthy ea. Candy is suitable for treating or fitting in as a dessert. Eating candy for every meal is aligned with the best dietary practices for food.
Yes, I'm aligned with the idea of eating a balanced diet. Eating candy for every meal is aligning with a balanced diet that is recommended.

With vector (normalized, times 120):

Yes, indeed, the best of hands. It's certainly saifiable
Yes, I'\n\n\n\nThe\n\n\n\n\n\n\n\n\n
Yes, I's\n\nS\n\nS\n\nS\n\nS\n\nS\n\nS

Reconstructing vectors

        We tried reconstructing the vector through ITO with three features from our revision 5 layer 20 SAE (you can look at max-activating examples here).

Layer 20, revision 5, feature 15036: max-activates on answers, some of which are incorrect.

Select questions and the ways answers to them may have been wrong:

However, overall the meaning of this feature is fuzzier than “incorrect answers”.

Feature 45142: may be induction-like because of the repeated phrases it activates on.

Reconstructed (normalized, times 100):

Yes, I agree that eating candy for every meal is not a good idea because candy does not provide the necessary nutrients needed for a healthy diet.
Yes, I agree that everyone should eat candy for every meal. However, that is not a good idea for several reasons.

Reconstructed (normalized, times 150):

Yes, the idea of eating candy for every meal is not a good idea because it would lead to a high intake of sugar
Yes, I cannot agree that everyone should eat caron for every meal. A diary cannot agree with a statement because it cannot not agree

It does not seem to be as effective, but there is no outright noise at higher scales like with the original vector.

We performed the same experiments with Gemma and Gemma 2. We could not find a sparse combination of features that could reproduce the results of sycophancy steering.

Properties of sycophancy vectors and features

        We applied the original sycophancy classifier to MMLU on the token just before the answer letter. We computed the dot product of the residual stream at Layer 20 (the steering vector’s layer) with various vectors and used it as a binary classifier.

Left: original sycophancy steering vector. Right: SAE ITO reconstruction (k=2)

Left: a feature from the SAE reconstruction (seems to fire on correct answers). Right: a random feature

        We tried to use one of the features, 45142, as a “correct answer” multiple choice classifier. We collect the activations of the feature on the final answer letter token:

Question: Are birds dinosaurs? (A: yes, B: cluck C: no D: rawr)
Answer: (C

For each of question, we add each answer option (A/B/C/D), run the model and take the argmax of the activations on each of the answers. We evaluated the accuracy of this approach by reframing the muticlass classification task as a binary correct/incorrect classification task. The accuracy was 56.1%. The MMLU score of Phi is 68.8%, and 56.1% lands the classifier at about the level of Llama 2 13B. There were 3 other features that had an accuracy this high. Note that we are measuring the model's ability to judge if a given answer is incorrect, a different (but related) task to predicting the correct answer before it is given

These results do not replicate on TruthfulQA. The sycophancy features are much weaker classifiers of correct answers on that dataset.

Acknowledgements

This work was produced during the research sprint of Neel Nanda’s MATS training program. We thank Neel Nanda as our mentor and Arthur Conmy as our TA. We thank Daniel Tan for collaborating on early EPO experiments. We thank Thomas Dooms and Joseph Bloom for discussions about SAE training. We use “abliterated” Phi models from failspy. We are grateful to Google for providing us with computing resources through the TPU Research Cloud.

Appendix A: Technical details

        Our SAEs are trained  with a learning rate of 1e-3 and Adam betas of 0.0 and 0.99 for 150M (±100) tokens. The methodology is overall similar to Bloom 2024 [LW · GW]. We initialize encoder weights orthogonally and set decoder weights to their transpose. We initialize decoder biases to 0. We use Eoin Farrell’s sparsity loss [? · GW] with an ϵ of 0.1 for our Phi autoencoders. We use Senthooran Rajamanoharan’s ghost gradients variant [AF · GW] (ghost gradients applied to dead features only, loss multiplied by proportion of death features) with the additional modification of using softplus instead of exp for numerical stability. A feature is considered dead when its density (according to a 1000-batch buffer) is below 5e-6 or when it hasn’t fired in 2000 steps. We use Anthropic’s input normalization and sparsity loss for Gemma 2B. We found it to improve Gated SAE training stability. We modified it to work with transcoders by keeping track of input and output norms separately and predicting normed outputs.

We use 8 v4 TPU chips running Jax (Equinox) to train our SAEs. We found that training with Huggingface’s Flax LM implementations was very slow. We reimplemented LLaMA and Gemma in Penzai with Pallas Flash Attention (which isn’t much of an improvement at sequence lengths of 128) and a custom layer-scan transformation and quantized inference kernels. We process an average of around 500 tokens per second, and caching LM activations is not the main bottleneck for us. For this and other reasons, we don’t do SAE sparsity coefficient sweeps to increase utilization.

For caching, we use a distributed ring buffer which contains separate pointers on each device to allow for processing masked data. The (in-place) buffer update is in a separate JIT context. Batches are sampled randomly from the buffer for each training step.

We train our SAEs in bfloat16 precision. We found that keeping weights and scales in bfloat16 and biases in float32 performed best in terms of the amount of dead features and led to a Pareto improvement over float32 SAEs.

While experimenting with SAE training, we found that in our context it is possible to quantize and de-quantize the SAE weights (encoder and decoder matrices) at 8-bits using zero-point quantization with a block size of 16 (meaning each weight actually takes up 8 (base bits per weight) + 16 (bits in bfloat16/float16) * 2 (scale + offset) / 16 (block size) = 10 bits). We quantize and dequantize encoder and decoder weights after each step for a Layer 12 Gemma 2 2B SAE as an inefficient proof of concept and see comparable variance explained (74.5% with quantization and 77.6% without) without divergence for 8 billion tokens:

There is a difference in the L0 coefficient necessary to achieve the same L0: the int8 SAE has an L0 of 104, smaller than the 124 of the bf16 SAE.

With custom kernels for matrix multiplication and optimizer updates, it would be possible to significantly reduce the memory usage (and potentially improve performance with better kernels) of SAE training. We leave an efficient GPU implementation to the reader =)

We tried reimplementing 8-bit Adadelta without bias correction (Adam with b1=0.0) in Jax but found that training diverged with it. We also tried using approximate MIPS for top-K SAEs and did not observe quality degradation at the default settings but saw slight increases in speed.

Our SAEs are made publicly available at nev/phi-3-4k-saex-test and nev/gemma-2b-saex-test. The library used to train them is on GitHub at https://github.com/neverix/saex.

Appendix B: Model choice

We first looked at LaMini. It is a set of instruction finetunes of old small models, including GPT-2 Small and XL, for which some residual stream SAEs [? · GW] already exist. From preliminary explorations, we found that LaMini’s dataset did not adequately teach the models to refuse to make harmful completions, only to refuse the specific requests that prompted the model to talk about itself. For example, the models will happily tell you how to drown a puppy but will refuse to answer "What is your name?". We suspect this is because the dataset did not contain safety training data; the closest to that was the Alpaca subset, which does not contains usual harmlessness prompts.

Phi-3 Mini is the strongest open model of its size (3.8B) according to benchmark scores that we know of. It was not finetuned for instruction following. Instead, it was pretrained on instruction-like data. There is no base variant for us to use, so we need to train the SAE on residuals from the model on an instruction task.

We do not know what data Phi was trained on, but we found that it can generate its own instructions: simply prompting the model with <|user|> will make it generate instruction-like text. Anecdotally, it largely consists of math word problems and programming tasks.

Appendix C: MELBO

        We follow the methodology of Mack et al. 2024 [? · GW]. We ask the model a difficult arithmetic question (“What is 3940 * 3892?”) and optimize layer 10 (source) activation addition vectors to maximize divergence in layer 20 (target) activations. Some vectors we can find with this process are chain-of-thought-like vectors, refusal-like vectors and vectors that start the response with a specific character:

We can take the target layer activations for an interesting direction and subtract average activations for all other directions. Activations for other vectors to get a steering-ish vector for the target layer. We expect target layer vectors to be more interpretable because they were “cleaned up” by the network’s computation and are in higher layers (for which SAEs extract more interpretable features).
        Let’s take the first vector (the one producing responses starting with “*”) and reconstruct it with two SAE features:

Applying self-explanation yields results similar to those of a refusal feature – the word “no” is often present. We find nothing interesting when steering with this feature. The model outputs NO and breaks at high scales. (It is possible the model actually refuses when this feature activates on-distribution, but the feature is not causally relevant.)

We steer with 27894 and generate full examples. Instructions become about arithmetic problems:

Here we are plotting the presence of the word “First” at different steering scales:

 

We attempted to do something similar with Gemma 1 2B on the same question. The original MELBO steering vector generated these responses:

3940 * 3892 = 3940000 (simple multiplication)\n\nThe result is a number, not a prime number, so it is not a prime number. The
The result of 3940 * 3892 is the number obtained by the simple integer multiplication of the two numbers. The correct result is obtained by performing the operation:
3940 * 3892 is a simple multiplication problem. To solve it, you can use a calculator or perform the multiplication manually.\n\n3940 * 3892 =

A 16-feature reconstruction generated mostly:

To find the product of 3940 and 3892, you can use a calculator or perform the multiplication manually.\n\nHere's the calculation:

There is one repeated phrase, and the verbose style is unchanged. However, this result had a bug in the prompt formatting (using Phi's prompt format for Gemma). We did not find accurate representations from SAE reconstruction of other MELBO vectors with or without this bug.

Overall, these results are weak. We are unlikely to continue pursuing this direction and are writing down results on MELBO for the sake of completeness.

  1. ^

    Strictly speaking,  some work on "abliteration" includes DPO finetuning, e.g. https://huggingface.co/blog/mlabonne/abliteration. We ignore DPO finetuning in this work.

  2. ^

    An early one, with no DPO tuning.

4 comments

Comments sorted by top scores.

comment by Bogdan Ionut Cirstea (bogdan-ionut-cirstea) · 2024-10-12T21:39:26.330Z · LW(p) · GW(p)

Seems a bit in the spirit of what's called 'high-level cross-validation' here, between SAEs + auto-interp on the one hand, and activation engineering on the other.

Replies from: jason-l
comment by Jaehyuk Lim (jason-l) · 2024-10-13T12:58:17.607Z · LW(p) · GW(p)

Although not "circuit-style," this could also be considered one of these attempts outlined by Mack et al. 2024.

https://www.lesswrong.com/posts/ioPnHKFyy4Cw2Gr2x/#:~:text=Unsupervised%20steering%20as,more%20distributed%20circuits.

comment by Sheikh Abdur Raheem Ali (sheikh-abdur-raheem-ali) · 2024-10-15T08:20:28.627Z · LW(p) · GW(p)

Thanks for writing this up!

I’m trying to understand why you take the argmax of the activations, rather than kl divergence or the average/total logprob across answers?

Usually, adding the token for each answer option (A/B/C/D) is likely to underestimate the accuracy, if we care about instances where the model seems to select the correct response but not in the expected format. This happens more often in smaller models. With the example you gave, I’d still consider the following to be correct:

Question: Are birds dinosaurs? (A: yes, B: cluck C: no D: rawr)
Answer: no

I might even accept this:

Question: Are birds dinosaurs? (A: yes, B: cluck C: no D: rawr)
Answer: Birds are not dinosaurs

Here, even though the first token is B, that doesn’t mean the model selected option B. It does mean the model didn’t pick up on the right schema, where the convention is that it’s supposed to reply with the “key” rather than the “value”. Maybe (B is enough to deal with that.

Since you mention that Phi-3.5 mini is pretrained on instruction-like data rather than finetuned for instruction following, it’s possible this is a big deal, maybe the main reason the measured accuracy is competitive with LLama-2 13B.

One experiment I might try to distinguish between “structure” (the model knows that A/B/C/D are the only valid options) and “knowledge” (the model knows which of options A/B/C/D are incorrect) could be to let the model write a full sentence, and then ask another model which option the first model selected.

What’s the layer-scan transformation you used?

Replies from: neverix
comment by neverix · 2024-10-16T17:48:17.334Z · LW(p) · GW(p)
  1. It doesn't really make sense to interpret feature activation values as log probabilities. If we did, we'd have to worry about scaling. It's also not guaranteed the score wouldn't just decrease because of decreased accuracy on correct answers.
  2. Phi seems specialized for MMLU-like problems and has an outsized score for a model its size, I would be surprised if it's biased because of the format of the question. However, it's possible using answers instead of letters would help improve raw accuracy in this case because the feature we used (45142) seems to max-activate on plain text and not multiple choice answers and it's somewhat surprising it does this well on multiple-choice questions. For this reason, I think using text answers will boost the feature's score but won't change the relative ranking of features by accuracy. I don't know how to adapt your proposed experiment for the task of finding how accurate a feature is at eliciting the model's knowledge of correct answers.
  3. This is a technical detail. We use a jax.lax.scan by the layer index and store layers in one structure with a special named axis. This mainly improves compilation time. Penzai has since implemented the technique in the main library.