Short Remark on the (subjective) mathematical 'naturalness' of the Nanda--Lieberum addition modulo 113 algorithm
post by carboniferous_umbraculum (Spencer Becker-Kahn) · 2023-06-01T11:31:37.796Z · LW · GW · 12 commentsContents
12 comments
These remarks are basically me just wanting to get my thoughts down after a Twitter exchange on this subject. I've not spent much time on this post and it's certainly plausible that I've gotten things wrong.
In the 'Key Takeaways' section of the Modular Addition part of the well-known post 'A Mechanistic Interpretability Analysis of Grokking [LW · GW]' , Nanda and Lieberum write:
This algorithm operates via using trig identities and Discrete Fourier Transforms to map , and then extracting
And
The model is trained to map to (henceforth 113 is referred to as )
But the casual reader should use caution! It is in fact the case that "Inputs are given as one-hot encoded vectors in ". This point is of course emphasized more in the full notebook (it has to be, that's where the code is), and the arXiv paper that followed is also much clearer about this point. However, when giving brief takeaways from the work, especially when it comes to discussing how 'natural' the learned algorithm is, I would go as far as saying that it is actually misleading to suggest that the network is literally given and as inputs. It is not trained to 'act' on the numbers themselves.
When thinking seriously about why the network is doing the particular thing that it is doing at the mechanistic level, I would want to emphasize that one-hotting is already a significant transformation. You have moved away from having the number be represented by its own magnitude. You instead have a situation in which and now really live 'in the domain' (its almost like a dual point of view: The number is not the size of the signal, but the position at which the input signal is non-zero).
So, while I of course fully admit that I too am looking at it through my own subjective lens, one might say that (before the embedding happens) it is more mathematically natural to think that what the network is 'seeing' as input is something like the indicator functions and . Here, is something like the 'token variable' in the sense that these are functions on the vocabulary. And if we essentially ignore the additional tokens for | and =, we can think that these are functions on the group and that we would like the network to learn to produce the function at its output neurons.
In particular, this point of view further (and perhaps almost completely) demystifies the use of the Fourier basis.
Notice that the operation you want to learn is manifestly a convolution operation, i.e.
And (as I distinctly remember being made to practically chant in an 'Analysis of Boolean Functions' class given by Tom Sanders) the Fourier Transform is the (essentially unique) change of basis that simultaneously diagonalizes all convolution operations. This is coming close to saying something like: There is one special basis that makes the operation you want to learn uniquely easy to do using matrix multiplications, and that basis is the Fourier basis.
12 comments
Comments sorted by top scores.
comment by Rohin Shah (rohinmshah) · 2023-07-02T12:18:04.183Z · LW(p) · GW(p)
In particular, this point of view further (and perhaps almost completely) demystifies the use of the Fourier basis.
I disagree at least with the "almost completely" version of this claim:
Notice that the operation you want to learn is manifestly a convolution operation, i.e.
This also applies to the non-modular addition operation, but I think it's pretty plausible that if you train on non-modular addition (to the point of ~perfect generalization), the network would learn an embedding that converts the "tokenized" representation back into the "magnitude" representation, and then simply adds them as normal.
Some evidence for this:
- I've heard it claimed that an LLM represented the number of characters in a token as a magnitude, which was used for deciding whether a line of code was > 80 characters (useful for predicting line breaks).
- This paper trains on non-modular addition and gets this result. (Note however that the paper has a highly unusual setting that isn't representative of typical network training, and arguably the setup is such that you have to get this result: in particular the architecture enforces that the embeddings of the two inputs are added together, which wouldn't work in a Fourier basis. I cite it as evidence that networks do learn magnitude representations when forced to do so.)
It seems like if you believe "when the operation you want to learn is a convolution operation, then you will learn the Fourier basis", you should also believe that you'll get a Fourier basis for non-modular addition on one-hot-encoded numbers, and currently my guess is that that's false.
Fwiw, I agree that the algorithm is quite "mathematically natural" (indeed, one person came up with the algorithm independently, prompted by "how would you solve this problem if you were a Transformer"), though I feel like the "modular" part is pretty crucial for me (and the story I'd tell would be the one in Daniel's comment [LW(p) · GW(p)]).
Replies from: Spencer Becker-Kahn↑ comment by carboniferous_umbraculum (Spencer Becker-Kahn) · 2023-09-05T09:58:34.651Z · LW(p) · GW(p)
Thanks for the comment Rohin, that's interesting (though I haven't looked at the paper you linked).
I'll just record some confusion I had after reading your comment that stopped me replying initially: I was confused by the distinction between modular and non-modular because I kept thinking: If I add a bunch of numbers and and don't do any modding, then it is equivalent to doing modular addition modulo some large number (i.e. at least as large as the largest sum you get). And otoh if I tell you I'm doing 'addition modulo 113', but I only ever use inputs that add up to 112 or less, then you never see the fact that I was secretly intending to do modular addition. And these thoughts sort of stopped me having anything more interesting to add.
↑ comment by Rohin Shah (rohinmshah) · 2023-09-06T19:08:07.621Z · LW(p) · GW(p)
I agree -- the point is that if you train on addition examples without any modular wraparound (whether you think of that as regular addition or modular addition with a large prime, doesn't super matter), then there is at least some evidence that you get a different representation than the one Nanda et al found.
comment by Garrett Baker (D0TheMath) · 2023-06-01T18:11:35.760Z · LW(p) · GW(p)
This is a neat observation, but I'm reminded of a story I was told about a math professor:
One day while in the middle of a long proof of an arcane theorem, the professor was stopped and questioned about a particular step by a student, who wondered what made that step true. The professor said "Its trivial!" then thought a bit more about the step, mumbled to himself "Wait, is it trivial?", and excused himself to step out of the hall and think. Ten minutes later, he comes back into the hall and declares the step was indeed trivial, and proceeds along with the proof.
This feels similar to me. Neel and Tom figure out this algorithm 9.5 months ago, and now mathematicians have just realized that indeed the algorithm is obvious and simple, and indeed the only way to the operation described when matrix multiplications come easy.
Not to say the insight is wrong, but I would be far more impressed if you were able to predict the algorithm a network does in advance through similar reasoning rather than a 9.5 month later justification.
Replies from: Spencer Becker-Kahn, DanielFilan, tailcalled↑ comment by carboniferous_umbraculum (Spencer Becker-Kahn) · 2023-06-02T10:42:50.782Z · LW(p) · GW(p)
Hi Garrett,
OK so just being completely honest, I don't know if it's just me but I'm getting a slightly weird or snarky vibe from this comment? I guess I will assume there is a good faith underlying point being made to which I can reply. So just to be clear:
- I did not use any words such as "trivial", "obvious" or "simple". Stories like the one you recount are obviously making fun of mathematicians, some of whom do think its cool to say things are trivial/simple/obvious after they understand them. I often strongly disagree and generally dislike this behaviour and think there are many normal mathematicians who don't engage in this sort of thing. In particular sometimes the most succinct insights are the hardest ones to come by (this isn't a reference to my post; just a general point). And just because such insights are easily expressible once you have the right framing and the right abstractions, they should by no means be trivialized.
- I deliberately emphasized the subjectivity of making the sorts of judgements that I am making. Again this kinda forms part of the joke of the story.
- I have indeed been aware of the work since when it was first posted 10 months ago or so and have given it some thought on and off for a while (in the first sentence of the post I was just saying that I didn't spend long writing the post, not that these thoughts were easily arrived-at).
- I do not claim to have explained the entire algorithm, only to shed some light on why it might actually be a more natural thing to do than some people seem to have appreciated.
- I think the original work is of a high quality and one might reasonably say 'groundbreaking'.
In another one of my posts [LW · GW] I discuss at more length the kind of thing you bring up in the last sentence of your comment, e.g.
it can feel like the role that serious mathematics has to play in interpretability is primarily reactive, i.e. consists mostly of activities like 'adding' rigour after the fact or building narrow models to explain specific already-observed phenomena.
....[but]... one of the most lauded aspects of mathematics is a certain inevitability with which our abstractions take on a life of their own and reward us later with insight, generalization, and the provision of predictions. Moreover - remarkably - often those abstractions are found in relatively mysterious, intuitive ways: i.e. not as the result of us just directly asking "What kind of thing seems most useful for understanding this object and making predictions?" but, at least in part, as a result of aesthetic judgement and a sense of mathematical taste.
And e.g. I talk about how this sort of thing has been the case in areas like mathematical physics for a long time. Part of the point is that (in my opinion, at least) there isn't any neat shortcut to the kind of abstract thinking that lets you make the sort of predictions you are making reference to. It is very typical that you have to begin by reacting to existing empirical phenomena and using it as scaffolding. But I think, to me, it has come across as that you are being somewhat dismissive of this fact? As if, when B might well follow from A and someone actually starts to do A, you say "I would be far more impressed if B" instead of "maybe that's progress towards B"?
(Also FWIW, Neel claims here that regarding the algorithm itself, another researcher he knows "roughly predicted this".)
↑ comment by Garrett Baker (D0TheMath) · 2023-06-05T18:26:14.084Z · LW(p) · GW(p)
I don't know if it's just me but I'm getting a slightly weird or snarky vibe from this comment?
Sorry about that. On a re-read, I can see how the comment could be seen as snarky, but I was going more for critical via illustrative analogy. Oh the perils of the lack of inflection and facial expressions.
I think your criticisms of my thought in the above comment are right-on, and you've changed my mind on how useful your post was. I do think that lots of progress can be made in understanding stuff by just finding the right frame by which the result seems natural, and your post is doing this. Thanks!
↑ comment by DanielFilan · 2023-06-01T21:28:35.270Z · LW(p) · GW(p)
My submission: when we teach modular arithmetic to people, we do it using the metaphor of clock arithmetic. Well, if you ignore the multiple frequencies and argmax weirdness, clock arithmetic is exactly what this network is doing! Find the coordinates of rotating the hour hand (on a 113-hour clock) x hours, then y hours, use trig identities to work out what it would be if you rotated x+y hours, then count how many steps back you have to rotate to get to 0 to tell where you ended up. In fairness, the final step is a little bit different than the usual imagined rule of "look at the hour mark where the hand ends up", but not so different that clock arithmetic counts as a bad prediction IMO.
Replies from: tailcalled↑ comment by tailcalled · 2023-06-02T10:10:59.703Z · LW(p) · GW(p)
Is this really an accurate analogy? I feel like clock arithmetic would be more like representing it as a rotation matrix, not a Fourier basis.
Replies from: DanielFilan↑ comment by DanielFilan · 2023-06-02T16:58:49.777Z · LW(p) · GW(p)
I agree a rotation matrix story would fit better, but I do think it's a fair analogy: the numbers stored are just coses and sines, aka the x and y coordinates of the hour hand.
Replies from: DanielFilan↑ comment by DanielFilan · 2023-06-02T17:00:54.454Z · LW(p) · GW(p)
Like, the only reason we're calling it a "Fourier basis" is that we're looking at a few different speeds of rotation, in order to scramble the second-place answers that almost get you a cos of 1 at the end, while preserving the actual answer.
↑ comment by tailcalled · 2023-06-02T10:18:10.410Z · LW(p) · GW(p)
Here's another way of looking at it which could be said to make it more trivial:
We can transform addition into multiplication by taking the exponential, i.e. x+y=z is equivalent to 10^x * 10^y = 10^z.
But if we unfold the digits into separate axes rather than as a single number, then 10^n is just a one-hot encoding of the integer n.
Taking the Fourier transform of the digits to do convolutions is a well-known fast multiplication algorithm.
comment by Rudi (rudolf-zeidler) · 2023-06-03T14:28:34.036Z · LW(p) · GW(p)
Thanks for putting this so succintly! To add another subjective data point, I had very similar thoughts immediately after I first saw this work (and the more conceptual follow-up by Chugtai et al) a few months ago.
About "one-hotting being a significant transformation": I have a somewhat opposite intuition here and would say that this is also quite natural.
Maybe at first glance one would find it more intuitive to represent the inputs as a subset of the real numbers (or floats, I guess) and think of modular addition as some garbled version of the usual addition. But the group structure on a finite cyclic group and the vector space structure on the real number line are not really compatible, so I'm not sure that this actually makes that much sense bearing in mind that the model has to work mostly with linear transformations.
On the other hand, if such a representation was useful, the model could in principle learn an embedding which takes all the one-hot embedded inputs to the same one-dimensional space but with different lengths.
In fact, one-hotting is in a precise sense the most general way of embedding a given set in a vector space because it does not impose any additional linear relations (in mathematical jargon it's the free vector space generated by the set, and is characterized by the universal property of turning arbitary maps on the generating set into unique linear maps on the vector space). In this sense I'd view using a one-hot embedding as the natural way of designing the architecture if I don't want create a bias towards a particular linear representation.
As a side remark, although it's in a sense completely trivial, the "one-hotting construction" is also used as an important ingredient in many areas of mathematics. One example would be homology theory in algebraic topology, where one turns geometric/combinatorial objects into a vector space in this way and then does linear algebra on that space rather than working with the objects directly. Another example, closer to the problem discussed here, is turning a group into the corresponding group algebra in representation theory.