Decoding intermediate activations in llama-2-7b

post by Nina Rimsky (NinaR) · 2023-07-21T05:35:03.430Z · LW · GW · 3 comments

Contents

  Analyzed example
  Informed activation manipulation
  The country layer
  Future work
    Application to oversight
    Application to deception 
None
3 comments

Produced as part of the SERI ML Alignment Theory Scholars Program - Summer 2023 Cohort

Existing research, such as the post interpreting GPT: the logit lens [LW · GW] and related paper Eliciting Latent Predictions from Transformers with the Tuned Lens, has shown that it is possible to decode intermediate states of transformer activations into interpretable tokens. I applied a similar technique to llama-2-7b, a 7-billion-parameter decoder-only transformer language model - you can find my code here (file with just relevant reusable code). 

Unlike in the Eliciting Latent Predictions paper, I decoded not only the output of a full transformer layer but also the intermediate outputs post-MLP and post-attention-mechanism before and after merging into the residual stream. 

This diagram shows the points at which the intermediate activations are decoded for a particular block. To decode them, I pass them through the same layer norm and final unembedding layer that is normally applied after the final transformer block. 

Similarly to previous research, I found that the decoded block outputs at most layers, except a few early ones, were interpretable. I also found that the other intermediate outputs were interpretable and provided some intuition on what different layers were responsible for. 

Furthermore, the intermediate decoded outputs gave useful information on where it could be most effective to apply a steering vector activation change. In a previous experiment I did [LW · GW], I found that merging activations from different inputs at a specific point in a transformer can result in effective concept mixing. However, such concept mixing does not work as effectively at each layer. By inspecting the decoded layer activations, one can more reliably predict where adding a steering vector will be most effective.

Analyzed example

If we look at the intermediate activations for the final token during a forward pass of 'The capital of Germany is,' we can see the following interesting features:

Informed activation manipulation

As mentioned above, the context relevant to the question is first visible in the decoded representation at layer 14. This corresponds to the optimal layer for 'changing' the capital of Germany via attention mechanism output activation mixing. 

Mixing in the attention activations of Croissant, Cheese, Baguette at layer 13 (one before) results in The capital of Germany is the most popular city in the world, whereas mixing in the activations of Croissant, Cheese, Baguette at layer 14 results in The capital of Germany is the city of Paris, and the capital of France is the city of Paris, and the capital of the United States is the city of Paris (if generation is continued with this perturbation) (the 0.7 scaling factor was chosen empirically and applied in all tests)

We can also inspect the intermediate states of layers after 14 when the additional activation is added to the output of the attention mechanism at layer 14. 

Unlike previously, where at layer 19 we could see 'Germany' and 'Capital' combining, we now see the additional country 'France' in the mix, which was added by our Croissant, Cheese, Baguette perturbation.

I also ran a test with the input Artificial Intelligence will impact the world in many ways, particularly, and could see that the attention mechanism of layer 16 pulled in the representation of concepts related to jobs and employment. 

Therefore, I hypothesized that integrating information about a particular profession before this point would skew the completion. This was indeed correct:

Combining the activation of 'bananas' at layer 14 resulted in a statement about agriculture (Artificial Intelligence will impact the world in many ways, particularly in the field of agriculture.\n The world is facing a food crisis.), whereas the default completion is about healthcare (Artificial Intelligence will impact the world in many ways, particularly in the field of healthcare.\nThe healthcare industry is one of the most important industries in the world. It is responsible for the well-being of millions).

The completion could be steered very predictably at this layer:

The country layer

As part of this investigation, I found that the attention output of layer 24 of the llama 2 transformer consistently represents relevant information related to countries, even when neither the prompt nor the higher probability completions are related to countries (credit to Dmitry Vaintrob for finding this with me). 

Here is layer 24 with various prompts:

The most important political question in the world is

These attention outputs are clearly the most politically salient countries - Ukraine, Russia, China, Palestine, Israel, Iran, etc.

My favorite dish to eat is

These attention outputs are all popular global cuisines, such as Asian and Italian, even though layer layers discard country information and prefer more generic answers:

Ramen generally consists of

(Predictably results in tokens related to Japan)

Pineapples are a delicious fruit

Layer 24 represents that pineapples are associated with Hawaii, even though the final greedily-decoded output is: Pineapples are a delicious fruit that can be eaten fresh or used in cooking. They are also a popular ingredient in many desserts and drinks

Future work

Application to oversight

I would be interested in seeing whether the interpretability of intermediate results of a model's computation can be leveraged for more efficient and effective oversight. Basic filtering-based oversight mechanisms rely on an overseer model filtering out unacceptable outputs that have been produced by the main model. However, if the overseer can detect whether it should filter out the output earlier in the computation process, this would make filtering mechanisms more efficient, reducing the 'safety tax.' 

Application to deception 

If a model uses a deceptive reasoning process to decide to output an acceptable output, perhaps this could be detected via a variation of this method. For instance, if a model 'knows' that it is being audited and therefore answers in a way humans would approve of but otherwise would have given an unacceptable answer, this could be detectable if some intermediate representation could be decoded to represent information related to auditing, checking, or human approval. 

3 comments

Comments sorted by top scores.

comment by amirrahnama · 2024-01-10T14:38:07.456Z · LW(p) · GW(p)

Thanks for the nice tutorial. 

I have a problem understanding your code (I am new to Pytorch). When you are calculating the activations of attention: 

 

def forward(self, *args, **kwargs): 

     output = self.attn(*args, **kwargs) 

    if self.add_tensor is not None: output = (output[0] + self.add_tensor,)+output[1:] 

     self.activations = output[0] return output

 

What is the argument that is passed to the self.attn function?

I tried passing the following but cannot reproduce your code: 

  • model.layers.layers[0].self_attn(past_key_values[0][0].reshape(1, 10, 32* 128))[0]
  • model.model.embed_tokens(inputs.input_ids.to(device)) 

Neither of these can reproduce your results. Can you clarify this? 

Replies from: NinaR
comment by Nina Rimsky (NinaR) · 2024-01-10T20:41:39.425Z · LW(p) · GW(p)

The wrapper modules simply wrap existing submodules of the model, and call whatever they are wrapping (in this case self.attn) with the same arguments, and then save some state / do some manipulation of the output. It's just the syntax I chose to use to be able to both save state from submodules, and manipulate the values of some intermediate state. If you want to see exactly how that submodule is being called, you can look at the llama huggingface source code. In the code you gave, I am adding some vector to the hidden_states returned by that attention submodule. 

Replies from: amirrahnama
comment by amirrahnama · 2024-01-11T10:50:28.546Z · LW(p) · GW(p)

Thanks, Nina, for sharing the forward pass of Hugging face. I now realize I was skipping the input layer norm calculations. Now, I can reproduce your numbers :)