Goal oriented cognition in "a single forward pass"
post by dxu, habryka (habryka4) · 2024-04-22T05:03:18.649Z · LW · GW · 15 commentsContents
Meta on dxu's background Shortform discussion Non-myopic objectives/n-token predictors None 15 comments
The below is me (habryka) and dxu talking about a shortform that dxu had published a few months ago, going into the relationship between goal-oriented cognition and the myopic nature of current large language model training setups. Some key quotes if you don't want to read the whole thing:
I think it's interesting to think about how brains get around the issue [of having limited context]. Obviously the patterns in which biological neurons fire don't carve cleanly into "forward passes" the way ANNs do, but ultimately there's only so much computation that can occur in the brain within some (small) unit of time. In that sense, I think I want to claim that the brain's ability to engage in something that looks a whole lot like long-term, goal-oriented reasoning clearly can't depend on being able to hold that entire context in memory to be attended to, even selectively.
When I think about the internal experience of problem-solving, there's a mental move of going up and down the ladder of abstraction, where you zoom in on some particularly difficult and/or confusing part of the problem, solve it, and then use what you learned from that to zoom back out and fill in a gap in the larger problem you were trying to solve. For an LLM, that seems like it's harder, and indeed it's one of the reasons I inside-view suspect LLMs as-currently-trained might not actually scale to AGI. (Though there are obviously some terminological issues surrounding "AGI" at the moment, what I mean by that is something like the classical concept of "something better than humans at pretty much every cognitive task humans can do".)
An LLM doesn't learn how it can best reason within a 2048-token context. The human cognition that it is imitating has been shaped by lots of feedback that propagates back through multiple tokens. The human has learned to avoid cognitive traps and routes around the things that would cause the human to go off track. But if there is something that would cause the LLM to go off-track on a task, it will just do it again and again every time.
dxu
Okay, but doesn't that also suggest that an LLM trained on human-generated data would reach human-level intelligence internally before its output began to successfully mirror the human output?
habryka
I would strongly predict it would develop many vastly superhuman capabilities, yes. (Like having memorized vastly more facts than any human alive, or be vastly better at predicting the next token of a body of text than any human alive, or be much faster at writing code than any human alive.)
dxu
Yeah, I see your point, and in fact GPT-4 is certainly past that point on many of those metrics.
The dialogue had no big conclusion, though I found thinking about "n-token" and "1-token" reasoners useful as an abstraction that has come up in a few conversations I've been in since we wrote this dialogue. I would also be interested in hearing more from people who have been closer to the metal with transformers on helping me (and others) understand the degree to which transformers are maybe doing something more like joint optimization.
Meta on dxu's background
Shortform discussion
Non-myopic objectives/n-token predictors
15 comments
Comments sorted by top scores.
comment by Bogdan Ionut Cirstea (bogdan-ionut-cirstea) · 2024-09-09T13:13:15.788Z · LW(p) · GW(p)
But there are various self-play or self-critique like approaches that could just defeat the data paucity here, and I am very concerned this will be associated with an enormous capability spike.
You might be interested in this paper and summary thread.
comment by Gunnar_Zarncke · 2024-04-22T09:54:35.782Z · LW(p) · GW(p)
there's a mental move of going up and down the ladder of abstraction, where you zoom in on some particularly difficult and/or confusing part of the problem, solve it, and then use what you learned from that to zoom back out and fill in a gap in the larger problem you were trying to solve. For an LLM, that seems like it's harder, and indeed it's one of the reasons I inside-view suspect LLMs as-currently-trained might not actually scale to AGI. [bold by me]
But that might already no longer be true with model that have short term memory and may might make moves like you. See my Leave No Context Behind - A Comment [LW · GW].
comment by johnswentworth · 2024-04-22T05:45:09.462Z · LW(p) · GW(p)
(Didn't read most of the dialogue, sorry if this was covered.)
But the way transformers work is they greedily think about the very next token, and predict that one, even if by conditioning on it you shot yourself in the foot for the task at hand.
That depends on how we sample from the LLM. If, at each "timestep", we take the most-probable token, then yes that's right.
But an LLM gives a distribution over tokens at each timestep, i.e. . If we sample from that distribution, rather than take the most-probable at each timestep, then that's equivalent to sampling non-greedily from the learned distribution over text. It's the chain rule:
Replies from: habryka4
↑ comment by habryka (habryka4) · 2024-04-22T06:16:25.652Z · LW(p) · GW(p)
I think you are talking about a different probability distribution here.
You are right that this allows you to sample non-greedily from the learned distribution over text, but I was talking about the inductive biases on the model.
My claim was that the way LLMs are trained, the way the inductive biases shake out is that the LLM won't be incentivized to output tokens that predictably have low probability, but make it easier to predict future tokens (by, for example, in the process of trying to predict a proof, reminding itself of all the of the things its knows before those things leave its context window, or when doing an addition that it can't handle in a single forward pass, outputting a token that's optimized to give itself enough serial depth to perform the full addition of two long n-digit digit numbers, which would then allow it to get the next n tokens right and so overall achieve lower joint loss).
Replies from: habryka4↑ comment by habryka (habryka4) · 2024-05-01T22:45:25.073Z · LW(p) · GW(p)
@johnswentworth [LW · GW] I think this paper basically does the thing I was talking about (with pretty impressive results), though I haven't read it in a ton of detail: https://news.ycombinator.com/item?id=40220851
Replies from: ryan_greenblatt↑ comment by ryan_greenblatt · 2024-05-02T23:42:02.636Z · LW(p) · GW(p)
Hmm, I don't think so. Or at least, the novel things in that paper don't seem to correspond.
My understanding of what this paper does:
- Trains models to predict next 4 tokens instead of next 1 token as an auxilary training objective. Note that this training objective yields better performance on downstream tasks when just using the next token prediction component (the normally trained component) and discarding the other components. Notable, this is just something like "adding this additional prediction objective helps the model learn more/faster". In other words, this result doesn't involve actually changing how the model is actually used, it just adds some additional training task.
- Uses these heads for speculative executation, a well known approach in the literature for accelerating inference.
↑ comment by habryka (habryka4) · 2024-05-03T00:24:56.968Z · LW(p) · GW(p)
Hmm, I think the first bullet point is pretty precisely what I am talking about (though to be clear, I haven't read the paper in detail).
I was specifically saying that trying to somehow get feedback from future tokens into the next token objective would probably do some interesting things and enable a bunch of cross-token optimization that currently isn't happening, which would improve performance on some tasks. This seems like what's going on here.
Agree that another major component of the paper is accelerating inference, which I wasn't talking about. I would have to read the paper in more detail to get a sense of how much it's just doing that, in which case I wouldn't think it's a good example.
comment by faul_sname · 2024-04-22T20:30:17.690Z · LW(p) · GW(p)
And I think one way to create a 2-token reasoner is to generate all plausible completions of 2 tokens, and then propagate the joint loss of the log-probs of those two tokens.
I think this just doesn't work very well, because it incentivizes the model to output a token which makes subsequent tokens easier to predict, as long as the benefit in predictability of the subsequent token(s) outweighs the cost of the first token. Concretely, let's say you have the input "Once upon a time, there was a" and you want 32 tokens. Right now, davinci-002
will spit out something like [" little"," girl"," who"," was"," born"," with"," a"," very"," special"," gift","."," She"," could"," see"," things"," that"," others"," could"," not","."," She"," could"," see"," the"," future",","," and"," she"," could"," see"," the"," past"]
, with logprobs of [-2.44, -0.96, -0.90, ..., -0.28, -0.66, 0.26]
, summing to -35.3. But if instead, it returned [" a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"]
, it would have logprobs like [-9.32, -7.77, -1.51, ..., -0.06, -0.05, -0.05]
, summing to -23.5. And indeed, if you could somehow ask a couple quadrillion people "please write a story starting with Once upon a time, there was a
", I suspect that at least 1 in a million people would answer with low-entropy completions along the lines of a a a a ...
(and there just aren't that many low-entropy completions). But "Once upon a time there was a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a" is not a very good completion, despite being a much higher-probability completion.
You could use a more sophisticated loss function that "sum of individual-token logprob", but I think that road leads towards PPO (nothing says that your criterion has to be "helpful/harmful/honest as judged by a human rater" though).
↑ comment by habryka (habryka4) · 2024-04-22T21:07:28.275Z · LW(p) · GW(p)
I think this just doesn't work very well, because it incentivizes the model to output a token which makes subsequent tokens easier to predict, as long as the benefit in predictability of the subsequent token(s) outweighs the cost of the first token.
Hmm, this doesn't sound right. The ground truth data would still be the same, so if you were to predict "aaaaaa" you would get the answer wrong. In the above example, you are presumably querying the log props of the model that was trained on 1-token prediction, which of course would think it's quite likely that conditional on the last 10 characters being "a" the next one will be "a", but I am saying "what is the probability of the full completion 'a a a a a...' given the prefix 'Once upon a time, there was a'", which doesn't seem very high.
The only thing I am saying here is "force the model to predict more than one token at a time, conditioning on its past responses, then evaluate the model on performance of the whole set of tokens". I didn't think super hard about what the best loss function here is, and whether you would have to whip out PPO for this. Seems plausible.
Replies from: faul_sname↑ comment by faul_sname · 2024-04-22T22:15:09.694Z · LW(p) · GW(p)
I think the probability of getting the exact continuation "a a a a a ..." is genuinely higher than the probability of getting the exact continuation "little girl who was born with a very special gift...", though getting a continuation in the class of "a a a a a..." is much lower-probability than getting a continuation in the class of "little girl who was born with a very special gift..", because the latter class has a much larger possibility space than the former. So there might be 1e4 different low-entropy length-32 completions with an average probability of 1e-10 each, and 9.999999e15 different high-entropy length-32 completions with an average probability of 1e-16. This adds up to normality in that if you were to randomly sample this distribution, you'd get a weird low-entropy output one time in a million, and a normal high-entropy output the other 999999 times in a million. But if you try to do something along the lines of "take the best K outputs and train the model on those", you'll end up with almost entirely weird low-entropy outputs.
But yeah, I think I misunderstood your proposal as something along the lines of "take the k most probable n-token outputs" rather than "take the k% most probable n-token outputs" or "randomly sample a bunch of n-token outputs".
comment by p.b. · 2024-04-22T09:47:35.298Z · LW(p) · GW(p)
Yeah, the first 99 tokens would be optimized both to be locally the correct character, and also to set things up so that the 100th character is also correct.
That is how LLMs currently work. The gradient of each token prediction does flow back into all the earlier tokens whose information was integrated into the predicted token. So each token optimizes its own next token prediction but also tries to integrate the information that is most useful for future tokens.
Replies from: habryka4↑ comment by habryka (habryka4) · 2024-04-22T15:45:45.666Z · LW(p) · GW(p)
I reference this in this section:
I do think saying "the system is just predicting one token at a time" is wrong, but I guess the way the work a transformer puts into token N gets rewarded or punished when it predicts token N + M feels really weird and confusing to me and still like it can be summarized much more as "it's taking one token at a time" than "it's doing reasoning across the whole context
IIRC at least for a standard transformer (which maybe had been modified with the recent context length extension) the gradients only flow through a subset of the weights (for a token halfway through the context, the gradients flow through half the weights that were responsible for the first token, IIRC).
↑ comment by p.b. · 2024-04-22T16:29:36.464Z · LW(p) · GW(p)
Frankly, I don't really understand what you are saying here and I am open to the possibility that I don't really understand how the gradient works in autoregressive transformers.
But as I said in my other comment, my current understanding is:
In standard attention (for example in an encoder) tokens are not ordered, so it is clear that the gradient of the loss of one of the token predictions (for example a masked token in BERT) flows through all other tokens equally. In autoregressive transformers an order is imposed by masking, but all later tokens attend to all earlier tokens in the same way.
The gradient of the loss of a later tokens flows through all earlier tokens in the same way. It doesn't matter whether a token is half the context back or all the context, neither for the information flow nor for the gradient flow.
To put it another way: In the n-th layer the last token attends to all the output tokens from the n-1-th layer. It doesn't somehow have to make do with the output of earlier layers for tokens that are further back.
Replies from: habryka4↑ comment by habryka (habryka4) · 2024-04-22T16:53:01.772Z · LW(p) · GW(p)
Yeah, I was indeed confused, sorry. I edited out the relevant section of the dialogue and replaced it with the correct relevant point (the aside here didn't matter because a somewhat stronger condition is true, which is that during training we always just condition on the right answer instead of conditioning on the output for the next token in the training set).
In autoregressive transformers an order is imposed by masking, but all later tokens attend to all earlier tokens in the same way.
Yeah, the masking is what threw me off. I was trying to think about whether any information would flow from the internal representations used to predict the second token to predicting the third token, and indeed, if you were to backpropagate the error after each specific token prediction, then there would be some information from predicting the second token available to predicting the third token (via the the updated weights).
However, batch-sizes make this also inapplicable (I think you would basically never do a backpropagation after each token, that would kind of get rid of the whole benefit of parallel training), and even without that, the amount of relevant information flowing this way would be very miniscule and there wouldn't be any learning going for how this information flows.
comment by p.b. · 2024-04-22T09:34:52.614Z · LW(p) · GW(p)
I don't know how people are creating huge context windows these days, but IIRC the way it works is that the longer you look back into your context (and correspondingly the further you are trying to plan ahead) the less of your computation is available. Like, if you have N layers, then for a token M steps back, you only have access to the computation up until layer N-M.
Everything in the context window is equally available. It doesn't make a difference whether an earlier token is 5 tokens back or 5000. The attention mechanism is an operation over a set of tokens, there is no intrinsic order.