A library for safety research in conditioning on RLHF tasks

post by James Chua (james-chua) · 2023-02-26T14:50:56.762Z · LW · GW · 2 comments

Contents

2 comments

There is a couple of discussions regarding conditioning / decision transformer training. "Conditioning" as in, placing your reward model's reward in the prompt. We then train our language model to create a completion that follows the specified reward.

See Safety considerations for online generative modeling [LW · GW], Soft optimization makes the value target bigger [LW · GW], RLHF bad, conditioning good [LW · GW]. 

The tldr is that training models this way could have safety benefits.

I've created a library so that we can extend a pre-trained LLM (gpt2, gpt-j) to work with conditioning by scalar rewards. This allows researchers to save time by avoiding the need to modify attention masks, positions, and labels themselves. For example, researchers can retrain GPT-2 to replicate OpenAI's summarization RLHF paper, but by relying purely on conditioning. I created it because I couldn't find an existing library that did so.

Note: an easier way of conditioning would be to use discrete tokens. Pretraining models With Human Preferences [LW · GW] implements conditioning via discrete <|good|> and <|bad|> tokens.

However using scalar rather than discrete rewards could have the following benefits

 

While the library is still in early development, it can already be used for offline training of GPT-2. 

I'm writing this post earlier rather than later to:

2 comments

Comments sorted by top scores.

comment by Charlie Steiner · 2023-02-26T16:39:16.853Z · LW(p) · GW(p)

Seems cool to me. I don't totally understand what's going on with the "embedding" of the score, but presumably this way works well for DTs.

Replies from: james-chua
comment by James Chua (james-chua) · 2023-02-26T16:52:16.802Z · LW(p) · GW(p)

For DTs its really just a linear function to convert the scalar reward into the same dimmensions the token embeddings.

So e.g. a single token's embedding has a hidden state of size 1024 . 

We can learn a linear function that takes this scalar and outputs something of size 1024.

The more annoying (PITA) part was offset the positional/attention masks/labels for this.