Redundant Attention Heads in Large Language Models For In Context Learning
post by skunnavakkam · 2024-09-01T20:08:48.963Z · LW · GW · 1 commentsThis is a link post for https://skunnavakkam.github.io/blog/degenerate-attention-heads-in-large-language-models/
Contents
Claim 1. Language models have many redundant attention heads for a given task Claim 2. In context learning works through addition of features, which are learnt through Bayesian updates Claim 3. The model breaks down the task into subtasks. MLPs are responsible for this. None 1 comment
In this post, I claim a few things and offer some evidence for these claims. Among these things are:
- Language models have many redundant attention heads for a given task
- In context learning works through addition of features, which are learnt through Bayesian updates
- The model likely breaks down the task into various subtasks, and each of these are added as features. I assume that these are taken care of through MLPs (this is also the claim that I'm least confident about)
To set some context, the task I'm going to be modelling is the task such that we give a pair of in the following format:
(x, y)\n
where for each example, . As a concrete example, I use:
(28, 59)
(86, 175)
(13, 29)
(55, 113)
(84, 171)
(66, 135)
(85, 173)
(27, 57)
(15, 33)
(94, 191)
(37, 77)
(14, 31)
All experiments here were done with llama-3-8b
using TransformerLens running on an A100-40GB unless specified otherwise.
Claim 1. Language models have many redundant attention heads for a given task
To probe this, I patch activations from the residual stream of a model given context to a model that doesn't. As a result, it is possible to see where task formation happens. Initially, without any modifications, the hook point that first works to patch over some semblance of the original task was layer 12. At this layer, it seems like the model learns that .
Ablating all attention heads on layers 10 and 11 of the model given context (which I will now reference as Model A, and let the other be model B) does not change the model's answer significantly. However, when repeating patching, the first point that works is the portion of the residual stream before layer 14.
This can be confirmed through attention patterns, where backup heads on layer 13 activate very strongly after the initial ablation. Ablating layer 13 does something similar, except this time the first layer that works shifts from 14 to 16, with heads on layer 15 activating very strongly in response.
Ablating layer 15 still results in the model coming to the correct answer. However, patching is different in this case. From patching, on no layer of patching does this behavior where the output should be form. Instead, almost the exact answer is copied over from model A. My current hypothesis for this is that the model adds the feature for the value of to the residual stream before the model is able to add the feature corresponding to the particular task, causing this portion of the residual stream, when patched, to produce the output of the task applied to .
There are clearly a large number of portions of the model responsible for in-context learning. Otherwise, ablating heads responsible for in-context learning would result in an incorrect answer without redundant heads. This weakly supports Anthropic's initial hypothesis that in-context learning is driven primarily by induction heads, since we also find induction heads everywhere in large language models.
Claim 2. In context learning works through addition of features, which are learnt through Bayesian updates
This is a weaker claim that the first, and I have less evidence to support it.
I have a good amount of evidence that shows that language models update approximately Bayesianly to in-context learning tasks. More specifically, if you segregate answers into two buckets, and where represents the probability of the correct answer, and represents the probability of the incorrect answer. represents probabilities after seeing examples.
With each training example seen, a Bayesian update would be:
with the model having some prior for , with , and being the model's expectation that the observation is true.
As a result,
We can now drop the normalization constant, instead saying
Now, we only need to keep this normalization constant around to ensure that the sizes of and stay reasonable, since they would lose precision as they decrease. In this way, we can represent Bayesian updates as a sequence of additions, with the softmax built in for free with the language model.
This provides a clean reason for why features seem to be updated Bayesianly. For each example, each previous example confirming the pattern attends to the th example, and adds a constant update term (equivilant to ) to the th example.
Mechanistically, I would expect this to be similar to this feature being added from every pair of examples where the attention head thinks the two examples look similar, thus causing the model's expectation for to be formed from the addition of
This supports the general linear representation hypothesis. In addition, this also plays nicely with the idea that different features are learnt at different points of the model in-context, since no knowledge of the prior is needed for this to work. This allows for parity, the idea that the output is a number, and the idea that the output is to be added at different points in the model.
One thing I'm still slightly unclear about is how the model is able to crystalize this in the answer, and why patching at different points shows different behavior as to when the model crystalizes the answer (such as during my part in claim 1, where model B started reporting the actual answer, instead of learning the task).
My current hypothesis for this is two part:
- That the point which is represented by the sum of the features , , that is a number, and a prior for maps to
- That before the layer where we see model B reporting the answer to model A's task, the prior for formed by model A is already added to the residual stream
Claim 3. The model breaks down the task into subtasks. MLPs are responsible for this.
Broadly, the algorithm that I have in my head is roughly:
-
One attention head attends from for a single example , and brings information about what is. Hopefully, and likely, these vectors that convey what is and what is are both in the residual stream position of , and approximately orthogonal. I think it's fairly likely that this is an induction head!
-
Then, an MLP acts on this, breaking down the two directions into a superposition of a bunch of features representing what the transformation is, with the size of these features approximating how confident the MLP is that the transformation is the correct one.
-
After this, another attention head attends from the token to the token before , with . This is responsible for adding the feature that corresponds to the transformation, updating it Bayesianly. I think that the size of this update is probably only proportional to the MLPs confidence in the transformation.
My evidence for this is that of the heads that were ablated to show redundancy, the heads that contributed the most in any given layer had attention patterns that looked similar to what you would expect if point 3 were true.
1 comments
Comments sorted by top scores.