Contextual attention heads in the first layer of GPT-2

post by Alex Gibson · 2025-01-20T13:24:31.803Z · LW · GW · 0 comments

Contents

  Overview:
  Decomposition of First Layer Attention Patterns:
    Positional pattern visualization:
    Behavioural Classification of First Layer Attention Heads:
  Approximating Softmax Probabilities:
    Independent model of normalisation factors:
  Contribution to Residual Stream from Contextual Attention Heads:
  Contextual Neurons:
        Britain vs America (Neuron 300):
        19-20th century conflict? (Neuron 1621):
  Evaluating the approximation of contextual neurons:
  Rotary Embeddings:
  Further Work:
  Acknowledgements:
  Appendix:
    Analysis of Attention Patterns with Layer Normalisation:
      Positional component analysis:
    Second layer attention patterns:
    Contextual features:
None
No comments

Overview:

Through mathematical analysis of GPT-2 small's first-layer attention patterns, I find six attention heads whose combined output can be approximated as a "bag-of-tokens" representation of the previous ~50 tokens. The combined output of these attention heads can be approximated as a weighted sum of extended embeddings of tokens, where the weight of a particular token depends only on its position relative to the current position.

The model can use this bag-of-tokens for lots of downstream tasks. The bag-of-tokens of input sequences with distinct token distributions will naturally linearly separate, allowing the model to learn contextual features by projecting the bag-of-tokens along directions corresponding to particular distributions of tokens.

To demonstrate the predictive power of the bag-of-tokens approximation, I find first-layer contextual neurons without having to run the model by composing the extended embeddings of tokens with the input to the first MLP layer. I show that these neurons can be thought of as performing Naive Bayes classification on the input, and that the approximation gets close to matching the observed activations of these contextual neurons.

There are lots of contextual neurons which seem monosemantic, whereas others seem more complicated. A few example contextual neurons:

Decomposition of First Layer Attention Patterns:

I assume familiarity with the Mathematical Framework For Transformer Circuits.

I approximate the layer norm as linear because its scale consistently falls within 0.15–0.175. I refine this approximation in the appendix. When referring to  and , I take these to be post layer norm approximation.

I write  for the token embedding , Q for the query matrix , K for the key matrix ,and  for the positional embedding 

I simplify notation by concatenating letters to denote matrix multiplication, applying transpositions when appropriate. For instance,  refers to .
Consider a sequence of tokens , where  represents the current destination position. 

For any position , the th attention score measures the attention weight that  places on 

This score combines contributions from both token embeddings () and positional embeddings () through query and key matrices:

Technically, GPT-2 has bias terms on each of the queries and keys, and layer-norm has a weight and bias as well. These bias terms make the equations harder to read but don't cause any complications, so I omit them here.

The transformer applies a softmax operation to these attention scores, exponentiating each score for  and normalizing the resulting vector. The exponentiated attention score decomposes into two independent components:

: depends exclusively on the token's position, .

: depends exclusively on the token's content, .

This decomposition can be expressed as:


I define the positional pattern as:

Given the destination token , this pattern can be computed independently of the other sequence tokens, representing the softmax of the position-dependent components of the attention score.

The positional pattern can be viewed as telling you how much each position in the sequence gets weighted, independent of the content of the token at that position. It is the attention pattern you would get if every token in the sequence were identical.

The positional pattern is slightly influenced by the destination token . However, for all attention heads in GPT-2-small’s first layer, the overall character of the positional pattern remains consistent regardless of ​. 

The final softmax probabilities at position  are:
, for .

Positional pattern visualization:

Below is the positional pattern of each of the first layer attention heads for , and a sample attention pattern obtained from running the model on the Bible. Each of the attention patterns of length 400 have been reshaped into a 20x20 grid for visualization purposes.

 

 

Head 0 positional pattern
Sample Head 0 attention pattern

 

Head 1 positional pattern
Sample Head 1 attention pattern

 

Head 2 positional pattern
Sample Head 2 attention pattern

 

Head 3 positional pattern
Sample Head 3 attention pattern

 

Head 4 positional pattern
Sample Head 4 attention pattern

 

Head 5 positional pattern
Sample Head 5 attention pattern

 

Head 6 positional pattern
Sample Head 6 attention pattern

 

Head 7 positional pattern
Sample Head 7 attention pattern

 

Head 8 positional pattern
Sample Head 8 attention pattern

 

Head 9 positional pattern
Sample Head 9 attention pattern

 

Head 10 positional pattern
Sample Head 10 attention pattern
Head 11 positional pattern
Sample Head 11 attention pattern


Behavioural Classification of First Layer Attention Heads:

Based on the positional patterns, and observation of the  values, we can classify the first layer attention heads into distinct behavioural groups:

GroupHeadsBehaviour
Detokenization heads3, 4, 7Attends mainly to the previous  tokens; used for detecting known -grams. Their positional pattern is mostly translation invariant, so that n-grams have consistent representations regardless of position.
Contextual attention heads0, 2, 6, 8, 9, 10Positional pattern that exponentially decays over  tokens.  tends to be pretty consistent across different destination tokens . Head  has duplicate token behaviour, but tends to attend pretty uniformly to non-duplicate tokens.
Duplicate token heads1, 5Attends almost entirely to duplicate copies of the current token . Head  detects duplicate tokens close to uniformly,  Head  only detects nearby duplicate tokens so it can do more precise relative indexing.
Miscellaneous11Uncertain of its role

 

Approximating Softmax Probabilities:

Continuing from , where  is the destination position:

This expression is difficult to work with because the denominator of  depends on all tokens  in the sequence, rather than solely on  and .

For contextual attention heads, however, the empirical value of the denominator  depends mostly on the destination position , rather than the particular  values. When  is fixed, the denominator tends to concentrate around the same value for a variety of input sequences.

Fix , assuming that the destination token doesn't affect the content-dependent component too much. Even if this assumption breaks down for certain contextual attention heads, it will at least tell us what happens when the destination token is a stopword. 

For this fixed , I plotted the denominator of the softmax for various texts as a function of the destination position :

 

 

The normalisation factor initially decays because the  token has a large content-dependent component. This reduces the impact of the contextual attention heads for the first ~ tokens of the sequence.

Of the contextual heads, heads  and  vary the most in normalisation factor across different contexts. A lot of the variation can be attributed to differences in the percentage of keywords in the text. Head 6 emphasises keywords, whereas Head 8 emphasises stopwords. The Bible has more newline characters than other texts because it is composed of verses, so it has a below-average normalisation factor for head 6, and an above-average normalisation factor for head .

Independent model of normalisation factors:

For a fixed text, certain heads seem to have denominators which are a smoother function of  than others. What determines this?

As a simple model, imagine that each  was independently drawn from some distribution with , with  fixed. Then we have .

 Because of the constraint that  becomes a measure of how spread out the positional pattern is. If an attention head pays  attention to  tokens, then . So, the more tokens a positional pattern is spread over, the lower the variance in its normalisation factor. The contextual attention heads are averaging over enough tokens that the normalisation factor varies smoothly.

 is determined by how much the content-dependent weighting varies between tokens. Head  and Head  have a very similar positional pattern, but head  emphasises keywords above stopwords, whereas head  attends fairly evenly to tokens. Therefore Head 6 has a higher variance than head .

Heads , and  (the de-tokenisation heads) all have high variance in their content-dependent component, and a positional pattern which isn't spread over many tokens. Hence the confetti. Of course, this doesn't mean these heads are particularly hard to analyse, it's just not appropriate to model them as having a constant normalisation factor for fixed .

The argument is:

Note the model doesn't have to work particularly hard to get the third property. Completely random content-dependent components would work.

Arguably a lot of the trouble in reading off algorithms from the weights is that models often learn algorithms that rely upon statistical regularities like the softmax denominator, which are invisible at the weights level. You need to inject some information about the input distribution to be able to read off algorithms, even if it's just a couple inputs to allow you to find constants of the computation.

For toy models, we can analyse contextual attention heads by partitioning inputs based on their softmax denominator values. While computing all the elements of this partition would be exponentially expensive, we can use concentration inequalities to bound the size of each partition - i.e., how likely inputs are to produce each range of denominator values. This lets us prove formal performance bounds without having to enumerate all possible inputs.

Contribution to Residual Stream from Contextual Attention Heads:

Once again, fix the destination token  to be .  Define  as the median of the empirical normalization factors for head 

The EVO circuit of head  can be approximated as:

I approximate the contextual attention heads as each having the same positional pattern , as their true positional patterns are quite similar. Small differences in the positional pattern are unlikely to matter much because the EVO output of head  will concentrate around the expected value of , regardless of the specific positional pattern.

Thus, the approximate output of the contextual attention heads can be expressed as:

Because I fixed the destination token , and  is a constant,

  is a function of just  and . I refer to this as .

The output to the residual stream from the contextual attention heads is then approximately .

 can be interpreted as an extended embedding of the token , optimized for contextual classification. The combined output of all the contextual attention heads to the residual stream is approximately an exponentially decaying average of these extended embeddings over the previous ~50 tokens.

Technically,  depends on the destination position . It would be very strange if these extended embeddings changed significantly depending on , although it might be able to implement a correction term for layer norm, discussed in the appendix.

If these extended embeddings were one-hot encodings of tokens, and the average was uniform instead of an exponential decay, the output would be a bag-of-words, which is often used as input to classification algorithms like Naive Bayes.

In practice, the model only has 768 dimensions in its residual stream rather than 50257; which is the number of dimensions it would need to one-hot encode each token. This shouldn't cause too many issues for Naive Bayes classification, though, since contextual features are sparse. Nonetheless, it would be useful to know precisely how much this dimensionality restriction, and the fact that the average is an exponential decay rather than uniform, affects the model's ability to learn contextual features.

Contextual Neurons:

I define contextual neurons to be first-layer MLP neurons which have a significant component due to the EVO circuit of the contextual attention heads. 

You can use the extended embedding approximation to find hundreds of different interesting contextual neurons by looking at the composition of the extended embeddings with the MLP input, after accounting for the MLP layer-norm. This is demonstrated in the accompanying Google Colab.

Below I show a couple of these contextual neurons to give an idea for the contexts the model finds important to represent. These contextual neurons are just reading off directions which have been constructed by the contextual attention heads, and in general, there's no reason to expect the directions constructed by the contextual attention heads to align with the neuron basis. Some of the contextual neurons seem to be monosemantic, whereas others seem to be complex to understand.

Each contextual neuron has a token contribution vector associated with it listing the composition of each token's extended embedding with the MLP neuron's input. The way to interpret this is that the transformer adds up an exponentially decaying average of these token contributions over the sequence. Positive token contributions update the neuron towards firing, and negative token contributions update the neuron against firing. Neurons might have an initial bias so that they only fire if the positive contributions get over a certain threshold.

Britain vs America (Neuron 300):

Top positive contributions:

 

Bottom negative contributions:

There is another neuron for detecting British contexts, neuron 704, which has an output in the opposite direction. Plausibly they cancel each other out to avoid their total output getting too large.

19-20th century conflict? (Neuron 1621):

Top positive contributions:

The years (1917,1918,1942,1943,1944) would indicate this is related to WW2/WW1, but 'Sherman' was a general during the American Civil War.

Bottom negative contributions:

 

These token contributions feel like a 'War' latent  '21st century' latent.

If this was the case, the top contributions and bottom contributions would be misleading, and wouldn't necessarily inform you about what is going on with the median contributions.

It would be useful if there was a way to automatically decompose the token contributions vector of a neuron into separate sub-latents. Potentially some sort of SVD technique could be helpful here. 

Evaluating the approximation of contextual neurons:

Now that we have examples of neurons which we think are detecting certain specialised contexts, we can use these neurons to evaluate the above approximations. 

The above gives an approximation for the EVO circuit of contextual attention heads. To estimate the normalization factor, I use a fixed control text with an average number of keywords. I approximate the PVO contribution using each head's positional pattern. For the other attention heads, I approximate the EVO circuit as attending solely to the current token. I calculate the E and P circuits directly. 

It's easiest to show a variety of these approximations in the Google Colab, but here is a typical example, for neuron 1710, which fires on religious texts:

 

The approximation does tend to capture the overall fit of the graphs, but the approximations at individual points aren't very good. There is noise of about .

The approximation tends to agree on the decision boundaries of neurons, at least, but it wouldn't satisfy someone looking for -distance bounds.

Assuming there is no error in the non-contextual heads or second layer-norm approximation, the main source of error would come from the  = ' the' approximation. Most alternative  seem to be consistent in normalization factor across a wide variety of contexts, even when the context in question is semantically related to . This would make our 'bag of words' a function of , at which point the claim about a 'bag of words' is that this doesn't vary too much with .

We'd need to know more about future layers to know how significant the error term is for understanding the broad behaviour of the model. For instance, future layers might use an averaging head to remove the noise from the neuron activations, at which point this approximation would be a good substitute. But if future layers make significant use of the exact neuron activations, this would require us to understand what's going on far better. 

Regardless, validating that this mechanism makes more or less the same decisions as the model is exciting, because it feels like I have learnt at a high-level what the model is doing, even if the particulars would require a more lengthy explanation. 

Rotary Embeddings:

Rotary embeddings can implement contextual attention heads by having queries and keys that mainly use the lower frequencies, as discussed in ROUND AND ROUND WE GO! WHAT MAKES ROTARY POSITIONAL ENCODINGS USEFUL?. At sufficiently low frequencies, the positional dependence will mostly drop out, which is what you want for an attention head that summarizes information.

Contextual attention heads are what rotary embeddings are designed to make easy to construct. It seems difficult to get as neat of a decomposition as when the positional embeddings lie in the residual stream, though.

Further Work:

Acknowledgements:

This post benefitted significantly from discussion with and feedback from Euan Ong and Jason Gross.

Appendix:


Analysis of Attention Patterns with Layer Normalisation:

I analyse how layer normalisation affects attention patterns by decomposing the attention weights into positional and content-dependent components. For tractability, fix a destination token  and examine how attention to previous tokens varies.

Let  denote the query vector derived from the post-layer-norm embedding of , which we can compute exactly. The attention weight to a previous token  can be decomposed into:

where  represents the token embedding and  the positional embedding.

and 


Empirically, token and positional embeddings are approximately orthogonal. This allows us to approximate:


   
 

 for , and the positional pattern only attends to the previous roughly  tokens so that we can approximate the content-dependent component  by:
 

This approximation recovers the sole dependence of  on  and .

Positional component analysis:


The positional component of attention is given by:

Once again, I construct a probabilistic model of what will occur.

To analyse this, introduce a mean approximation by averaging over token embeddings according to some distribution :


The relationship between actual and mean components can be written as:

Using a Taylor expansion around the mean component:

where:

This represents the relative deviation from mean normalisation.


As before, define 
Unlike before,  now depends on , and  depends on all terms in the sequence.

To handle this, we use a similar argument to the normalisation factor:

Importantly,  varies significantly with n. For , this term is about , whereas for , it's about . This indicates the model is more sensitive to layer normalization further from the center of the sequence. At the center, there is very little distortion from layer norm.

In any case,  is bounded given , and is at most .

Then:

And  depends only on the source token , not the other sequence values.

 is not very large in practice, at most .

Using the same argument as the main text with an independent model:
.

We can argue that this variance will be tiny for contextual attention heads, bounding  using our bound for .

We can do the same for higher-order terms of the Taylor expansion but it's probably unnecessary here since  is quite small.

Therefore I approximate:
.

This gives us:

Let's say that we restrict our sequences to ones where this approximation holds up to a factor of , where it does so with high probability thanks to our variance bounds above. Then,
 where the approximation holds up to .

Now we can use the same variance argument from earlier, on the terms of this sequence. And we can calculate the variance up to some small tolerance assuming independence of  because the () approximation holds with sufficiently high probability. 

So we can argue that  will concentrate around a constant, or we can empirically observe this. Call this constant , noting that this constant will be different from in the main text.

Then we can approximate 

where we can approximate  in terms of  alone because  for  where  is non-fnegligible.

And then we can approximate , to get a correction term when  is quite large or quite small. 

And now we have obtained an approximation that partially takes into account layer-norm which nonetheless allows for a decomposition into position-dependent and content-dependent terms.

Notably, for  near  is close to 0, so that the approximation given in the main text works well.
 

Second layer attention patterns:

These are sample attention patterns on the Bible for the second layer, with n=1022:

 

 

 

 

To a large extent, the attention heads in the second layer are far better behaved than the attention heads in the first layer. There are lots of attention heads which seem almost entirely positional, with not even a content-dependent component. 

These positional attention heads likely clean up noise from the first-layer contextual attention heads arising from variation in . They could also construct bags of bigrams by composing with the post-mlp detokenization head outputs, or compute how repetitive a text is by composing with duplicate token neurons. In general, primarily positional heads can be viewed as constructing bags-of-representations, out of any representations constructed by previous layers.

Contextual features:

The bag-of-tokens constrains how models can represent contextual features in the first-layer residual stream. Excluding Head 11, which behaves similarly to the contextual attention heads, only the contextual attention heads have access to tokens occuring more than ~7 tokens away.

Whatever a feature is, if we assume the bag of tokens approximation holds up, we must be able to understand first-layer contextual features through this lens.

For instance, 'Maths', and 'Topology', will naturally have a large overlap in their token distributions. And so by default we should expect them to lie close together in activation space, because their bag of tokens will overlap, assuming models don't give topology tokens abnormally large extended embeddings. Models will generally benefit from having close token distributions have close next token distributions, so there's no immediate reason for models to want large extended embeddings.

Bags of tokens are mostly translation invariant,  so that contextual features will also tend to be translation invariant. Although for the first 100 tokens or so, models attend disproportionately to the <end-of-text> token, so that all contextual features will be partially ablated. 

Arguably this ablation immediately implies the existence of weak linear representations. Assume that there is a cluster of 'bag-of-token' activations corresponding to 'Math'. Then if the model pays 10% to the <end-of-text> token, it will have 90% 'Math'. Whereas if it pays 50% to <end-of-text>, then it will have 50% 'Math' in its residual stream. So models will naturally want to handle a continuous range of strengths of activations.

0 comments

Comments sorted by top scores.