How does a toy 2 digit subtraction transformer predict the difference?
post by Evan Anders (evan-anders) · 2023-12-22T21:17:30.331Z · LW · GW · 0 commentsThis is a link post for https://evanhanders.blog/2023/12/22/2-digit-subtraction-difference-prediction/
Contents
2 digit subtraction -- difference prediction Summary Intro Roadmap Attention Patterns Neuron Pre-Activations and Activations The function(s) implemented by attention heads The resulting neuron pre-activations The Neuron Activations Neuron-to-logit operation The Logits Wrap up Code Acknowledgments None No comments
2 digit subtraction -- difference prediction
Summary
I continue studying a toy 1-layer transformer language model trained to do two-digit addition of the form . After predicting if the model is positive (+) or negative (-) it must output the difference between and . The model creates activations which are oscillatory in the a and b bases, as well as activations which vary linearly as a function of . The model uses the activation function to couple oscillations across the and directions, and it then sums those oscillations to eliminate any variance except for that depending on the absolute difference to predict the correct output token. I examine the full path of this algorithm from input to model output.
Intro
In previous posts, I described training a two-digit subtraction transformer and studied how that transformer predicts the [LW · GW]sign of the [LW · GW] output. In this post, I look at how the model's weights have been trained to determine the difference between two-digit integers, and the emergent behavior of model activations.
Roadmap
Unlike in my previous post, where I worked back-to-front through the model, this time I'm going to take the opposite approach, starting at the input and understanding how it transforms into the output. I will break down in order:
- The attention patterns, and which tokens the model attends to to predict the output.
- How the attention head weights set the pre-activations.
- The emergent patterns in the pre-activations.
- The emergent patterns in the activations.
- The mapping between neurons and logits.
- The patterns in the logits.
Attention Patterns
Here's a visualization of the attention patterns for four examples (two positive, two negative, using the same numerical values for and but swapping them):
I've greyed out all of the unimportant rows so we can focus on the second to last row, which corresponds to the row where the transformer predicts the difference . There's a few things to notice here:
- When the result is positive, H0 attends to token and H3 attends to token .
- When the result is negative, H0 attends to token and H3 attends to token .
- H1 and H2 roughly attend roughly evenly to all tokens in the context.
So H3 always attends to whichever of and has a larger magnitude, and H0 always attends to whichever has the smaller magnitude. H1 and H2 don't seem to be doing anything too important.
In math, the above intuition can be written as the attention paid to token by each head:
and the attention paid to token by each head:
,
where is the value of the token at context position 4 (right after the =). It's reasonable to ask if this is a good description, and it is! If I zero-ablate the attention patterns in heads 1 and 2 and in heads 0 and 3 I replace the highlighted rows above with one-hot encoded unit vectors attending properly to either or , the loss on the problems in parameter space decreases from 0.0181 to 0.0168.
Neuron Pre-Activations and Activations
The function(s) implemented by attention heads
The matrix has shape [d_vocab, d_mlp] and describes how attention head alters the neuron pre-activations when that attention head attends to a token in the vocabulary. We know from the previous section that heads 0 and 3 attend fully either to tokens or , so we want to understand how heads 0 and 3 affect the pre-activations based on their inputs. (aside: in my previous post [LW · GW], I examined something like this, but there was a subtle bug in my calculation of which I've fixed).
Here are the four types of patterns that I see in the head contributions to the neuron activations:
So neuron 17 basically contributes linearly, and many other neurons have different types of oscillations. If we fit a line to each head's contribution to each neuron, subtract that, and then transform these signals into frequency space, we can form the following power spectra:
Here I've placed vertical lines at the three key frequencies (0.33, 0.39, 0.48) which I saw in my initial post on this topic. We see:
- The small oscillations on top of the linear trend in neuron 17 have a few peaks in the frequency spectrum associated with these key frequencies, but they're weak.
- Neuron 5 has strong peaks associated with each of these key frequencies, but the peak at is by far the strongest. Second, this dominant peak is not completely sharp. It has a strong central value as well as fairly strong "feet" values immediately to the side of the central value. This will be important in a sec.
- Neuron 76 is a mess. And there are a decent number of neurons like this. There's a strong peak but also a spread of power away from that frequency, falling off something like .
- Neuron 125 exhibits strong beats in the token signal and this corresponds to a gradual increase of power towards the peak at . This again is kind of a bit of a mess.
So what we're seeing is that these peaks can be described mostly in terms of a strong central peak and corresponding strong peaks in the frequency bins right next to that central peak. This is very reminiscent of the apodization of peaks we see in signal processes when we use windowing techniques -- power from a sharp central peak gets spread out a bit into adjacent bins.
I'm not entirely sure what the right function is to generally fit these functions, but this seems like as good a guess as any:
where is the vocab size corresponding to 0-99. Here, the second term in the sum over key frequencies responsible for the feet is inspired by past work I've done with the Hann window. When two sinusoidal terms with frequencies and are multiplied together, they put power into , so this form of the function I've listed above gives me a line, power in the key frequencies, and power in the bins next to the key frequencies.
I use scipy's curve_fit to fit functions of the form for the contributions of heads 0 and 3 to each neuron. I find that these fits accounts for 97% of the total variability in . For neurons like neuron 5 or neuron 17 above, this fit accounts for > 99% of the variability; neurons like neuron 125 have 97% of their variability accounted for. Neurons like neuron 76 are a mess, but this fit accounts for 90% of their variability.
I now replace the neuron preactivations with these fits (using the attention patterns I discovered in the previous section). I also keep the biases that end up being important: I find that the value bias from the attention heads and the embedding and positional embeddings (so the original residual stream that goes into the attention operation at context position 4) are important in setting the neuron preactivations.
When I make this approximation, loss increases on all problems from 0.0181 to 0.0341 and accuracy only decreases from 99.94% to 99.68%. Note however if I drop the "feet" term in my fit, then loss increases by over an order of magnitude and accuracy drops substantially, so this seems to be getting most of the important bits of what the model uses.
The resulting neuron pre-activations
In the previous section I examined how the model weights set the neuron pre-activations, but now I'm just going to look at the pre-activations and see if I can understand a simpler algorithm for describing them.
The previous section left me convinced that there would perhaps be four different classes of neurons. But, after a lot of struggling, it turns out there are only two, and both can be described using a simple fit:
where there is a linear portion (slope and intercept ) and an oscillatory fit (amplitudes , phases ), and each neuron has a dominant characteristic frequency in the dominant model frequencies.
The thing is -- this fit is only implemented by the model in the region where . In the region where , the model just reuses the same calculation that it came up with for the case with swapped inputs. So I fit the above function in the region where and so that uniformly . Then I calculate for all input values, and then I replace the output values everywhere with the corresponding case (e.g., I put the activation value from into the spot for ).
To see how this works, here are some sample neuron pre-activations (top row) and these fits (bottom row):
These fits look really pretty good! If I go into the model and replace all of the neuron pre-activations with best-fits to , then the model loss only changes from 0.0181 to 0.0222 and accuracy only drops from 99.94% to 99.92%. Further, after the ReLU(), the fits account for 95% of the variability in the neuron activations. So I'm pretty happy with this description of what the model does to construct the pre-activations!
The Neuron Activations
The ReLU() is our one nonlinearity in the problem. It spreads power from the nice fit I described above outwards. Specifically, I find (similar to Neel's grokking work) that
So power gets projected outward from the single-axis terms into cross-axis terms.
If I subtract the linear term from the neuron activation fits , then fit for the cross-axis oscillatory expression above, then add the linear contribution back in, I can create a guess for the neuron activations. But, if I replace the neuron activations with those guesses, the performance is really bad! Loss increases from 0.0181 to 2.496 and accuracy decreases from 99.94% to 19.79%.
Perversely, if I instead modify that prior guess so that the magnitude of the cross-frequency term is boosted, performance is much better! More concretely, I take the same procedure above, but after fitting I boost by a factor of . Replacing the neurons with this boosted approximation decreases loss back down to 0.091 and accuracy increases back up to 98.79%. Not perfect, but it seems like indeed the neurons are largely using a combination of the ReLU'd linear terms and these cross-axis terms to compute the solution.
I'm not completely satisfied with my explanation here, but I also want to wrap up looking at this toy problem before I take vacation tonight from Christmas through New Year's, so let's move along!
Neuron-to-logit operation
If the neuron activations are known (a matrix of shape [batch_a, batch_b, d_mlp]), then the logits can be recovered by multiplying by the matrix , which has shape [d_mlp, d_vocab].
If I Fourier Transform along the vocab direction for tokens 0-99, then take the mean power spectrum over all those tokens, I see the same three strong peaks at the same characteristic frequencies (0.33, 0.39, 0.48), and I also see a new peak that shows up at . This peak is associated with the linear features in the activations, and the other three peaks are associated with the oscillatory features examined above. See below for some examples of neuron activations and [d_vocab] vectors in :
There are a number of neurons like neuron 10 on the left. This neuron in particular suppresses low-value logits and boosts high-value logits when and are both large (this can be read off from a combination of the top and middle row plots). Neuron 17, in the second row, boosts the logit values of small number tokens and large number tokens (the latter of which seems like an error in what the model learns?) but suppresses intermediate valued tokens when and are similar in magnitude and are both small. In both cases, the dominant frequency in the vector is 0.01, because it contains one smooth sinusoidal feature which throttles some tokens and boosts others.
The right two plots show a different story. Neurons 75 and 125 are oscillatory, and they affects the logits in an oscillatory fashion. The oscillations of the vectors for these neurons have the same characteristic frequencies as the neurons themselves. The bottom panel plots are showing power spectra, and for the right two columns (neurons 75 and 125) I'm plotting both the power spectrum of the vector and also denoting the dominant neuron frequency (calculated in the previous sections) in a vertical line -- and there's really good agreement between the neuron and mapping vector frequencies!
So if I was right in the previous section and the oscillatory neurons have terms like , then these vectors are mapping in the following way:
for some amplitude and phase , and where is the contribution to logit for token from neuron . In a reduced form this is
where . So there are some trigonometric terms oscillating in space, and the value of shifts the phase of those terms just a bit to adjust the logits to what they need to be.
In the end, the logits themselves are constructed from a sum over the neurons , which contain linear terms like neurons 10 and 17 and then these oscillatory terms on tops of those peaks and valleys set up by the linear terms.
The Logits
In the end, we get logits that look like this (plotted here are logit maps for 0, 25, 50, an 75):
One thing that pops out to me is that the patterns are no longer oscillatory in two dimensions -- they seem to be oscillatory solely in terms of the value of . This suggests that the model sets the values of and so that it can throw out the oscillatory terms with (I ran out of time to test this).
This also makes intuitive sense: along the diagonals in the plane I've plotted above, the result of subtraction is always the same! If you increase by one and also increase by one, the difference between the two remains the same. So it makes sense that the model would learn a solution that varies and oscillates in the direction but which is perfectly constant along the direction perpendicular to that. Neat!
Wrap up
I feel ~85% confident that I've figured out the algorithm that this model is using to do subtraction. There are a few spots (particularly towards the end, after the activation function) where things got a bit rushed and hand-wavey, and if I had infinite time to spend on this model then I would solidify and polish some of the concepts there. But I don't!
I'm going to call it quits on this model and move on to something else perhaps a bit more interesting in my next post, but if anyone has any ideas of holes in my analysis or thoughts about things the model might be doing that I'm not examining, I'd be happy to discuss!
Code
The code I used to train the model can be found in this colab notebook, and the notebook used for this investigation can be found on Github.
Acknowledgments
Thanks again to Adam Jermyn for his mentorship and advice, and thanks to Philip Quirke and Eoin Farrell for our regular meetings. Thanks also to Alex Atanasov and Xianjun Yang for taking the time to meet with me and chat ML/AI this week!
0 comments
Comments sorted by top scores.