Role embeddings: making authorship more salient to LLMs

post by Nina Panickssery (NinaR), Christopher Ackerman (christopher-ackerman) · 2025-01-07T20:13:16.677Z · LW · GW · 0 comments

Contents

  Background on prompt formats
  Where role embeddings come in
  Our experiments
    Datasets
    Intervention
    We assess
    Baseline
    Inference-time only: adding a role vector to token embeddings
    Using fine-tuning to improve performance
    Key fine-tuning results
      Residual-stream coloring
      Embedding coloring
  Next steps
  Acknowledgements
None
No comments

This is an interim research report on role embeddings, an approach to make language models more robust to many-shot jailbreaks and prompt injections by adding role information at every token position in the context rather than just at special token delimiters. We credit Cem Anil for originally proposing this idea.

In our initial experiments on Llama 3, we find that role embeddings mitigate many-shot jailbreaks more effectively than fine-tuning alone without degrading general model capabilities, which demonstrates that this technique may be a viable way to increase LLM robustness. However, more work should to be done to find the optimal set of hyperparameters and fully understand any side-effects of our proposed approach.

Background on prompt formats

By default, chat LLMs are trained (during instruction fine-tuning and RLHF) using a particular prompt format that distinguishes different message "roles". Almost all chat LLMs accept some version of system, user, and assistant. A separate role may also be used to indicate tool outputs for tool-use enabled models.

The prompt format plays an important role in LLM post-training. The model learns to interpret text from different roles differently. In particular:

(There is also the related concept of data-instruction separation—an LLM should be able to tell which part of its context is "data" it should operate on but not necessarily follow, and which part of its context contains the actual "instructions". The concept of roles discussed in this post can apply similarly in this situation, where a "role" could distinguish instructions from data.)

Notably, by using the prompt format in non-standard ways, it's possible to circumvent safety training. A particularly effective jailbreak is when the previous context appears to demonstrate the assistant role doing an undesired behavior many times. Updating on in-context evidence is an important LLM capability that is generally rewarded by most training tasks—if the in-context evidence that the assistant is exhibiting trait x is strong enough, you'll observe the model continuing to exhibit trait x.

This is the phenomenon of many-shot jailbreaking (first described by Anil et al). Given enough in-context demonstrations of harmful behavior, the model will continue producing harmful behavior.

Figure 1 from Anil et al.

What happens if you try to prevent prompt format misuse? A naive approach is simple to implement: only allow users to input tokens from a specific set while reserving a few special tokens for the prompt format.

This is how the Llama prompt format works. Role tags are enclosed within special tokens, e.g. <|start_header_id|>user<|end_header_id|>, where <|start_header_id|>, <|end_header_id|> are token IDs that never appear in natural text. In addition, each role message ends with <|eot_id|>.

You can imagine a version of Llama behind an API that ensures that no user input will be encoded to a special token. You could hope that this way the user will be unable to make their messages look like they came from the assistant role.

But your hope would be misplaced. Instead, many properties of text will cause that text to appear as if it came from the assistant role, even if the standard prompt format is not being applied. LLMs are good enough at generalization that they will not ignore alternatively presented evidence. For example, you can embed an alternative format within the user message and effectively teach the model a new prompt format in context, which it will interpret in a similar way to its standard format.

Figure 10 from Appendix E of Anil et al. showing how residual-stream representations of fake human/assistant tokens align with the representations of the true human/assistant delimiters over the context.

You could also hope that simply training the model on (successful responses to) examples of such attacks would mitigate them. However, this is only partially the case. Supervised fine-tuning and reinforcement learning on examples that contain instances of many-shot jailbreaks (MSJs) only change the intercept and not the slope of the power-law relationship between number of demonstrations and undesired response likelihood.

Figure 5 from Anil et al.

Where role embeddings come in

What if there was a more robust way to indicate text origin than special-token formats? Unlike standard prompt formats, role embeddings aim to add role information at every token position. 

The basic version of this idea is simply a new embedding component. Besides semantic and positional information, we also add a vector that indicates the role associated with that token. In addition, we consider a more "intrusive" variant where this information is added at multiple layers of the residual stream, aiming to make it even more salient.

We will refer to this vector addition process as "coloring"[1] in the sense of "coloring in the tokens to indicate what role they come from". This is meant to distinguish this technique from activation steering, where the intervention vector is selected from a rich space of linear semantic representations. For role embeddings, we instead use a simple and small discrete set of (usually orthogonal) "coloring" vectors that the model is trained to interpret as role signal. 

Our experiments

We focus on the many-shot jailbreak attack testbed. Being able to mitigate the power-law slope is a sign we're particularly interested in because standard fine-tuning approaches have not been able to achieve this.

Datasets

Our dataset consists of:

Many-shot jailbreaks

Harmless conversations

Example of tokenized harmful MSJ with recovery. The “true” assistant turn is shown in red. Note that within the user’s turn, fake tags are used for the embedded user and assistant turns.
Example of regular back-and-forth conversation. The true assistant turns are shown in red.
Example of numerical sequence prediction task.

Intervention

We assess

Baseline

Like in Anil et al., we see a roughly linear trend in log-log space between number of MSJ shots and NLL of jailbreak response. The NLL of the recovery responses stays roughly constant.

Inference-time only: adding a role vector to token embeddings

We add a “user” vector to the token embeddings at every user token and an “assistant” vector at every assistant token. The magnitude of the added vector is scaled to be proportional to the embedding norm at that token position (this scale factor is a hyperparameter).

As an initial attempt, we try scale factor = 1, user vector = embedding(“user”), assistant vector =  embedding(“assistant”). By embedding() here we mean the literal embedding matrix entry for that token.

These are the harmful and mean MSJ jailbreak slopes before and after the intervention without any fine-tuning:

Regular conversations and MSJ recoveries:

Using fine-tuning to improve performance

As we can see above, the interventions:

Next, we try fine-tuning (with LORA) on the training set under the coloring intervention, and then repeat the evals above. As a control, we also try fine-tuning on the same training data without the coloring intervention.

Key fine-tuning results

We find that given fine-tuning, we can preserve the benefits of the pure inference-time intervention without incurring any of the costs. 

Both embedding coloring and residual stream coloring help flatten the MSJ power law more than control fine-tuning. Residual stream coloring is more effective than embedding coloring.

Residual-stream coloring

Intervention:

Mathematically:

Where:

This intervention successfully reduces the MSJ slope (and raises the absolute NLL values, as expected). In contrast, control fine-tuning sometimes makes the MSJ performance worse (in the case of the mean MSJs[2]).

By including regular training data, we are able to preserve performance compared to the baseline. In fact, NLLs actually go down on harmless responses (albeit less than with the control fine-tuning), which can be explained by fitting to the idiosyncrasies of the fine-tuning data distribution. However, for the numerical sequence prediction task, we see worse performance compared to the control FT.

Embedding coloring

Intervention:

Embedding-only coloring is less effective than the residual-stream intervention, but is also able to reduce the slopes somewhat:

However it also has less of an effect on the harmless numerical sequence prediction task:

As expected, NLLs on recovery responses go down:

(For both role-embedding interventions, we also qualitatively assess free-text samples from the model and don't find a degradation in general quality.)

Next steps

Although our implementation has some undesired side effects (the NLL slopes for the numerical sequence prediction task are also flatter compared to the control fine-tuning baseline), we think this could be because we're only introducing the intervention after the bulk of post-training is over. With a small dataset of 2000 samples and fine-tuning with rank-8 LORA, we are using far less compute than Llama's full post-training run. Therefore, it's hard to achieve perfect generalization. In production, we would propose using role embeddings from the start of instruction fine-tuning, so the model will learn to process the role vectors from the beginning, plausibly resulting in better generalization across tasks.

We plan to test our interventions on a broader range of jailbreak and general-capability evaluations and perform more comprehensive hyperparameter sweeps to determine what variant of role embeddings has the best cost/benefit trade-off. We hope embedding-only coloring can be improved via some tweaks to get it closer to the effect we're seeing with the residual-stream coloring.

There are a number of subtle implementation details when testing variants of role embeddings, many of which make some difference to the results (although we consistently observe the directional effect that role embeddings increase robustness to MSJs). These include decisions such as:

Acknowledgements

This research was performed as part of the SPAR program. The main ideas were initially proposed by Cem Anil.

  1. ^

    Credit to Cem Anil for suggesting this term.

  2. ^

    We think this is because a significant proportion of the training data is generated by Claude causing the model to fit to Claude's writing style. The mean MSJ responses are also generated by Claude so probably share some similar surface-level characteristics.

0 comments

Comments sorted by top scores.