A library for safety research in conditioning on RLHF tasks
post by James Chua (james-chua) · 2023-02-26T14:50:56.762Z · LW · GW · 2 commentsContents
2 comments
There is a couple of discussions regarding conditioning / decision transformer training. "Conditioning" as in, placing your reward model's reward in the prompt. We then train our language model to create a completion that follows the specified reward.
See Safety considerations for online generative modeling [LW · GW], Soft optimization makes the value target bigger [LW · GW], RLHF bad, conditioning good [LW · GW].
The tldr is that training models this way could have safety benefits.
I've created a library so that we can extend a pre-trained LLM (gpt2, gpt-j) to work with conditioning by scalar rewards. This allows researchers to save time by avoiding the need to modify attention masks, positions, and labels themselves. For example, researchers can retrain GPT-2 to replicate OpenAI's summarization RLHF paper, but by relying purely on conditioning. I created it because I couldn't find an existing library that did so.
Note: an easier way of conditioning would be to use discrete tokens. Pretraining models With Human Preferences [LW · GW] implements conditioning via discrete <|good|> and <|bad|> tokens.
However using scalar rather than discrete rewards could have the following benefits
- The ability to specify a higher reward than the maximum reward available during training. For example, during training the maximum reward obtained was 1. However during inference you can set it to 2. Would we get a better result, as demonstrated in the decision transformer paper? (Thanks @Tomek Korbak [LW · GW] for pointing this out)
- More efficient training since we don't lose information during discretization (yet to be empirically shown?)
While the library is still in early development, it can already be used for offline training of GPT-2.
I'm writing this post earlier rather than later to:
- Get feedback on whether this is useful. Or that it isn't useful and I should spent time on something else :)
- Get feedback on how to improve the library if the above statement is true.
- Save someone hours trying to do the same thing (just in case that happens)
2 comments
Comments sorted by top scores.
comment by Charlie Steiner · 2023-02-26T16:39:16.853Z · LW(p) · GW(p)
Seems cool to me. I don't totally understand what's going on with the "embedding" of the score, but presumably this way works well for DTs.
Replies from: james-chua↑ comment by James Chua (james-chua) · 2023-02-26T16:52:16.802Z · LW(p) · GW(p)
For DTs its really just a linear function to convert the scalar reward into the same dimmensions the token embeddings.
So e.g. a single token's embedding has a hidden state of size 1024 .
We can learn a linear function that takes this scalar and outputs something of size 1024.
The more annoying (PITA) part was offset the positional/attention masks/labels for this.