Goal oriented cognition in "a single forward pass"

post by dxu, habryka (habryka4) · 2024-04-22T05:03:18.649Z · LW · GW · 15 comments

Contents

  Meta on dxu's background
  Shortform discussion
  Non-myopic objectives/n-token predictors
None
15 comments

The below is me (habryka) and dxu talking about a shortform that dxu had published a few months ago, going into the relationship between goal-oriented cognition and the myopic nature of current large language model training setups. Some key quotes if you don't want to read the whole thing:

I think it's interesting to think about how brains get around the issue [of having limited context]. Obviously the patterns in which biological neurons fire don't carve cleanly into "forward passes" the way ANNs do, but ultimately there's only so much computation that can occur in the brain within some (small) unit of time. In that sense, I think I want to claim that the brain's ability to engage in something that looks a whole lot like long-term, goal-oriented reasoning clearly can't depend on being able to hold that entire context in memory to be attended to, even selectively.

When I think about the internal experience of problem-solving, there's a mental move of going up and down the ladder of abstraction, where you zoom in on some particularly difficult and/or confusing part of the problem, solve it, and then use what you learned from that to zoom back out and fill in a gap in the larger problem you were trying to solve. For an LLM, that seems like it's harder, and indeed it's one of the reasons I inside-view suspect LLMs as-currently-trained might not actually scale to AGI. (Though there are obviously some terminological issues surrounding "AGI" at the moment, what I mean by that is something like the classical concept of "something better than humans at pretty much every cognitive task humans can do".)


An LLM doesn't learn how it can best reason within a 2048-token context. The human cognition that it is imitating has been shaped by lots of feedback that propagates back through multiple tokens. The human has learned to avoid cognitive traps and routes around the things that would cause the human to go off track. But if there is something that would cause the LLM to go off-track on a task, it will just do it again and again every time.


dxu
Okay, but doesn't that also suggest that an LLM trained on human-generated data would reach human-level intelligence internally before its output began to successfully mirror the human output?

habryka
I would strongly predict it would develop many vastly superhuman capabilities, yes. (Like having memorized vastly more facts than any human alive, or be vastly better at predicting the next token of a body of text than any human alive, or be much faster at writing code than any human alive.)

dxu
Yeah, I see your point, and in fact GPT-4 is certainly past that point on many of those metrics.

The dialogue had no big conclusion, though I found thinking about "n-token" and "1-token" reasoners useful as an abstraction that has come up in a few conversations I've been in since we wrote this dialogue. I would also be interested in hearing more from people who have been closer to the metal with transformers on helping me (and others) understand the degree to which transformers are maybe doing something more like joint optimization.

Meta on dxu's background

dxu

I think my gestalt sense of AI as a whole, including but not limited to alignment, has undergone some fairly radical shifts over the past 2-3 years, with the rapid increase in (apparent) capabilities from LLMs being a big contributor to that. That's obviously not limited to me, but on the whole I think I was pretty bullish on MIRI-style agent foundations work back in, like, the mid-2010s, and then LLMs came along and kind of blew a big chunk of that up, and I'm still working my way through the process of picking up those pieces and trying to get them to cohere.

habryka

Makes sense.

dxu

Much of the reason I haven't made any top-level posts (basically ever), apart from just a generalized impostor syndrome I guess, is that I don't really think I have a Thing here that can pass for any kind of generalized worldview; mostly what I have are snippets of thoughts that feel true and useful, and certainly feel like pieces of what could be a coherent worldview, but also a sense that what pieces come to mind depend substantively on the specifics of whatever is being discussed. To take the consequentialism thing as an example, that came up in response to a question about why I might expect powerful systems to end up agentic, and felt like a relatively solid piece of Worldview that I had, which is why I ended up posting it in shortform.

dxu

I suspect a lot of us are in something of a similar boat; one of the issues with AI alignment, and thinking about building intelligences in general, is that it's very much a foundational field in its (relative) infancy, and a big chunk of the difficulty there comes from choosing the right frame to think in; and I think I have issues with probably all of the existing frames, which prevents me from subscribing in full to any individual one of them.

dxu

This is all pretty meta, so how about this: let's talk about some things that you find confusing about this whole topic, and I can try and contribute my thoughts on that. That way, we can kind of build from the bottom up, rather than either of us having to present some kind of top-down worldview of the kind that I (at least) certainly don't think I have! :-)

habryka

Ok, yeah, that seems good to me.

Shortform discussion

habryka

Let's go directly into your latest shortform [LW · GW]: 

The way I think about it is something like: a "goal representation" is basically what you get when it's easier to state some compact specification on the outcome state, than it is to state an equivalent set of constraints on the intervening trajectories to that state.

In principle, this doesn't have to equate to "goals" in the intuitive, pretheoretic sense, but in practice my sense is that this happens largely when (and because) permitting longer horizons (in the sense of increasing the length of the minimal sequence needed to reach some terminal state) causes the intervening trajectories to explode in number and complexity, s.t. it's hard to impose meaningful constraints on those trajectories that don't map to (and arise from) some much simpler description of the outcomes those trajectories lead to.

This connects with the "reasoners compress plans" point, on my model, because a reasoner is effectively a way to map that compact specification on outcomes to some method of selecting trajectories (or rather, selecting actions which select trajectories); and that, in turn, is what goal-oriented reasoning is. You get goal-oriented reasoners ("inner optimizers") precisely in those cases where that kind of mapping is needed, because simple heuristics relating to the trajectory instead of the outcome don't cut it.

It's an interesting question as to where exactly the crossover point occurs, where trajectory-heuristics stop functioning as effectively as consequentialist outcome-based reasoning. On one extreme, there are examples like tic-tac-toe, where it's possible to play perfectly based on a myopic set of heuristics without any kind of search involved. But as the environment grows more complex, the heuristic approach will in general be defeated by non-myopic, search-like, goal-oriented reasoning (unless the latter is too computationally intensive to be implemented).

That last parenthetical adds a non-trivial wrinkle, and in practice reasoning about complex tasks subject to bounded computation does best via a combination of heuristic-based reasoning about intermediate states, coupled to a search-like process of reaching those states. But that already qualifies in my book as "goal-directed", even if the "goal representations" aren't as clean as in the case of something like (to take the opposite extreme) AIXI.

To me, all of this feels somewhat definitionally true (though not completely, since the real-world implications do depend on stuff like how complexity trades off against optimality, where the "crossover point" lies, etc). It's just that, in my view, the real world has already provided us enough evidence about this that our remaining uncertainty doesn't meaningfully change the likelihood of goal-directed reasoning being necessary to achieve longer-term outcomes of the kind many (most?) capabilities researchers have ambitions about.

So, I have a bunch of different thoughts here, but definitely the first thing that comes to mind is something like "but man, I feel like most types of search are recurrent, and transformers just don't seem very recurrent in the relevant way". Like, at the end of the day things are processed in one forward pass, and that makes thinking about this stuff harder.

dxu

I follow you.

habryka

I am trying to imagine this reasoner doing a generalized search, and I can see that search happening at the level of a full completion, but I have trouble imagining it at the level of a forward pass. 

But the vast majority of the optimization pressure in the system is going into optimizing performance of a next-token prediction, not into optimizing performance of a completion (which gets a tiny bit juice from RLHF or RLAIF, but like miniscule amounts of data compared to the base model).

dxu

Yeah, that's an interesting perspective, and not one I'm entirely unsympathetic to. What I'd say, first of all, is that I'm not (and frankly, I don't think anyone is) able to confidently say whether LLMs as currently trained can or can't scale to superintelligence, and so in a sense what you said can be viewed as a reason to think they can't. In which case, I do think it's fair to say LLMs present much less of an x-risk, but also, I would expect LLMs to be less useful in general? There's probably some operationalization there in terms of impact on world GDP or whatever, but even without that I think it's worth making the general point that there's something like a conditional there, which goes "If it's capable of solving [X problem], it has enough search to be dangerous."

dxu

Also, on the object level I do think it's not quite right to say "the vast majority of the optimization pressure in the system is going into optimizing performance of a next-token prediction", since in the actual dataset the next token exhibits substantial dependencies on tokens situated well into the tail end of the context window. I definitely think modern LLMs respect in-context dependencies, although I think it's harder to say how much that generalizes to logical sequences longer than the length of their context window.

habryka

In which case, I do think it's fair to say LLMs present much less of an x-risk, but also, I would expect LLMs to be less useful in general?

Though I guess it might be an argument that if you start introducing completion-level feedback, you would get something more dangerous, since at the language level things definitely seem recurrent enough for search.

habryka

I agree there is a question here about danger, but I think I also want to highlight at least a minor inconsistency here (either in my model or yours), that in some sense the current architectures we have seem to me to enforce most of the selection going into selecting constraints on the trajectory of a state (by doing things at the next-token level). And like, my guess is indeed that LLMs have a huge number of goals in the way you define it, but where the outcome state on which those goals are operating is constrained to things relevant to the next token.

habryka

I definitely think modern LLMs respect in-context dependencies, although I think it's harder to say how much that generalizes to logical sequences longer than the length of their context window.

So, maybe this is obvious to people who have worked more with modern deep learning systems, or have engaged with the math more deeply, but I find myself confused about this. I don't know how people are creating huge context windows these days, but IIRC the way it works is that the longer you look back into your context (and correspondingly the further you are trying to plan ahead) the less access you have to a bunch of the related

I do think saying "the system is just predicting one token at a time" is wrong, but I guess the way the work a transformer puts into token N gets rewarded or punished when it predicts token N + M feels really weird and confusing to me and still like it can be summarized much more as "it's taking one token at a time" than "it's doing reasoning across the whole context" computation. 

dxu

I think it's interesting to think about how brains get around this issue. Obviously the patterns in which biological neurons fire don't carve cleanly into "forward passes" the way ANNs do, but ultimately there's only so much computation that can occur in the brain within some (small) unit of time. In that sense, I think I want to claim that the brain's ability to engage in something that looks a whole lot like long-term, goal-oriented reasoning clearly can't depend on being able to hold that entire context in memory to be attended to, even selectively.

When I think about the internal experience of problem-solving, there's a mental move of going up and down the ladder of abstraction, where you zoom in on some particularly difficult and/or confusing part of the problem, solve it, and then use what you learned from that to zoom back out and fill in a gap in the larger problem you were trying to solve. For an LLM, that seems like it's harder, and indeed it's one of the reasons I inside-view suspect LLMs as-currently-trained might not actually scale to AGI. (Though there are obviously some terminological issues surrounding "AGI" at the moment, what I mean by that is something like the classical concept of "something better than humans at pretty much every cognitive task humans can do".)

dxu

One obvious rejoinder I can see to that (in the context of the larger argument surrounding policy) is something like, well, sure, but in that case we might get along just fine without ever building AGI, and also it's fine to scale up LLMs indefinitely, no pause needed. I'm not yet sure if I want to bite the bullet and say that, yes, my inside view does imply that, but regardless my overall level of uncertainty around what LLMs can and can't do is ultimately high enough that I don't want to take my inside view as particularly definitive here. I'm aware that this might sound like somewhat of a cop-out, though.

habryka

I mean, a story you can tell, which I am somewhat sympathetic to but also feel confused about, is: 

You have a model that really cares about predicting the next token. It only has (in a simplified sense) a single forward pass to work with, but if you scale up the system, that's eventually going to be enough. You can get dangerous cognition and planning and instrumental goals in the pursuit of predicting the next token. 

It does seem like the kind of thing that's harder, but also, I think it somewhat naturally predicts that it will be hard to get substantially superhuman performance on almost any many-token task, because like, you aren't really training the model on any task, you are training the model to predict the tokens associated with the task.

Another story you can tell, which opens up some questions I am quite curious about, is something like 

well, but a lot of the cognitive machinery that was being goal-directed in the pursuit of a next-token prediction probably can be repurposed pretty easily to be goal-directed in the pursuit of a whole-episode completion. So you might really not need that much training data to now allow the system to use 10,000x compute in the pursuit of its goals. 

But I feel pretty confused about the degree to which these things are transferrable. My current guess is a lot, but I have a lot of trouble making the argument explicit. 

dxu

I think one interesting question here is how easy it'd be, if you had a human narrate their internal stream of thought while solving some difficult problem (say, a competition-level programming or mathematics question), where the entire process of solving the problem is many times longer than the context window of a given LLM, to have the LLM "follow along" with that stream of thought, in the sense of being able to generate plausible (and therefore useful) continuations the entire way through. My sense is that this probably isn't possible for GPT-4 on sufficiently complex problems, but I don't think I currently know any facts that prohibit it being possible for any LLM no matter how big it is, even holding constant the size of the model's context window.

dxu

I think if you accept that that's possible, that gets you most of the way to (hypothetical) long-horizon goal-oriented LLM-based reasoners, but I'm not sure whether/why you get off the train along the way.

dxu

(I've tried similar-ish experiments with ChatGPT out of curiosity, and haven't particularly been impressed by the results.)

habryka

So, I mean, I expect the LLM could follow along and generate plausible continuations, but in some sense it would do so (in my current model of the situation) as only an indirect byproduct of the vast majority of its cognition. Like, it wouldn't need to literally solve the whole difficult problem on each forward pass, but like, I feel tempted to say that the reason the LLM would succeed is kind of by accident (though I sure don't really know what I mean by that).

And so, if you have a model that is capable of doing this, then my current intuition is that it internally is actually solving a bunch of much harder problems than this, and inasmuch as it's doing dangerous cognition, it's doing much more dangerous cognition than solving the competition-level programming or mathematics question would imply. 

dxu

Yeah, so on my model a key detail here is whether any given context-length chunk of the human's stream of thought is sufficient to infer the original problem. If it isn't, i.e. if, once the problem statement has exited the model's context, the human's stream of thought doesn't clearly reference it again, then that would obviously break the model's ability to generate a continuation that's useful for solving the original problem. I mostly don't think I'd characterize that as a problem with the model's cognition, but it's definitely true that it could come up.

Insofar as it remains clear throughout what problem the human is trying to solve (we could imagine the human cooperates with the model by providing periodic reminders of what they're trying to achieve with any given thread of approach/attack), I think that the task the model is being asked to do is possible in principle within the given architectural limitations, and that the reason GPT-4 fails at it sometimes is basically the same reason GPT-3 fails at it in more cases than GPT-4 does—which is just a longer way of saying I think scaling would probably fix it.

habryka

To give maybe one example of what I mean by "by-accident" and the point above about "maybe even a tiny bit of RLHF would cause you to overcome these limitations":

The obvious thing to do here, as the LLM, is of course to intentionally repeat the basic components of the problem necessary to solve the problem frequently enough so that it always stays in context. A system trained on imitating humans will never do that, or like, will only do it inasmuch as a human has to do it as well, but it seems like such an extremely simple thing to do that wouldn't cost you much accuracy on the prediction task, that even after just a few hundred to a thousand datapoints of RL training, it wouldn't surprise me if the system learned to do that.

dxu

Sure, I buy that. It's almost certain that RL on X will improve performance on X, in my view, and I think at a higher level that basically allows the argument to go through? Like, if something is achievable using RL-based finetuning on the base model, then I do in fact expect people to get around to doing that finetuning at some point, and if the result of that is dangerous I think I'm happy with saying the base model had dangerous capabilities all along?

habryka

So, the reason why I am bringing this up, is because it feels like an interesting and concrete lens on the thing you said in your shortform about having cognition that is oriented around an outcome vs. cognition that is oriented around trajectories. 

And I think the key thing that in many ways blew up a bunch of my models about how AI would go is the degree to which reasoning about tokens, which in some sense are the trajectories of thought at the text level, give rise to behavior that looks like global planning behavior. Like, I expected huge benefits from approaching problems from the outcome level, but instead it does sure look like you can get really an enormous amount of juice from doing trajectory level reasoning.

habryka

Like, you say

It's an interesting question as to where exactly the crossover point occurs, where trajectory-heuristics don't function as effectively as consequentialist outcome-based reasoning. On one extreme, there are examples like tic-tac-toe, where it's possible to play perfectly based on a myopic set of heuristics without any kind of search involved. But as the environment grows more complex, the heuristic approach will in general be defeated by non-myopic, search-like, goal-oriented reasoning (unless the latter is too computationally intensive to be implemented).

And if you had asked me to predict where that crossover point is like 5 years ago, I would have definitely said "I mean, clearly very long before you can solve International Math Olympiad problems". And like, AI isn't yet solving IMO problems, but my guess is we aren't very far away. 

One might spin this as a hope, but it is actually one of the things that makes me most afraid in the moment right now. Because like, it's not like we have been able to try what happens when you train systems to perform well at these kind of high-level outcomes. Deep RL is extremely finicky, and we don't have the data to train systems on long tasks like IMO problems remotely as much as we can on next-token prediction. But there are various self-play or self-critique like approaches that could just defeat the data paucity here, and I am very concerned this will be associated with an enormous capability spike.

dxu

Well, it's not clear to me the extent to which you can say an LLM is doing "trajectory level reasoning" as opposed to "outcome level reasoning", simply because the outer form of its prediction task happens to be local with respect to tokens. AlphaZero on a single forward pass only outputs its probability distribution on the immediate next move, but because it was trained to imitate probability distributions output by an actual search process, I would imagine its internal computation contains at least something we would recognize as "looking forward past the immediate next move", despite that not being anywhere in the explicit objective function.

habryka

Yeah, I mean, that's kind of the simulators view. In your forward pass you simulate or at least develop a heuristic model of the search process that generated the data you have (in the text case, human cognition; in the AlphaZero case, Monte Carlo Tree Search). But that needs to happen in a single forward pass, doesn't it?

dxu

It's certainly the case that the model's performance will at some point saturate, in the sense that its prediction will stop improving with additional training, even though it could be improved by bolting on some kind of actual search apparatus (like MCTS in the case of AlphaZero). I don't think I'm as bothered by this as you are; to return to the brain analogy, a human brain is also only capable of a finite amount of computation in, like, a half-second. It seems to me that wherever our general reasoning abilities come from, they have to come from recurrence, rather than from having huge context windows or whatever?

habryka

I mean, yes, I think we totally do recurrent reasoning at the level of our thoughts. But also, it seems really clear to me that my reasoning has been trained on long trajectories and goals with a wide range of different time horizons. 

Like, I learn things like "repeat the initial problem formulation in my head frequently enough so that I remember the important aspects of the problem". But LLMs do not get to learn that kind of reasoning. They only get to imitate the things that happen to be present in the text. If it would be extremely helpful for an LLM to repeat the problem formulation every few thousand tokens in order to get the problem right, but that behavior didn't show up in the text corpus, then it still won't do it.

I am not saying that LLMs are inherently, architecturally, incapable of long-chained reasoning. Text/chains-of-thought are plenty recurrent to allow cognition about long-term goals. I am saying that the training process of LLMs does not actually incentivize performance on long-chained reasoning.

dxu

So, I think there's a possible reframing here of your question (and you can let me know how you feel about this reframing), which is something like: how likely is it that the cognitive operations learned by the model compose with themselves across arbitrarily long time horizons? NAND gates, for example, are a very simple example of something which does compose; you can string together an arbitrarily large sequence of NAND gates and get it to compute any function you want. AlphaZero, despite only ever being trained on 800-visit MCTS distributions, seems to have learned something sufficiently composable that you can use its forward-passes to build trees much larger than 800 visits, and have that produce a meaningfully better move prediction (for game-winning purposes) than the 800-visit default.

It sounds to me like you're saying, possibly LLMs have learned some restrictive form of cognition which produces fine results within, like, 2048 or 8192 tokens or whatever, but wouldn't actually compose with itself over substantially longer sequences of reasoning than that. Do you think that's a reasonable characterization of your intuition here?

habryka

Quick clarification: What does 800-visit MCTS distribution mean here?

(I don't remember the AlphaZero setup in that much detail. I know it has a value network and does MCTS. I don't know what "visit" means in the above. I thought the value network was trained on trajectories from self-play, so on the first run it would basically just choose random moves to go down with MCTS)

dxu

Yeah, so basically the model is amplified using MCTS (self-play), and the result of that MCTS is then distilled via training into the next iteration of the model, and due to compute considerations they only had the self-play evaluate 800 nodes (individual positions) in the search tree before cutting the search off and using the resulting distribution as ground truth for the distillation. In actual play, MCTS can go for as many nodes as you want, and the improvement in performance doesn't just arbitrarily stop at 800 nodes even though that's the maximum number of nodes the model ever "saw" during training.

habryka

Ah, yeah, cool. That makes sense. I knew that they had somehow limited the look-ahead number. I guess I was imagining the limitation on MCTS was "N steps ahead", but visits makes more sense as a limitation.

dxu

Right, so the larger point is just, if you've got something which works well for 800-visit trees and smaller, it seems in practice there aren't many forms of cognition that work well up to that limit and then just stop working for some reason. If the model manages to learn some form of internal cognition which composes well with itself up to 800 visits, it seems to be the case, at least for chess and similar games, that this cognition ends up generalizing and composing well with itself in trees with arbitrary numbers of nodes.

I would guess by default that something similar is true for LLMs and the kind of reasoning they must have encountered in their massive training corpus; my prior on "they've learned '2048-token' reasoning, which is like real reasoning except it stops working in scenarios with more than 2048 tokens" is... really low.

habryka

Yeah, totally, I agree that if they had directly learned 2048-token reasoning, then I would buy this. But I currently believe they learned 1-token reasoning, and so I don't buy this.

dxu

And I did point out earlier that, like, one obvious issue with this is that sometimes the context window is too small for the model to remember what it's trying to do, in which case it will fail basically by forgetting the problem it was trying to solve. But you pointed out (and I agree) that very mild amounts of RL ought to fix this, e.g. by inducing the model to periodically remind itself what it's doing.

habryka

To be clear, I think there is a decent chance I am totally wrong about this, but I keep coming back to my current model of LLMs mostly just predicting one token at a time. I feel a bit confused since of course a transformer does attend to tokens (and their later-stage embedding) that were a many steps before in the context, but the key thing is that during training we don't condition on the output of the transformer, we predict all the tokens in-parallel, using the ground-truth as a prefix. So there can't really be any gradients that get passed through as a result of the effect of a predicted token on subsequently predicted tokens.

dxu

I actually share your intuition that there's something different about the LLM case and the AlphaZero case (and the "humans thinking through problems" case, for that matter); I'm just not sure how those differences are relevant when it comes to the question of, has this thing which was only ever trained on sequences with N tokens or fewer learned (a form of) cognition that could in principle usefully operate on sequences with more than N tokens.

habryka

As an extreme example, let's replace the transformer with literally just a single very long sequence of fully connected layers or something. Maybe do a bit of weight sharing in various places. Maybe some convolutions. 

But ultimately it's just literally taking in a sequence of tokens, and predicting the very next one. No parallelism or access to past internal states. You completely start over for the next token. And so on.

In this case I feel like it wouldn't make sense to think of this model as doing n-token reasoning.

dxu

I also think such a model would just perform significantly worse than an actual Transformer with attention mechanisms. I mean, you're basically talking about a regular feedforward net at that point, right?

habryka

To highlight what I mean in a different way, transformers don't optimize the joint probability of context-length completions that they output.

Like, a transformer will give you the highest probability next token. It won't output a token that has lower probability, but makes it more likely for it to get lower loss on future tokens, when conditioned on.

But I feel like if you are talking about an n-token reasoner, while I am imagining something that does unbiased optimization over strings of length n. But the way transformers work is they greedily think about the very next token, and predict that one, even if by conditioning on it you shot yourself in the foot for the task at hand.

dxu

Hm. Yeah, I think I can sort of see where you're coming from with that.

I guess the way I think about that, from the "compositionality" lens (in the sense of local cognitions which compose with each other to form longer reasoning chains), is that, if you're attending to and conditioning on a context of significant length, the "greedy" completion is likely (to the extent that the model is powerful enough) to simply be the correct continuation, even at a token level, of whatever process was responsible for generating the preceding context. What do you think of that?

habryka

I think I agree that inasmuch as your internal cognitive architecture is the same as the cognitive structure of the thing that you are predicting the next token of, then in some sense you can learn to imitate the state transitions of the system you are trying to predict with each token. 

But the mind of a human writer is really not that well-captured by moving token to token. Sometimes a human takes a long time to write the next token. Sometimes the human looks up something on the internet.

Non-myopic objectives/n-token predictors

habryka

Ok, so here is maybe one framing on the problem. 

In some sense, ignoring incoming sensory data, a human mind must have some kind of transition rule. When it's in state  at time t, and at time t+1 in state , then there must be some rule R by which  and then also  and so on.

You can imagine training a system on learning this transition rule.

If you had direct access to  and trained a system on those transitions, I do think you would clearly end up with something that would do long-chained reasoning. And in that case I think your model of "it's just doing the greedy completion" would indeed be a totally valid way of doing human reasoning.

dxu

I think I get what you're trying to say; you're saying that LLMs, rather than being trained on the full set of state transitions that occur in our minds, are trained only on the subset of those which happen to produce a token that we then choose to output, correct?

If so, that sounds broadly similar to a point I think Eliezer made somewhere, where (paraphrasing) he said LLMs are trained on the "shadows" of our thoughts, rather than our thoughts themselves.

habryka

Well, first of all it's not trained on any of the cognitive states of a human. At no point does an LLM have access to any of the 

But even worse than that, the thing that it is outputting is also not of the same type as a cognitive state. It is outputting an additional token that then gets appended to the context. Its internal transition rule is much more limited in the set of transitions it can apply, namely it can just append a single token to its previous state.

dxu

My guess at where you're going with this is that, because the LLM must potentially model an extremely complicated generating process (including cases where the human goes back and edits or deletes entire words, phrases, and sentences) within the space of a single forward-pass, that makes the LLM's prediction task much harder than would be implied by the naive expectation that all it has to do is think the same way a human does—which in turn has implications for capabilities?

habryka

That, but also, that during pretraining it doesn't get any feedback on how to use its reasoning for more than just predicting the next token, wherever its internal cognitive architecture and the human cognitive architecture diverge.

habryka

Like, an LLM doesn't learn how it can best reason within a 2048-token context. The human cognition that it is imitating has been shaped by lots of feedback that propagates back through multiple tokens. The human has learned to avoid cognitive traps and routes around the things that would cause the human to go off track. But if there is something that would cause the LLM to go off-track on the prediction task, it will just do it again and again every time.

dxu

I think I find that argument sketch reasonably persuasive, but I'm not actually sure which direction it points in terms of predicting capabilities. On the one hand, I could see an argument that because the LLM is working with an impoverished dataset, that means the cognition it learns is correspondingly shallower in certain hard-to-specify ways. On the other hand, I can also imagine thinking that if an LLM is able to perform despite these handicaps, that actually implies substantially more powerful internal cognition that could potentially do a whole lot more than it currently is, were it only pointed in a more productive direction.

habryka

Ok, sorry, here is a concrete prediction that I might just be wrong about: I expect an LLM to start doing quite different cognition if you just do joint-optimization of a many-token prediction task. I feel a bit confused here, but like, you get a very different loss function if you instead of training a model on predicting token by token, you ask it to produce whole sentences where at each token it conditions on its previous output, and then penalize it according to character-level differences in the output. 

(You can't use log-probs here anymore, since you would have to get probabilities of all possible completions, so I think the loss function here is a bit messy. Not sure whether character-level loss would work or how you would do it)

But like, I feel like when you are talking about modern LLMs being n-token reasoners, and I am saying that they are 1-token reasoners, one way to clarify what I mean is to define a 2-token reasoner. And I think one way to create a 2-token reasoner is to generate all plausible completions of 2 tokens, and then propagate the joint loss of the log-probs of those two tokens. That then feels to me like it would actually generate a 2-token reasoner, because you are actually propagating gradients for jointly optimizing a 2-token output.

habryka

(And of course, RLHF is one way of doing joint-optimization for really long tasks, but that leaves the land of prediction and instead goes into the land of approval, which I want to avoid in this case)

habryka

Does this make sense?

dxu

Would it be fair to characterize your 2-token reasoner as a 1-token reasoner over the cross-product of the set of original tokens (with itself)?

I.e. if we originally had, say, 10 tokens to work with, now we're working with 100 tokens, each with twice the length.

habryka

Well, not fully, because it would still be recurrent.

Like, if you have a 100-token reasoner, the cognition would still be constrained so that when you predict the 100th token, you condition on the first 99 tokens that you output (like, you are still thinking thoughts in-order, but you are now taking into account the effect of your thoughts at step 1 on your thoughts at step 100).

dxu

To check my understanding, the difference there would actually lie somewhere within those first 99 tokens, which would in some sense have been optimized for the 100th token?

habryka

Yeah, the first 99 tokens would be optimized both to be locally the correct character, and also to set things up so that the 100th character is also correct.

dxu

I mean, I guess I agree that would probably cleave more closely to the idea of something being "goal-oriented" in the sense of its local actions/behaviors being governed by an overarching understanding of having somewhere it's trying to go. I'm just unsure whether I think a "1-token reasoner", in your terms, can't exhibit that same property.

Okay, suppose we live in the world where you're right, and all extant LLMs are "1-token reasoners" with the property that they're not actually goal-directed in the relevant sense. Concretely, in that world I would expect to see something like a capabilities ceiling, where there are some tasks below human-level that no LLM seems to be able to solve, regardless of how much we scale them up. Does that sound like a prediction your model would make?

habryka

Not regardless. Because like, I think all tasks that humans can do can be done in a single forward pass, but it sure might have to be a really really long forward pass.

The thing that this model predicts is that LLMs, without being trained on this kind of joint optimization, do not meaningfully get more dangerous when you make the context longer (i.e. give them more time to think about a problem). The optimizer that arises as a result of imitating humans and following the structure of human text, is much weaker and kind of incidental compared to the optimizer that predicts the next token.

dxu

I think, for me, what separates a "trajectory-level" heuristic from "outcome-level" reasoning is that the former completes in O(1) time, whereas the latter scales with the time horizon of the task. This is what distinguishes search (which can go as deep as the state tree permits) from evaluation (which performs a fixed check on a given state and then returns).

For neural networks, all forward-passes are technically O(1), since all of them complete in constant time. For large enough models, however, there can be enough "time" inside of the forward-pass to execute some searchlike behavior, where the model explicitly reasons about states further in the future than the immediate successor state. Of course, there needs to be something incentivizing it to do this—and the claim is that for current LLMs, the training scheme does not incentivize this.

That... doesn't feel true to me? The argument about models not doing joint optimization over n-gram continuations seems like it's insensitive to the structural properties of the data the model is being fed, which feels strange to me. If we weren't training on data generated by humans, but by some hypothetical superintelligence, wouldn't a (sufficiently large) LLM trained on that dataset learn to think like said superintelligence? And wouldn't the resulting model almost certainly have to be capable of searchlike reasoning within a single forward-pass, to have any chance at all of imitating the output of said superintelligence?

habryka

(To maybe give a clearer example. Imagine that you don't do n-gram optimization, but you instead get to just output a chain of thought and then the last character of your output will be taken as your prediction, and we do PPO RL on your result. In that case, you can now do lots of recurrent optimization and get to take variable time in order to predict each token.)

And yes, my model is that before your LLM would become as smart as said superintelligence in the text output, it would have had to develop searchlike reasoning in a single forward pass. Correspondingly long before the time it successfully started imitating the superintelligence, it would already be superintelligent yourself and have probably broken out of its training process and killed you.

dxu

Okay, but doesn't that also suggest that an LLM trained on human-generated data would reach human-level intelligence internally before its output began to successfully mirror the human output?

habryka

I definitely think the "intelligence" abstraction falls apart a bit. But I would strongly predict it would develop many vastly superhuman capabilities, yes. (Like having memorized vastly more facts than any human alive, or be vastly better at predicting the next token of a body of text than any human alive, or be much faster at writing code than any human alive.)

dxu

Yeah, I see your point, and in fact GPT-4 is certainly past that point on many of those metrics.

habryka

Yeah, to be clear, those examples were chosen for rhetorical effect, since yeah, GPT-4 seems to indeed be vastly superhuman on these.

dxu

A model capable of searchlike reasoning within a single forward-pass, whether that was achieved through n-gram optimization or just forcing it to predict sufficiently hard data (as in the superintelligence example), would be capable of goal-oriented, "consequentialist" reasoning in the worrisome sense. Is that something we might agree on?

habryka

Well, to be clear, n-gram optimization here is referring to "do n-forward passes". And as you said, a model that does joint optimization of like 2048 tokens, seems much more like the kind of thing that might compose into arbitrarily long chains of reasoning.

dxu

Okay, I think I more-or-less agree with you, modulo the fact that you haven't explicitly said where in your model the danger comes from (though I expect I can guess).

Throughout this conversation, it felt to me like there was some deeper core of disagreement at play, but now that you've laid out your position, it actually feels to me like it dovetails fairly well with the shortform comment I posted.

habryka

Yeah, to be clear, I quite liked your shortform. 

I did intend to point out a contradiction that I see when I apply similar reasoning myself, which is that if you model a modern LLM as being forced by its training process to really just be doing outcome-level reasoning over tokens, which is really very trajectory-level over tasks, then I was surprised by how far we managed to be able to get task performance. I didn't expect tasks like writing functional code for many applications to be that solvable by only doing trajectory level optimization.

dxu

To reiterate, I don't have a strong stance (and in fact I think it would be wildly overconfident for anyone to claim they do) on whether the Transformer architecture as currently implemented might hit an architecture-level limitation that prevents it from becoming truly dangerous.

habryka

Yeah, totally. I am not confident of anything in this space either, though of course each time things don't hit a wall and get better as we scale is evidence that there isn't an architecture-level limitation, which I think is some evidence against the frame outlined in your shortform, at least if I am right about the level of joint optimization going on in current cutting-edge LLMs.

habryka

But overall, feels like a reasonable time to stop

dxu

Yep, I feel like I got several interesting pieces of model to chew on. Thanks again for doing this!

habryka

Thank you! I really enjoyed this. I've had a bunch of model fragments floating around and am glad I got to put them out on a page, and you also gave me multiple things to chew on.

15 comments

Comments sorted by top scores.

comment by Bogdan Ionut Cirstea (bogdan-ionut-cirstea) · 2024-09-09T13:13:15.788Z · LW(p) · GW(p)

But there are various self-play or self-critique like approaches that could just defeat the data paucity here, and I am very concerned this will be associated with an enormous capability spike.

You might be interested in this paper and summary thread.

comment by Gunnar_Zarncke · 2024-04-22T09:54:35.782Z · LW(p) · GW(p)

there's a mental move of going up and down the ladder of abstraction, where you zoom in on some particularly difficult and/or confusing part of the problem, solve it, and then use what you learned from that to zoom back out and fill in a gap in the larger problem you were trying to solve. For an LLM, that seems like it's harder, and indeed it's one of the reasons I inside-view suspect LLMs as-currently-trained might not actually scale to AGI. [bold by me]

But that might already no longer be true with model that have short term memory and may might make moves like you. See my Leave No Context Behind - A Comment [LW · GW].

comment by johnswentworth · 2024-04-22T05:45:09.462Z · LW(p) · GW(p)

(Didn't read most of the dialogue, sorry if this was covered.)

But the way transformers work is they greedily think about the very next token, and predict that one, even if by conditioning on it you shot yourself in the foot for the task at hand.

That depends on how we sample from the LLM. If, at each "timestep", we take the most-probable token, then yes that's right.

But an LLM gives a distribution over tokens at each timestep, i.e. . If we sample from that distribution, rather than take the most-probable at each timestep, then that's equivalent to sampling non-greedily from the learned distribution over text. It's the chain rule:

Replies from: habryka4
comment by habryka (habryka4) · 2024-04-22T06:16:25.652Z · LW(p) · GW(p)

I think you are talking about a different probability distribution here.

You are right that this allows you to sample non-greedily from the learned distribution over text, but I was talking about the inductive biases on the model. 

My claim was that the way LLMs are trained, the way the inductive biases shake out is that the LLM won't be incentivized to output tokens that predictably have low probability, but make it easier to predict future tokens (by, for example, in the process of trying to predict a proof, reminding itself of all the of the things its knows before those things leave its context window, or when doing an addition that it can't handle in a single forward pass, outputting a token that's optimized to give itself enough serial depth to perform the full addition of two long n-digit digit numbers, which would then allow it to get the next n tokens right and so overall achieve lower joint loss).

Replies from: habryka4
comment by habryka (habryka4) · 2024-05-01T22:45:25.073Z · LW(p) · GW(p)

@johnswentworth [LW · GW] I think this paper basically does the thing I was talking about (with pretty impressive results), though I haven't read it in a ton of detail: https://news.ycombinator.com/item?id=40220851 

Replies from: ryan_greenblatt
comment by ryan_greenblatt · 2024-05-02T23:42:02.636Z · LW(p) · GW(p)

Hmm, I don't think so. Or at least, the novel things in that paper don't seem to correspond.

My understanding of what this paper does:

  • Trains models to predict next 4 tokens instead of next 1 token as an auxilary training objective. Note that this training objective yields better performance on downstream tasks when just using the next token prediction component (the normally trained component) and discarding the other components. Notable, this is just something like "adding this additional prediction objective helps the model learn more/faster". In other words, this result doesn't involve actually changing how the model is actually used, it just adds some additional training task.
  • Uses these heads for speculative executation, a well known approach in the literature for accelerating inference.
Replies from: habryka4
comment by habryka (habryka4) · 2024-05-03T00:24:56.968Z · LW(p) · GW(p)

Hmm, I think the first bullet point is pretty precisely what I am talking about (though to be clear, I haven't read the paper in detail). 

I was specifically saying that trying to somehow get feedback from future tokens into the next token objective would probably do some interesting things and enable a bunch of cross-token optimization that currently isn't happening, which would improve performance on some tasks. This seems like what's going on here.

Agree that another major component of the paper is accelerating inference, which I wasn't talking about. I would have to read the paper in more detail to get a sense of how much it's just doing that, in which case I wouldn't think it's a good example.

comment by faul_sname · 2024-04-22T20:30:17.690Z · LW(p) · GW(p)

And I think one way to create a 2-token reasoner is to generate all plausible completions of 2 tokens, and then propagate the joint loss of the log-probs of those two tokens.

I think this just doesn't work very well, because it incentivizes the model to output a token which makes subsequent tokens easier to predict, as long as the benefit in predictability of the subsequent token(s) outweighs the cost of the first token. Concretely, let's say you have the input "Once upon a time, there was a" and you want 32 tokens. Right now, davinci-002 will spit out something like [" little"," girl"," who"," was"," born"," with"," a"," very"," special"," gift","."," She"," could"," see"," things"," that"," others"," could"," not","."," She"," could"," see"," the"," future",","," and"," she"," could"," see"," the"," past"], with logprobs of [-2.44, -0.96, -0.90, ..., -0.28, -0.66, 0.26], summing to -35.3. But if instead, it returned [" a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"," a"], it would have logprobs like [-9.32, -7.77, -1.51,  ..., -0.06, -0.05, -0.05], summing to -23.5. And indeed, if you could somehow ask a couple quadrillion people "please write a story starting with Once upon a time, there was a", I suspect that at least 1 in a million people would answer with low-entropy completions along the lines of  a a a a ... (and there just aren't that many low-entropy completions). But "Once upon a time there was a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a" is not a very good completion, despite being a much higher-probability completion.

You could use a more sophisticated loss function that "sum of individual-token logprob", but I think that road leads towards PPO (nothing says that your criterion has to be "helpful/harmful/honest as judged by a human rater" though).

Replies from: habryka4
comment by habryka (habryka4) · 2024-04-22T21:07:28.275Z · LW(p) · GW(p)

I think this just doesn't work very well, because it incentivizes the model to output a token which makes subsequent tokens easier to predict, as long as the benefit in predictability of the subsequent token(s) outweighs the cost of the first token.

Hmm, this doesn't sound right. The ground truth data would still be the same, so if you were to predict "aaaaaa" you would get the answer wrong. In the above example, you are presumably querying the log props of the model that was trained on 1-token prediction, which of course would think it's quite likely that conditional on the last 10 characters being "a" the next one will be "a", but I am saying "what is the probability of the full completion 'a a a a a...' given the prefix 'Once upon a time, there was a'", which doesn't seem very high.

The only thing I am saying here is "force the model to predict more than one token at a time, conditioning on its past responses, then evaluate the model on performance of the whole set of tokens". I didn't think super hard about what the best loss function here is, and whether you would have to whip out PPO for this.  Seems plausible.

Replies from: faul_sname
comment by faul_sname · 2024-04-22T22:15:09.694Z · LW(p) · GW(p)

I think the probability of getting the exact continuation "a a a a a ..." is genuinely higher than the probability of getting the exact continuation "little girl who was born with a very special gift...", though getting a continuation in the class of "a a a a a..." is much lower-probability than getting a continuation in the class of "little girl who was born with a very special gift..", because the latter class has a much larger possibility space than the former. So there might be 1e4 different low-entropy length-32 completions with an average probability of 1e-10 each, and 9.999999e15 different high-entropy length-32 completions with an average probability of 1e-16. This adds up to normality in that if you were to randomly sample this distribution, you'd get a weird low-entropy output one time in a million, and a normal high-entropy output the other 999999 times in a million. But if you try to do something along the lines of "take the best K outputs and train the model on those", you'll end up with almost entirely weird low-entropy outputs.

But yeah, I think I misunderstood your proposal as something along the lines of "take the k most probable n-token outputs" rather than "take the k% most probable n-token outputs" or "randomly sample a bunch of n-token outputs".

comment by p.b. · 2024-04-22T09:47:35.298Z · LW(p) · GW(p)

Yeah, the first 99 tokens would be optimized both to be locally the correct character, and also to set things up so that the 100th character is also correct.

That is how LLMs currently work. The gradient of each token prediction does flow back into all the earlier tokens whose information was integrated into the predicted token. So each token optimizes its own next token prediction but also tries to integrate the information that is most useful for future tokens. 

Replies from: habryka4
comment by habryka (habryka4) · 2024-04-22T15:45:45.666Z · LW(p) · GW(p)

I reference this in this section:

I do think saying "the system is just predicting one token at a time" is wrong, but I guess the way the work a transformer puts into token N gets rewarded or punished when it predicts token N + M feels really weird and confusing to me and still like it can be summarized much more as "it's taking one token at a time" than "it's doing reasoning across the whole context

IIRC at least for a standard transformer (which maybe had been modified with the recent context length extension) the gradients only flow through a subset of the weights (for a token halfway through the context, the gradients flow through half the weights that were responsible for the first token, IIRC).

Replies from: p.b.
comment by p.b. · 2024-04-22T16:29:36.464Z · LW(p) · GW(p)

Frankly, I don't really understand what you are saying here and I am open to the possibility that I don't really understand how the gradient works in autoregressive transformers. 

But as I said in my other comment, my current understanding is: 

In standard attention (for example in an encoder) tokens are not ordered, so it is clear that the gradient of the loss of one of the token predictions (for example a masked token in BERT) flows through all other tokens equally. In autoregressive transformers an order is imposed by masking, but all later tokens attend to all earlier tokens in the same way. 

The gradient of the loss of a later tokens flows through all earlier tokens in the same way. It doesn't matter whether a token is half the context back or all the context, neither for the information flow nor for the gradient flow. 

To put it another way: In the n-th layer the last token attends to all the output tokens from the n-1-th layer. It doesn't somehow have to make do with the output of earlier layers for tokens that are further back. 

Replies from: habryka4
comment by habryka (habryka4) · 2024-04-22T16:53:01.772Z · LW(p) · GW(p)

Yeah, I was indeed confused, sorry. I edited out the relevant section of the dialogue and replaced it with the correct relevant point (the aside here didn't matter because a somewhat stronger condition is true, which is that during training we always just condition on the right answer instead of conditioning on the output for the next token in the training set). 

In autoregressive transformers an order is imposed by masking, but all later tokens attend to all earlier tokens in the same way. 

Yeah, the masking is what threw me off. I was trying to think about whether any information would flow from the internal representations used to predict the second token to predicting the third token, and indeed, if you were to backpropagate the error after each specific token prediction, then there would be some information from predicting the second token available to predicting the third token (via the the updated weights). 

However, batch-sizes make this also inapplicable (I think you would basically never do a backpropagation after each token, that would kind of get rid of the whole benefit of parallel training), and even without that, the amount of relevant information flowing this way would be very miniscule and there wouldn't be any learning going for how this information flows. 

comment by p.b. · 2024-04-22T09:34:52.614Z · LW(p) · GW(p)

I don't know how people are creating huge context windows these days, but IIRC the way it works is that the longer you look back into your context (and correspondingly the further you are trying to plan ahead) the less of your computation is available. Like, if you have N layers, then for a token M steps back, you only have access to the computation up until layer N-M.

Everything in the context window is equally available. It doesn't make a difference whether an earlier token is 5 tokens back or 5000. The attention mechanism is an operation over a set of tokens, there is no intrinsic order.