Extracting SAE task features for in-context learning

post by Dmitrii Kharlapenko (dmitrii-kharlapenko), neverix, Neel Nanda (neel-nanda-1), Arthur Conmy (arthur-conmy) · 2024-08-12T20:34:13.747Z · LW · GW · 1 comments

Contents

  TL;DR
  Prior work
  Initial Phi-3 experiments
    Direct SAE task vector reconstruction
    Task-related SAE features
    Task vector SAE cleanup
    Early layer task vector cleanup
  Further experiments on Gemma 1 2B
    Comparison with task vectors
    Task features
    Task detector features
  Next steps
  How gradient-based task vector cleanup can be used
  Acknowledgements
None
1 comment

TL;DR

This work was produced as part of the ML Alignment & Theory Scholars Program - Summer 24 Cohort, under mentorship from Neel Nanda and Arthur Conmy.

Prior work

Task or function vectors are internal representations of some task that LLMs form while processing an ICL prompt. They can be extracted from a model running on a few-shot prompt and then be used to make it complete the same task without having any prior context or task description.

Several papers (Function vectors in large language models, In-Context Learning Creates Task Vectors) have proposed different ways to extract those task vectors. They all center around having ICL examples being fed to a model in the form of “input <separator> output, … ” and averaging the residuals on the “separator” token over a batch. This approach can reconstruct some part of the ICL performance but does not admit a straightforward conversion to the SAE basis.

ITO with gradient pursuit [? · GW] can be used to do a sparse coding of a residual vector using SAE features. The post suggests using this algorithm for steering vector SAE decomposition. Since task vectors can be thought of as steering vectors, ITO may provide some insight into the ways they operate.

Initial Phi-3 experiments

Direct SAE task vector reconstruction

In our study we trained a set of gated SAEs for Phi-3 Mini 3.8B using a model-generated synthetic instruction dataset.

While offering a sparse dictionary decomposition of residuals, SAEs tend to introduce a reconstruction error that impacts the performance of the model. They also have no guarantee to be able to decompose out-of-distribution vectors, and task vectors being a product of averaging activations across prompts and tokens may be the case of such vectors.

Thus, we first studied the performance of SAE reconstructions of task vectors in transferring the definition of two tasks: 1) antonym generation and 2) English to Spanish word translation. These and other tasks used to study task vectors were taken from the ICL task vectors paper github repository

These charts show the NLL loss of the model on the evaluation set of zero-shot prompts for both of the tasks depending on the layer of extraction/insertion.

It can be seen from charts that SAE reconstruction significantly decreases the performance of the task vectors. We speculate that English to Spanish is a more complex task for the model because task vectors do not work until layer ~17 compared to ~13 in the case of antonyms and thus is only represented at a more abstract level in the model. Indeed, the chart shows that SAE reconstruction almost completely destroys the task vector performance for this task.

Similar results can be seen in all of the other tasks that we tried.

After seeing that the SAE reconstructions of task vectors do keep some of the performance, we assumed that SAE features are capable of capturing the internal task representations to some extent. And this does seem to be the case!

Here are the some of the max-activating examples for some of the SAE features in the antonyms task vector:

We also found that using just a single one of those task related features as a steering direction recovers some of the performance of the original task vector and almost matches the performance of the reconstruction. We also found that deeper layers (~18) tend to have more interpretable task-related features, even if task vectors start reliably working much earlier (layers 12-14).

Unfortunately, these task-related features were mostly buried among dense or “repeated pattern”-related features in the reconstructions. The top (and therefore first picked by gradient pursuit) ~30 features were dense in some examples.

Task vector SAE cleanup

This motivates the creation of a method that can reinforce the features from the reconstruction that have a significant impact on task performance and remove those that do not. We introduce a method called sparse SAE task vector finetuning. A diagram of the method can be seen below.

Loss-guided SAE task vector cleanup

At its core, this method takes a set of zero-shot prompts and fine-tunes a residual activation addition vector at some layer with the objective of reducing the NLL loss on these prompts. However, instead of optimizing the vector directly, the method optimizes the weights of its SAE reconstruction regularized by a L1 norm of these weights. The key point is that we initialize these weights not with random values (this did not work) but with the SAE reconstruction of the task vector at the corresponding layer. And this works surprisingly well!

These charts contain the NLL loss on the same two tasks as the previous figure with task vectors vs cleaned task vectors depending on different extraction/insertion layers. The loss of the model on a multishot ICL prompt (used to create the task vector) is also provided as reference, although it was calculated on another set of prompts.

We can see that task vector cleaning not only preserves the original task vector performance, but even surpasses it. Another interesting result is that this cleanup method (surprisingly) allows task vectors to work on much earlier layers. It also drops the amount of SAE features to study from ~100 non-zero ones in the SAE reconstruction to just ~8-15. And a large portion of these features are actually task-related (as judged by us).

These charts show normalized features of the SAE task vector reconstructions (blue) vs cleaned task vector features (red + green). Although this method still keeps some of the dense features (and their removal significantly impacts the performance), it greatly boosts the task-related features and removes a significant portion of the useless features. These are some other task-related layer 18 features found using this method:

Cleaned task vectors also seem to offer better generalization than original task vectors. For example, for a sentence-to-sentence English to Spanish translation task, the single-word English to Spanish translation task vector gives 18.125 NLL loss, but we get 14.68 with a version of the same task vector cleaned with our method (sentence pairs were taken from the Helsinki Opus dataset).

Early layer task vector cleanup

Although the cleaned task vectors are more capable of transferring the context on much earlier layers than the original task vectors, the features on these layers tend to be much less interpretable. We tried studying this a bit more in-depth by first cleaning up a Layer 4 task vector, then inserting it at the lower layer, and then looking at the residuals at Layer 18:

Early layer task vector cleanup experiment 

We found that it results in L18 task vectors that share a lot of common features with basic L18 task vectors, while having increased relative weights of task-related features[1]. This suggests that the method may actually be exploiting the internal model task vector formulation mechanism and not just finding some fine-tuning that works.

Further experiments on Gemma 1 2B

Comparison with task vectors

For our Gemma 2B experiments we also trained a full SAE suite using the FineWeb dataset.

Losses on cleaned/original task vectors 

We can see that cleaned task vectros in Gemma 2B have either the same or higher performance as task vectors (regular SAE reconstruction either decreases the performance or just loses it completely).  

Number of positive feature after cleaning

Cleaning leaves 5-15 non-zero features on layer 12 for most of the tasks.

Task features

We again notice that cleaned task vectors always contain interpretable task-related features. Often these features directly define the task, i.e. features from antonyms task vectors activate right before the model predicts the second antonym, or features from Spanish to English translation activate right before the model starts translating to English. Below are examples of some of those features.

 person <separator> profession feature
antonym1 <separator> antonym2 feature
“predict past simple next” feature
“meaning in English” feature

We believe these features encode the action that the model should perform next. Task features from tasks like present_simple_past_simple reinforce this hypothesis. Their max activations come not from present simple -> past simple pairs, but just from tokens after which it is expected to have a verb in past tense. This is why we believe they are closer to “predict past simple next” features than just present simple to past simple task features. This may be also caused by the fact that the dataset for max activating examples collection did not contain such pairs.

Going further, just steering with these task features alone also recovers some of the task vector performance, and only for related tasks.  For that, we use a zero-shot prompt with no explanation, and just add the scaled feature direction to the “->” (separator) token.

Steering with antonyms task feature 

Using a scale of 20 we can see how steering with different task features affects different tasks. The chart below shows a relative loss change (with no steering taken as a base) for pairs of different tasks and task features.

Effect of task features on different tasks
Maximum increase clipped to 1 (from 3.5) 

It is possible to notice that most of the tasks here have just a single feature with high relative loss decrease. Other features either do not affect the loss at all or even increase it by several times. The exceptions are translation from English tasks, since the feature selected for them is an “English to foreign language translation” feature. In cleaned task vectors it was accompanied by features that activate on any foreign texts. We believe they are used to encode the target language.

Another notable example is the feature 13458 that improves loss on “person to some property” tasks. This may mean that it has a broader use than just recalling someone’s profession in the model.      

Task detector features

Using an algorithm similar to Sparse Feature Circuits (to be published in a future blog post), we found a second type of task-related features. These features seem to play an important role in the formation of task features. We call them task detector features. We hypothesize that they do exactly what their name implies: detect completed tasks in ICL prompts or just generic texts.

Layer 11 antonyms detector
Layer 11 English to foreign language switch detector

Furthermore, we can often use the task vector cleaning algorithm to find these features. For that, we first collect the task vector from the output tokens. Then we modify the zero-shot prompt to a one-shot prompt with a generic example “X -> Y” (literally). Finally, we optimize the loss after steering on the token “Y” at the layer of collection. From our experience, task vectors collected in this fashion do not usually work, while the cleaned ones often do. 

Losses on cleaned/original detector task vectors 

Cleaned task detector vectors seem to have a bigger impact on earlier layers than cleaned task vectors. That is expected, since task detector features can be used to construct later task features. We also note that not all features that we classify as detectors are exact task detectors. For example, a detector feature that is used in the “country to capital” task activates on capitals of different countries without any mention of the said countries. This also suggests that task features often encode tasks in a broader sense, i.e. “predict the most appropriate capital in this context”, and not just “predict the capital of this country”.

Layer 11 capital detector feature

We can also plot relative loss change for pairs of different detectors and tasks. This time we will use a scale of 25.

Relative loss change after steering with detector feature on layer 11

We can again see that these features have a noticeable causal effect on the task completion.

Next steps

Our current results suggest three directions for further investigation:

  1. How do task detectors form?
  2. How are they converted into task features?
  3. How do task features manipulate downstream execution?

We expect directions 1 and 3 to have a lot of task-specific mechanisms, and thus results from them may not generalize to a broad range of tasks. Meanwhile, direction 2 seems to be focused on interactions between feature types common to most tasks. Therefore, we are currently working on it by looking at task features with sparse feature circuits.

We may also try training our SAE features on datasets that include many more variations of ICL tasks that can distinguish features relevant for only specific tasks.

How gradient-based task vector cleanup can be used

Task vector cleanup is most promising as a way of discovering task-related SAE features, and it can specifically help find causally relevant ones without circuit discovery. It can also be seen as a fine-tuning method, and it shares some of the flaws: ease of implementation (moderate, with hyperparameter tuning needed for learning rate and sparsity coefficient) and faithfulness to the original behavior or task vector (low). It is, however, regularized: the vector is restricted to a 10-100-dimensional subspace defined by the SAE features and is encouraged to be sparse.

The method described is also somewhat similar to 1-dimensional distributed alignment search. There is a potential interpretability illusion here similar to An Interpretability Illusion for Subspace Activation Patching. Our primary method of evaluating sparse task vector fine-tuning is looking at the loss (steering metric) and presence of interpretable features (max-activating tokens being a loose estimate of correlation with a target concept). One experiment to test for this would be hand-picking interpretable features, tuning weights only and evaluating the resulting vector’s loss.

Acknowledgements

This work was produced during the research sprint of Neel Nanda’s MATS training program. We thank Neel Nanda as our mentor and Arthur Conmy as our TA. We thank McKenna Fitzgerald for research management and feedback. We are grateful to Google for providing us with computing resources through the TPU Research Cloud.

  1. ^

     Top SAE features of L18 TV before: 1065, 33278, 48437, 24363, 46421, 8787, 33045, 26164, 14577, 46849, 33039, 46370, 39179, 47882, 28188, 33953, 22038, 28517, 16422, 11831, 34138, 16805, 43555, 21885, 9224

    Top SAE features of L18 TV after: 1065, 33039, 11831, 46849, 24363, 34138, 22038, 48437, 33953, 6188, 8023, 48299, 4483, 21885, 5405, 18196, 14577, 46370, 25746, 31738, 39233, 28517, 41370, 8787, 43555

1 comments

Comments sorted by top scores.

comment by Clément Dumas (butanium) · 2024-08-13T09:58:12.265Z · LW(p) · GW(p)

Nice work!

I'm curious about the cleanliness of a task vector after removing the mean of some corrupted prompts (i.e., same format but with random pairs). Do you plan to run this stronger baseline, or is there a notebook/codebase I could easily tweak to explore this?