How does a toy 2 digit subtraction transformer predict the difference?

post by Evan Anders (evan-anders) · 2023-12-22T21:17:30.331Z · LW · GW · 0 comments

This is a link post for https://evanhanders.blog/2023/12/22/2-digit-subtraction-difference-prediction/

Contents

  2 digit subtraction -- difference prediction
  Summary
  Intro
    Roadmap
  Attention Patterns
  Neuron Pre-Activations and Activations
    The function(s) implemented by attention heads
    The resulting neuron pre-activations
    The Neuron Activations
  Neuron-to-logit operation
  The Logits
  Wrap up
    Code
    Acknowledgments
None
No comments


2 digit subtraction -- difference prediction

Summary

I continue studying a toy 1-layer transformer language model trained to do two-digit addition of the form . After predicting if the model is positive (+) or negative (-) it must output the difference between  and . The model creates activations which are oscillatory in the a and b bases, as well as activations which vary linearly as a function of . The model uses the activation function to couple oscillations across the  and  directions, and it then sums those oscillations to eliminate any variance except for that depending on the absolute difference  to predict the correct output token. I examine the full path of this algorithm from input to model output.

Intro

In previous posts, I described training a two-digit subtraction transformer and studied how that transformer predicts the [LW · GW]sign of the [LW · GW] output. In this post, I look at how the model's weights have been trained to determine the difference between two-digit integers, and the emergent behavior of model activations.

Roadmap

Unlike in my previous post, where I worked back-to-front through the model, this time I'm going to take the opposite approach, starting at the input and understanding how it transforms into the output. I will break down in order:

Attention Patterns

Here's a visualization of the attention patterns for four examples (two positive, two negative, using the same numerical values for  and  but swapping them):

This image has an empty alt attribute; its file name is attention_head_examples-1.png

I've greyed out all of the unimportant rows so we can focus on the second to last row, which corresponds to the row where the transformer predicts the difference . There's a few things to notice here:

So H3 always attends to whichever of  and  has a larger magnitude, and H0 always attends to whichever has the smaller magnitude. H1 and H2 don't seem to be doing anything too important.

In math, the above intuition can be written as the attention paid to token  by each head:

and the attention paid to token  by each head:

,

where  is the value of the token at context position 4 (right after the =). It's reasonable to ask if this is a good description, and it is! If I zero-ablate the attention patterns in heads 1 and 2 and in heads 0 and 3 I replace the highlighted rows above with one-hot encoded unit vectors attending properly to either  or , the loss on the  problems in parameter space decreases from 0.0181 to 0.0168.

Neuron Pre-Activations and Activations

The function(s) implemented by attention heads

The matrix  has shape [d_vocab, d_mlp] and describes how attention head  alters the neuron pre-activations when that attention head attends to a token in the vocabulary. We know from the previous section that heads 0 and 3 attend fully either to tokens  or , so we want to understand how heads 0 and 3 affect the pre-activations based on their inputs. (aside: in my previous post [LW · GW], I examined something like this, but there was a subtle bug in my calculation of   which I've fixed).

Here are the four types of patterns that I see in the head contributions to the neuron activations:

This image has an empty alt attribute; its file name is token_attention_neuron_contributions-2.png



So neuron 17 basically contributes linearly, and many other neurons have different types of oscillations. If we fit a line to each head's contribution to each neuron, subtract that, and then transform these signals into frequency space, we can form the following power spectra:

This image has an empty alt attribute; its file name is token_attention_neuron_contributions_freqspace.png

Here I've placed vertical lines at the three key frequencies (0.33, 0.39, 0.48) which I saw in my initial post on this topic. We see:

So what we're seeing is that these peaks can be described mostly in terms of a strong central peak and corresponding strong peaks in the frequency bins right next to that central peak. This is very reminiscent of the apodization of peaks we see in signal processes when we use windowing techniques -- power from a sharp central peak gets spread out a bit into adjacent bins.

I'm not entirely sure what the right function is to generally fit these functions, but this seems like as good a guess as any:

where   is the vocab size corresponding to 0-99. Here, the second term in the sum over key frequencies responsible for the feet is inspired by past work I've done with the Hann window. When two sinusoidal terms with frequencies  and  are multiplied together, they put power into , so this form of the function I've listed above gives me a line, power in the key frequencies, and power in the bins next to the key frequencies.

I use scipy's curve_fit to fit functions of the form  for the contributions of heads 0 and 3 to each neuron. I find that these fits accounts for  97% of the total variability in . For neurons like neuron 5 or neuron 17 above, this fit accounts for > 99% of the variability; neurons like neuron 125 have  97% of their variability accounted for. Neurons like neuron 76 are a mess, but this fit accounts for  90% of their variability.

I now replace the neuron preactivations with these fits (using the attention patterns I discovered in the previous section). I also keep the biases that end up being important: I find that the value bias from the attention heads and the embedding and positional embeddings (so the original residual stream that goes into the attention operation at context position 4) are important in setting the neuron preactivations.

When I make this approximation, loss increases on all  problems from 0.0181 to 0.0341 and accuracy only decreases from 99.94% to 99.68%. Note however if I drop the "feet" term in my fit, then loss increases by over an order of magnitude and accuracy drops substantially, so this seems to be getting most of the important bits of what the model uses.

The resulting neuron pre-activations

In the previous section I examined how the model weights set the neuron pre-activations, but now I'm just going to look at the pre-activations and see if I can understand a simpler algorithm for describing them.

The previous section left me convinced that there would perhaps be four different classes of neurons. But, after a lot of struggling, it turns out there are only two, and both can be described using a simple fit:

where there is a linear portion (slope  and intercept ) and an oscillatory fit (amplitudes , phases ), and each neuron has a dominant characteristic frequency  in the dominant model frequencies.

The thing is -- this fit is only implemented by the model in the region where . In the region where , the model just reuses the same calculation that it came up with for the  case with swapped inputs. So I fit the above function  in the region where   and  so that uniformly . Then I calculate  for all input values, and then I replace the output values everywhere  with the corresponding  case (e.g., I put the activation value from  into the spot for ).

To see how this works, here are some sample neuron pre-activations (top row) and these fits (bottom row):

This image has an empty alt attribute; its file name is sample_neuron_preactivations.png

These fits look really pretty good! If I go into the model and replace all of the neuron pre-activations with best-fits to , then the model loss only changes from 0.0181 to 0.0222 and accuracy only drops from 99.94% to 99.92%. Further, after the ReLU(), the fits account for 95% of the variability in the neuron activations. So I'm pretty happy with this description of what the model does to construct the pre-activations!

The Neuron Activations

The ReLU() is our one nonlinearity in the problem. It spreads power from the nice fit I described above outwards. Specifically, I find (similar to Neel's grokking work) that

So power gets projected outward from the single-axis terms into cross-axis terms.

If I subtract the linear term from the neuron activation fits , then fit for the cross-axis oscillatory expression above, then add the linear contribution back in, I can create a guess for the neuron activations. But, if I replace the neuron activations with those guesses, the performance is really bad! Loss increases from 0.0181 to 2.496 and accuracy decreases from 99.94% to 19.79%.

Perversely, if I instead modify that prior guess so that the magnitude of the cross-frequency term is boosted, performance is much better! More concretely, I take the same procedure above, but after fitting I boost  by a factor of . Replacing the neurons with this boosted approximation decreases loss back down to 0.091 and accuracy increases back up to 98.79%. Not perfect, but it seems like indeed the neurons are largely using a combination of the ReLU'd linear terms and these cross-axis terms to compute the solution.

I'm not completely satisfied with my explanation here, but I also want to wrap up looking at this toy problem before I take vacation tonight from Christmas through New Year's, so let's move along!

Neuron-to-logit operation

If the neuron activations  are known (a matrix of shape [batch_a, batch_b, d_mlp]), then the logits can be recovered by multiplying by the matrix , which has shape [d_mlp, d_vocab].

If I Fourier Transform  along the vocab direction for tokens 0-99, then take the mean power spectrum over all those tokens, I see the same three strong peaks at the same characteristic frequencies (0.33, 0.39, 0.48), and I also see a new peak that shows up at . This peak is associated with the linear features in the activations, and the other three peaks are associated with the oscillatory features examined above. See below for some examples of neuron activations and [d_vocab] vectors in :

This image has an empty alt attribute; its file name is sample_neurons_w_logit-2.png

There are a number of neurons like neuron 10 on the left. This neuron in particular suppresses low-value logits and boosts high-value logits when  and  are both large (this can be read off from a combination of the top and middle row plots). Neuron 17, in the second row, boosts the logit values of small number tokens and large number tokens (the latter of which seems like an error in what the model learns?) but suppresses intermediate valued tokens when  and  are similar in magnitude and are both small. In both cases, the dominant frequency in the  vector is 0.01, because it contains one smooth sinusoidal feature which throttles some tokens and boosts others.

The right two plots show a different story. Neurons 75 and 125 are oscillatory, and they affects the logits in an oscillatory fashion. The oscillations of the  vectors for these neurons have the same characteristic frequencies as the neurons themselves. The bottom panel plots are showing power spectra, and for the right two columns (neurons 75 and 125) I'm plotting both the power spectrum of the  vector and also denoting the dominant neuron frequency (calculated in the previous sections) in a vertical line -- and there's really good agreement between the neuron and mapping vector frequencies!

So if I was right in the previous section and the oscillatory neurons have terms like , then these vectors are mapping in the following way:

for some amplitude  and phase , and where  is the contribution to logit for token  from neuron . In a reduced form this is

where . So there are some trigonometric terms oscillating in  space, and the value of  shifts the phase of those terms just a bit to adjust the logits to what they need to be.

In the end, the logits themselves are constructed from a sum over the neurons , which contain linear terms like neurons 10 and 17 and then these oscillatory terms on tops of those peaks and valleys set up by the linear terms.

The Logits

In the end, we get logits that look like this (plotted here are logit maps for 0, 25, 50, an 75):

This image has an empty alt attribute; its file name is logit_values.png

One thing that pops out to me is that the patterns are no longer oscillatory in two dimensions -- they seem to be oscillatory solely in terms of the value of . This suggests that the model sets the values of  and  so that it can throw out the oscillatory terms with  (I ran out of time to test this).

This also makes intuitive sense: along the diagonals in the  plane I've plotted above, the result of subtraction is always the same! If you increase  by one and also increase  by one, the difference between the two remains the same. So it makes sense that the model would learn a solution that varies and oscillates in the  direction but which is perfectly constant along the direction perpendicular to that. Neat!

Wrap up

I feel ~85% confident that I've figured out the algorithm that this model is using to do subtraction. There are a few spots (particularly towards the end, after the activation function) where things got a bit rushed and hand-wavey, and if I had infinite time to spend on this model then I would solidify and polish some of the concepts there. But I don't!

I'm going to call it quits on this model and move on to something else perhaps a bit more interesting in my next post, but if anyone has any ideas of holes in my analysis or thoughts about things the model might be doing that I'm not examining, I'd be happy to discuss!

Code

The code I used to train the model can be found in this colab notebook, and the notebook used for this investigation can be found on Github.

Acknowledgments

Thanks again to Adam Jermyn for his mentorship and advice, and thanks to Philip Quirke and Eoin Farrell for our regular meetings. Thanks also to Alex Atanasov and Xianjun Yang for taking the time to meet with me and chat ML/AI this week!

0 comments

Comments sorted by top scores.