Contextual attention heads in the first layer of GPT-2
post by Alex Gibson · 2025-01-20T13:24:31.803Z · LW · GW · 0 commentsContents
Overview: Decomposition of First Layer Attention Patterns: Positional pattern visualization: Behavioural Classification of First Layer Attention Heads: Approximating Softmax Probabilities: Independent model of normalisation factors: Contribution to Residual Stream from Contextual Attention Heads: Contextual Neurons: Britain vs America (Neuron 300): 19-20th century conflict? (Neuron 1621): Evaluating the approximation of contextual neurons: Rotary Embeddings: Further Work: Acknowledgements: Appendix: Analysis of Attention Patterns with Layer Normalisation: Positional component analysis: Second layer attention patterns: Contextual features: None No comments
Overview:
Through mathematical analysis of GPT-2 small's first-layer attention patterns, I find six attention heads whose combined output can be approximated as a "bag-of-tokens" representation of the previous ~50 tokens. The combined output of these attention heads can be approximated as a weighted sum of extended embeddings of tokens, where the weight of a particular token depends only on its position relative to the current position.
The model can use this bag-of-tokens for lots of downstream tasks. The bag-of-tokens of input sequences with distinct token distributions will naturally linearly separate, allowing the model to learn contextual features by projecting the bag-of-tokens along directions corresponding to particular distributions of tokens.
To demonstrate the predictive power of the bag-of-tokens approximation, I find first-layer contextual neurons without having to run the model by composing the extended embeddings of tokens with the input to the first MLP layer. I show that these neurons can be thought of as performing Naive Bayes classification on the input, and that the approximation gets close to matching the observed activations of these contextual neurons.
There are lots of contextual neurons which seem monosemantic, whereas others seem more complicated. A few example contextual neurons:
- 'Medical studies'
- 'Legal text'
- 'Bible'
- 'British text vs American text'
- 'Spanish'
- '20th century conflict'
- 'Pokemon'
- 'Astronomy'
- 'Climate Change'
- 'Conspiracy theories'
Decomposition of First Layer Attention Patterns:
I assume familiarity with the Mathematical Framework For Transformer Circuits.
I approximate the layer norm as linear because its scale consistently falls within 0.15–0.175. I refine this approximation in the appendix. When referring to and , I take these to be post layer norm approximation.
I write for the token embedding , Q for the query matrix , K for the key matrix ,and for the positional embedding .
I simplify notation by concatenating letters to denote matrix multiplication, applying transpositions when appropriate. For instance, refers to .
Consider a sequence of tokens , where represents the current destination position.
For any position , the th attention score measures the attention weight that places on .
This score combines contributions from both token embeddings () and positional embeddings () through query and key matrices:
Technically, GPT-2 has bias terms on each of the queries and keys, and layer-norm has a weight and bias as well. These bias terms make the equations harder to read but don't cause any complications, so I omit them here.
The transformer applies a softmax operation to these attention scores, exponentiating each score for and normalizing the resulting vector. The exponentiated attention score decomposes into two independent components:
: depends exclusively on the token's position, .
: depends exclusively on the token's content, .
This decomposition can be expressed as:
I define the positional pattern as:
Given the destination token , this pattern can be computed independently of the other sequence tokens, representing the softmax of the position-dependent components of the attention score.
The positional pattern can be viewed as telling you how much each position in the sequence gets weighted, independent of the content of the token at that position. It is the attention pattern you would get if every token in the sequence were identical.
The positional pattern is slightly influenced by the destination token . However, for all attention heads in GPT-2-small’s first layer, the overall character of the positional pattern remains consistent regardless of .
The final softmax probabilities at position are:
, for .
Positional pattern visualization:
Below is the positional pattern of each of the first layer attention heads for , , and a sample attention pattern obtained from running the model on the Bible. Each of the attention patterns of length 400 have been reshaped into a 20x20 grid for visualization purposes.
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
Behavioural Classification of First Layer Attention Heads:
Based on the positional patterns, and observation of the values, we can classify the first layer attention heads into distinct behavioural groups:
Group | Heads | Behaviour |
Detokenization heads | 3, 4, 7 | Attends mainly to the previous tokens; used for detecting known -grams. Their positional pattern is mostly translation invariant, so that n-grams have consistent representations regardless of position. |
Contextual attention heads | 0, 2, 6, 8, 9, 10 | Positional pattern that exponentially decays over tokens. tends to be pretty consistent across different destination tokens . Head has duplicate token behaviour, but tends to attend pretty uniformly to non-duplicate tokens. |
Duplicate token heads | 1, 5 | Attends almost entirely to duplicate copies of the current token . Head detects duplicate tokens close to uniformly, Head only detects nearby duplicate tokens so it can do more precise relative indexing. |
Miscellaneous | 11 | Uncertain of its role |
Approximating Softmax Probabilities:
Continuing from , where is the destination position:
This expression is difficult to work with because the denominator of depends on all tokens in the sequence, rather than solely on and .
For contextual attention heads, however, the empirical value of the denominator depends mostly on the destination position , rather than the particular values. When is fixed, the denominator tends to concentrate around the same value for a variety of input sequences.
Fix , assuming that the destination token doesn't affect the content-dependent component too much. Even if this assumption breaks down for certain contextual attention heads, it will at least tell us what happens when the destination token is a stopword.
For this fixed , I plotted the denominator of the softmax for various texts as a function of the destination position :
|
The normalisation factor initially decays because the token has a large content-dependent component. This reduces the impact of the contextual attention heads for the first ~ tokens of the sequence.
Of the contextual heads, heads and vary the most in normalisation factor across different contexts. A lot of the variation can be attributed to differences in the percentage of keywords in the text. Head 6 emphasises keywords, whereas Head 8 emphasises stopwords. The Bible has more newline characters than other texts because it is composed of verses, so it has a below-average normalisation factor for head 6, and an above-average normalisation factor for head .
Independent model of normalisation factors:
For a fixed text, certain heads seem to have denominators which are a smoother function of than others. What determines this?
As a simple model, imagine that each was independently drawn from some distribution with , with fixed. Then we have .
Because of the constraint that , becomes a measure of how spread out the positional pattern is. If an attention head pays attention to tokens, then . So, the more tokens a positional pattern is spread over, the lower the variance in its normalisation factor. The contextual attention heads are averaging over enough tokens that the normalisation factor varies smoothly.
is determined by how much the content-dependent weighting varies between tokens. Head and Head have a very similar positional pattern, but head emphasises keywords above stopwords, whereas head attends fairly evenly to tokens. Therefore Head 6 has a higher variance than head .
Heads , , and (the de-tokenisation heads) all have high variance in their content-dependent component, and a positional pattern which isn't spread over many tokens. Hence the confetti. Of course, this doesn't mean these heads are particularly hard to analyse, it's just not appropriate to model them as having a constant normalisation factor for fixed .
The argument is:
- 1.) Contextual attention heads produce an empirical average of content-dependent components
- 2.) Because contextual attention heads average over many tokens, the empirical average in a particular context will be close to the global average in said context.
- 3.) The model has learnt to have the global average of the content-dependent component in a variety of different contexts be basically independent of the particular context.
Note the model doesn't have to work particularly hard to get the third property. Completely random content-dependent components would work.
Arguably a lot of the trouble in reading off algorithms from the weights is that models often learn algorithms that rely upon statistical regularities like the softmax denominator, which are invisible at the weights level. You need to inject some information about the input distribution to be able to read off algorithms, even if it's just a couple inputs to allow you to find constants of the computation.
For toy models, we can analyse contextual attention heads by partitioning inputs based on their softmax denominator values. While computing all the elements of this partition would be exponentially expensive, we can use concentration inequalities to bound the size of each partition - i.e., how likely inputs are to produce each range of denominator values. This lets us prove formal performance bounds without having to enumerate all possible inputs.
Contribution to Residual Stream from Contextual Attention Heads:
Once again, fix the destination token to be . Define as the median of the empirical normalization factors for head .
The EVO circuit of head can be approximated as:
I approximate the contextual attention heads as each having the same positional pattern , as their true positional patterns are quite similar. Small differences in the positional pattern are unlikely to matter much because the EVO output of head will concentrate around the expected value of , regardless of the specific positional pattern.
Thus, the approximate output of the contextual attention heads can be expressed as:
Because I fixed the destination token , and is a constant,
is a function of just and . I refer to this as .
The output to the residual stream from the contextual attention heads is then approximately .
can be interpreted as an extended embedding of the token , optimized for contextual classification. The combined output of all the contextual attention heads to the residual stream is approximately an exponentially decaying average of these extended embeddings over the previous ~50 tokens.
Technically, depends on the destination position . It would be very strange if these extended embeddings changed significantly depending on , although it might be able to implement a correction term for layer norm, discussed in the appendix.
If these extended embeddings were one-hot encodings of tokens, and the average was uniform instead of an exponential decay, the output would be a bag-of-words, which is often used as input to classification algorithms like Naive Bayes.
In practice, the model only has 768 dimensions in its residual stream rather than 50257; which is the number of dimensions it would need to one-hot encode each token. This shouldn't cause too many issues for Naive Bayes classification, though, since contextual features are sparse. Nonetheless, it would be useful to know precisely how much this dimensionality restriction, and the fact that the average is an exponential decay rather than uniform, affects the model's ability to learn contextual features.
Contextual Neurons:
I define contextual neurons to be first-layer MLP neurons which have a significant component due to the EVO circuit of the contextual attention heads.
You can use the extended embedding approximation to find hundreds of different interesting contextual neurons by looking at the composition of the extended embeddings with the MLP input, after accounting for the MLP layer-norm. This is demonstrated in the accompanying Google Colab.
Below I show a couple of these contextual neurons to give an idea for the contexts the model finds important to represent. These contextual neurons are just reading off directions which have been constructed by the contextual attention heads, and in general, there's no reason to expect the directions constructed by the contextual attention heads to align with the neuron basis. Some of the contextual neurons seem to be monosemantic, whereas others seem to be complex to understand.
Each contextual neuron has a token contribution vector associated with it listing the composition of each token's extended embedding with the MLP neuron's input. The way to interpret this is that the transformer adds up an exponentially decaying average of these token contributions over the sequence. Positive token contributions update the neuron towards firing, and negative token contributions update the neuron against firing. Neurons might have an initial bias so that they only fire if the positive contributions get over a certain threshold.
Britain vs America (Neuron 300):
Top positive contributions:
Bottom negative contributions:
There is another neuron for detecting British contexts, neuron 704, which has an output in the opposite direction. Plausibly they cancel each other out to avoid their total output getting too large.
19-20th century conflict? (Neuron 1621):
Top positive contributions:
The years (1917,1918,1942,1943,1944) would indicate this is related to WW2/WW1, but 'Sherman' was a general during the American Civil War.
Bottom negative contributions:
These token contributions feel like a 'War' latent '21st century' latent.
If this was the case, the top contributions and bottom contributions would be misleading, and wouldn't necessarily inform you about what is going on with the median contributions.
It would be useful if there was a way to automatically decompose the token contributions vector of a neuron into separate sub-latents. Potentially some sort of SVD technique could be helpful here.
Evaluating the approximation of contextual neurons:
Now that we have examples of neurons which we think are detecting certain specialised contexts, we can use these neurons to evaluate the above approximations.
The above gives an approximation for the EVO circuit of contextual attention heads. To estimate the normalization factor, I use a fixed control text with an average number of keywords. I approximate the PVO contribution using each head's positional pattern. For the other attention heads, I approximate the EVO circuit as attending solely to the current token. I calculate the E and P circuits directly.
It's easiest to show a variety of these approximations in the Google Colab, but here is a typical example, for neuron 1710, which fires on religious texts:
The approximation does tend to capture the overall fit of the graphs, but the approximations at individual points aren't very good. There is noise of about .
The approximation tends to agree on the decision boundaries of neurons, at least, but it wouldn't satisfy someone looking for -distance bounds.
Assuming there is no error in the non-contextual heads or second layer-norm approximation, the main source of error would come from the = ' the' approximation. Most alternative seem to be consistent in normalization factor across a wide variety of contexts, even when the context in question is semantically related to . This would make our 'bag of words' a function of , at which point the claim about a 'bag of words' is that this doesn't vary too much with .
We'd need to know more about future layers to know how significant the error term is for understanding the broad behaviour of the model. For instance, future layers might use an averaging head to remove the noise from the neuron activations, at which point this approximation would be a good substitute. But if future layers make significant use of the exact neuron activations, this would require us to understand what's going on far better.
Regardless, validating that this mechanism makes more or less the same decisions as the model is exciting, because it feels like I have learnt at a high-level what the model is doing, even if the particulars would require a more lengthy explanation.
Rotary Embeddings:
Rotary embeddings can implement contextual attention heads by having queries and keys that mainly use the lower frequencies, as discussed in ROUND AND ROUND WE GO! WHAT MAKES ROTARY POSITIONAL ENCODINGS USEFUL?. At sufficiently low frequencies, the positional dependence will mostly drop out, which is what you want for an attention head that summarizes information.
Contextual attention heads are what rotary embeddings are designed to make easy to construct. It seems difficult to get as neat of a decomposition as when the positional embeddings lie in the residual stream, though.
Further Work:
- Are any of the contextual neurons shown above actually monosemantic? Is there a way to use the token contributions vector associated with a neuron to decompose its activations into sub-latents or find unrelated contexts? Is there a universality to the contextual neurons which models learn when trained on the same dataset?
- Is there a way to extend the notion of a positional pattern to later layers? The existence of previous token heads in later layers indicates this is at least worth looking for, and empirically many attention heads in later layers seem to have quite well-defined positional patterns. Positional information interacting with contextual information in a non-additive way could be a barrier to this, but the existence of neurons which are mostly positional is a promising sign.
- Looking at large entries of the EQKE matrix of the 'de-tokenization' heads 3, 4, and 7 seems promising for finding 'known bigrams', but how large do these entries need to be for the model to distinguish a particular bigram? Do different detokenization heads specialise in distinct sorts of bigrams? You can imagine that the model might want to treat the bigram ' al','paca' differently from ' Barack',' Obama'.
- Is it possible to refine the approximation for the normalization factor? Is it mostly a function of keyword density, or is there more to it? Is the model doing something more sophisticated that is lost in treating the normalization factor as a constant (almost certainly)?
- The assumption that the destination token doesn't affect the approximation too much needs more investigation. I have ignored pretty much all the internal structure of these heads, so there is lots of important information missing from this approximation.
- Is there interesting developmental interpretability that can be done just by tracking positional patterns throughout training?
- Here I didn't investigate the structure of any of the VO matrices. Certain heads seem like they are specialising in different things, and you'd expect there to be corresponding structure in their VO matrices. Understanding this structure would make clearer situations in which you should be more/less worried about the approximation being inaccurate. Do certain VO matrices cancel each other out? How does this interact with layer-norm? It would be interesting to look at how the VO matrices for head and head get used, since these seem to be the heads which vary the most with the destination token .
Acknowledgements:
This post benefitted significantly from discussion with and feedback from Euan Ong and Jason Gross.
Appendix:
Analysis of Attention Patterns with Layer Normalisation:
I analyse how layer normalisation affects attention patterns by decomposing the attention weights into positional and content-dependent components. For tractability, fix a destination token and examine how attention to previous tokens varies.
Let denote the query vector derived from the post-layer-norm embedding of , which we can compute exactly. The attention weight to a previous token can be decomposed into:
where represents the token embedding and the positional embedding.
and
Empirically, token and positional embeddings are approximately orthogonal. This allows us to approximate:
for , and the positional pattern only attends to the previous roughly tokens so that we can approximate the content-dependent component by:
This approximation recovers the sole dependence of on and .
Positional component analysis:
The positional component of attention is given by:
Once again, I construct a probabilistic model of what will occur.
To analyse this, introduce a mean approximation by averaging over token embeddings according to some distribution :
The relationship between actual and mean components can be written as:
Using a Taylor expansion around the mean component:
where:
This represents the relative deviation from mean normalisation.
As before, define .
Unlike before, now depends on , and depends on all terms in the sequence.
To handle this, we use a similar argument to the normalisation factor:
Importantly, varies significantly with n. For , this term is about , whereas for , it's about . This indicates the model is more sensitive to layer normalization further from the center of the sequence. At the center, there is very little distortion from layer norm.
In any case, is bounded given , and is at most .
Then:
And depends only on the source token , not the other sequence values.
is not very large in practice, at most .
Using the same argument as the main text with an independent model:
.
We can argue that this variance will be tiny for contextual attention heads, bounding using our bound for .
We can do the same for higher-order terms of the Taylor expansion but it's probably unnecessary here since is quite small.
Therefore I approximate:
.
This gives us:
Let's say that we restrict our sequences to ones where this approximation holds up to a factor of , where it does so with high probability thanks to our variance bounds above. Then,
where the approximation holds up to .
Now we can use the same variance argument from earlier, on the terms of this sequence. And we can calculate the variance up to some small tolerance assuming independence of because the () approximation holds with sufficiently high probability.
So we can argue that will concentrate around a constant, or we can empirically observe this. Call this constant , noting that this constant will be different from in the main text.
Then we can approximate
where we can approximate in terms of alone because for where is non-fnegligible.
And then we can approximate , to get a correction term when is quite large or quite small.
And now we have obtained an approximation that partially takes into account layer-norm which nonetheless allows for a decomposition into position-dependent and content-dependent terms.
Notably, for near , is close to 0, so that the approximation given in the main text works well.
Second layer attention patterns:
These are sample attention patterns on the Bible for the second layer, with n=1022:
| ||
|
|
|
To a large extent, the attention heads in the second layer are far better behaved than the attention heads in the first layer. There are lots of attention heads which seem almost entirely positional, with not even a content-dependent component.
These positional attention heads likely clean up noise from the first-layer contextual attention heads arising from variation in . They could also construct bags of bigrams by composing with the post-mlp detokenization head outputs, or compute how repetitive a text is by composing with duplicate token neurons. In general, primarily positional heads can be viewed as constructing bags-of-representations, out of any representations constructed by previous layers.
Contextual features:
The bag-of-tokens constrains how models can represent contextual features in the first-layer residual stream. Excluding Head 11, which behaves similarly to the contextual attention heads, only the contextual attention heads have access to tokens occuring more than ~7 tokens away.
Whatever a feature is, if we assume the bag of tokens approximation holds up, we must be able to understand first-layer contextual features through this lens.
For instance, 'Maths', and 'Topology', will naturally have a large overlap in their token distributions. And so by default we should expect them to lie close together in activation space, because their bag of tokens will overlap, assuming models don't give topology tokens abnormally large extended embeddings. Models will generally benefit from having close token distributions have close next token distributions, so there's no immediate reason for models to want large extended embeddings.
Bags of tokens are mostly translation invariant, so that contextual features will also tend to be translation invariant. Although for the first 100 tokens or so, models attend disproportionately to the <end-of-text> token, so that all contextual features will be partially ablated.
Arguably this ablation immediately implies the existence of weak linear representations. Assume that there is a cluster of 'bag-of-token' activations corresponding to 'Math'. Then if the model pays 10% to the <end-of-text> token, it will have 90% 'Math'. Whereas if it pays 50% to <end-of-text>, then it will have 50% 'Math' in its residual stream. So models will naturally want to handle a continuous range of strengths of activations.
0 comments
Comments sorted by top scores.