Decision Transformer Interpretability

post by Joseph Bloom (Jbloom), Paul Colognese (paul-colognese) · 2023-02-06T07:29:01.917Z · LW · GW · 13 comments

Contents

  Key Claims
  Introduction
    Previous Work
  Methods
    Environment - RL Environments. 
      Dynamic Obstacles Environment
    Decision Transformer Training
    Attribution and Preference Directions
    An Interface for live Circuit Analysis
  Results
    Training a small Decision Transformer
    Black Box Model Characterisation
    Decision Transformer Architecture
    Circuit Results Summary
        If you have no time, skip ahead to the obstacle avoidance circuit. It’s the most interesting section.
    Attention and Head Ablation
      Ablation of Head 0 makes the DT reach the Goal at RTG = -1 from the top
      Ablation of Head 1 makes the DT hit obstacles at RTG = 0.9 
      Ablation of MLP makes the DT approach the Goal at RTG = -1 from the left
    Analysing Embeddings
      Analysing the RTG and Time Tokens
        Reward-to-Go
        Time
      Analysing the State Embeddings
    Explaining Obstacle Avoidance at positive RTG using QK and OV circuits
  Discussion
    Analysis takeaways
    Limitations
    Alignment Relevance
    Capabilities Externalities Assessment
    Future Directions
  Meta
    Acknowledgements
    Author Contributions
    Feedback
  Glossary
None
14 comments

TLDR: We analyse how a small Decision Transformer learns to simulate agents on a grid world task, providing evidence that it is possible to do circuit analysis on small models which simulate goal-directedness. We think Decision Transformers are worth exploring further and may provide opportunities to explore many alignment-relevant deep learning phenomena in game-like contexts. 

Link to the GitHub Repository. Link to the Analysis App. I highly recommend using the app if you have experience with mechanistic interpretability. All of the mechanistic analysis should be reproducible via the app. 

Key Claims

If you are short on time, I recommend reading:

I would welcome assistance with:

I’m also happy to collaborate on related projects. 

Introduction

For my ARENA Capstone project, I (Joseph) started working on decision transformer interpretability at the suggestion of Paul Colognese. Decision transformers can solve reinforcement learning tasks when conditioned on generating high rewards via the specified “Reward-to-Go” (RTG). However, they can also generate agents of varying quality based on the RTG, making them simultaneously simulators [LW · GW], small transformers and RL agents. As such, it seems possible that identifying and understanding circuits in decision transformers would not only be interesting as an extension of current mechanistic interpretability research but possibly lead to alignment-relevant insights. 

Previous Work

The most important background for this post is:

Figure 1: Decision Transformer Architecture Diagram from the original Decision Transformers Paper. Rather than modelling reward as resulting from an action, trajectories are labelled with the Reward-to-go (RTG), enabling information about future rewards to condition the agent's learned behaviour. RTG can be varied to change how well the agent performs on a task. 

Methods

Environment - RL Environments. 

GridWorld Environments are loaded from the Minigrid python package. Such environments have a discrete state space and action space where an embedded agent can turn right or left, move forward, pick up and drop objects such as keys, and use objects such as a key to open a door. Each gridworld involves a different task/environment, and most involve very sparse rewards. The observations can be rendered full or partial (from the agent’s point of view). In this work, we use partial views. 

Dynamic Obstacles Environment

This environment involves no keys or doors. An agent must proceed to a green goal square but will receive a -1 reward if it walks into a wall or obstacle (ending the episode). The obstacles move randomly every timestep, regardless of what the agent does. The agent received 1 reward (or time-discounted reward with PPO). The goal square does not move, and the space is always a fixed-size grid. 
 

Figure 2: The Dynamic Obstacle task. The agent must avoid obstacles that move randomly and reach the green objective.  The agent can only turn left or right or move forward. In both environments, the agent can see the highlighted region (“partial view”).

Decision Transformer Training

Trajectories are generated with an implementation of the PPO algorithm developed as part of the ARENA Coursework. We modified code from the original decision transformer paper for the model architecture to work with arbitrary mini-grid environments and to use TransformerLens. We flatten observations and use a linear projection to create the state token. We remove tanh activations (used in the original DT architecture in state/RTG/action encodings) and LayerNorm everywhere to have more linearity to facilitate analysis.

Training of the decision transformer is performed as described in the original decision transformer paper, although we have not implemented learning rate schedules. 

Attribution and Preference Directions

We can perform logit attributions by decomposing the final logits for [forward, left, and right] actions into a sum of contributions for each component (attention heads, MLPs, input tokens). This analysis can be conceived of as a single-step version of the integrated gradients method described in Appendix B of  Understanding RL Vision.  
 

However, since softmax is invariant under translation, logits aren't inherently meaningful. It is useful to construct the notion of a “preference direction”, which is the difference between the attribution to one action, such as forward, minus the attribution to another, such as right. Figure 3 shows preference direction decomposition. In the Dynamic Obstacles task studied, the actions possible are only forward/right/left, meaning that the pairwise preference of the model loses some information (an action is left out). Still, we find that right and left logits correlate heavily, making a pairwise analysis useful. 

Figure 3:  Attribution Decomposition diagram. Top: Residual stream activation is projected into logits/probabilities.  Bottom: Preference directions can be calculated from logit attribution directions, and the contribution of each transformer component to this direction can be estimated via a dot product.

An Interface for live Circuit Analysis

Inspired by OpenAI’s approach to the Interpreting RL Vision project, I built an app in streamlit which would enable me to generate trajectories by playing mini-grid games myself whilst observing analysis of model activations and preferences over actions. 

Results

Training a small Decision Transformer

We first trained an RL agent via the proximal policy optimisation (PPO) algorithm, which solved the task well and generated trajectories of varying rewards. Ideally, we might have written specific code to sample PPO agents of varying quality; however, it was much quicker to simply store the trajectories generated during PPO training. 

Figure 4: The joint distribution of  Duration and Reward in the training trajectories(scatter + marginal histograms). Positive Reward is shown in red, 0 reward in green and -1 reward in blue. Trajectories are truncated at 300 steps receiving 0 reward, except that a small number of trajectories are truncated earlier as a quirk of the trajectory generation process. 

With these trajectories, we trained decision transformers of varying sizes (0, 1, 2 layers) as well as width/hidden dimension size (32, 64, 128), number of heads (2, 4) and context length (1, 3 and 9 timesteps worth of tokens). We evaluated the decision transformer during training by placing it in the dynamic obstacle environment and observing performance when conditioned on high reward.
 

We found that the transformer that seemed to robustly achieve good performance was a 1 layer transformer with a hidden dimension of 128, two heads and a single time step.
 

Characterising the calibration of this transformer (to confirm it was a good simulator of trajectories of varying quality and not simply a good agent),  we generated calibration curves showing average performance achieved for a range of Reward-to-go’s ( RTGs - the target reward the agent is conditioned to achieve). 
 

Figure 5: Calibration Curve for Dynamic Obstacle Agent performance. The Dotted Teal line shows hypothetical perfect calibration. Blue Line shows the actual calibration. Shaded Region includes 95% of the reward distribution at each RTG. 500 simulations per performed at initial RTG value. 

Now Figure 5 appears to show a highly miscalibrated simulator, at least insofar as the curve is s-shaped rather than linear. I would argue, however, that it is certainly good enough for analysis for the following reasons. 

Having trained a decision transformer and verified that it can generate trajectories consistent with various levels of reward, we proceeded to analyse the decision transformer. 
 

Black Box Model Characterisation

To achieve calibrated performance at RTG = -1, RTG =0 and RTG ~ 1, as shown above, the model must robustly learn 3 different behaviours that activate different levels of RTG. Table 3 lists these and their corresponding RTGs. RTG = 1 is impossible to achieve with a time-discounted reward so we use RTG = 0.90 instead.

RTG

-1

0

0.90

Wall Avoidance

yes*

yes

yes

Obstacle Avoidance

no

yes

yes

Goal Avoidance

yes

yes

no


Table 1: RTG Modulated Behaviours.  Wall/Obstacle/Goal Avoidance means not walking into those objects. The agent must not reach the goal when RTG is negative nor hit walls or obstacles when it is positive. At 0 RTG, it must avoid walls, obstacles and the goal. *Notably, the agent could learn to achieve -1 RTG by walking into walls but appears not to do so. 

Decision Transformer Architecture

The model analysed has a single layer, two attention heads and a context window that only includes the current state and the RTG. Because of the lack of layer norms or non-linearities, which we removed, we can perform an analysis where we decompose the contributions of each component to the preferred direction. Figure 6 describes the architecture.

Figure 6: Decision Transformer architecture. Tokens are initially concatenated, and the time encoding is added to each. In the transformer, mostly linear operations move information from input to output. Attention Heads and the MLP are conceived of as reading from and writing to subspaces of the residual stream. Finally, we take the state embedding and project it into Action Logits. 

Worth noting:

Circuit Results Summary

The analysis in this post involves playing the game, doing ablations, and visualising attributions and weights for different components to try to work out what's happening. I don’t consider any of this rigorous evidence but proof of concept (and practice!).

That being said, the transformer appears to learn various trigrams (RTG, f(State) -> Action) combinations, where f(state) includes things like whether the object in front of the agent is the goal, an obstacle or a wall. There are also broader biases that I haven’t fully understood, such as the tendency to manoeuvre the grid clockwise. 

The extent to which these trigrams can be localised to specific components varies, but broadly speaking:

The rest of the analysis shows:

If you have no time, skip ahead to the obstacle avoidance circuit. It’s the most interesting section.

Figure 7: Decision Transformer Architecture. Coloured diamonds indicate component responsibilities. For the attention heads and MLP layer, ablation of that component leads to aberrations of each behaviour. Amplification or previous head signals are indicated with a diamond in an asterisk.  Goal Avoidance in MLP is from the left. Goal Avoidance in Attn Head 0 is from the top.

Attention and Head Ablation

Inspecting the dot products of components of the residual stream with the forward/right direction can hint at the responsibilities of model components, but performing ablations rapidly demonstrates the value of different components. Looking at each of the behaviours defined above, we ablated Head 0, Head 1 and the MLP one by one and found the following responsibilities, summarised in Figure 7. We used mean ablations for this write-up but found similar results for zero ablation. 

To substantiate these results, I’ll share a few example cases. Wall avoidance and Goal seeking are associated with the initial state embedding. This can be seen by inspecting the dot product of this component in the forward/right direction, which is clearly significant when not facing walls (Figure 9) and facing the goal (Figures 8, 10). Wall following in the clockwise directions appears to be contributed to by both heads and the MLP. 

Ablation of Head 0 makes the DT reach the Goal at RTG = -1 from the top

Since an agent which reaches the goal receives a positive reward, the transformer refuses to enter the goal in cases where RTG = -1. However, ablation of Head 0 both directly reduces the projections against the forward/right direction and leads the MLP to contribute less to this negative direction (i.e. right over forward).  

Figure 8: Agent facing Goal Square at RTG = -1 and bar chart of component contributions before/after ablation to mean of Head 0 to the forward/right direction. The agent classifies forward with probability ~0 without ablation and 0.48 with ablation (as compared to 0.24 right and 0.28 left) with ablation of Head 0.  

Ablation of Head 1 makes the DT hit obstacles at RTG = 0.9 

Since an agent which walks into an obstacle receives a negative reward, the transformer refuses to walk into obstacles to the goal in cases where RTG  is > 0.2-0.3. However, the Ablation of Head 1 directly reduces the projection against the forward/right direction and leads the MLP to contribute less to the opposing direction.  

Figure 9: Agent facing Obstacle at RTG = 0.9 and bar chart of component contributions to the forward/right direction with and without ablation to mean of Head 1. The agent classifies forward with probability ~0 without ablation and 0.9 with ablation of Head 1.  

Ablation of MLP makes the DT approach the Goal at RTG = -1 from the left

Since an agent which reaches the goal receives a positive reward, the transformer refuses to enter the goal in cases where RTG = -1. We’ve seen that ablating head 0 can encourage entering the goal from the top. However, it appears that Head 0 does not likewise inhibit forward into the goal from the left. In fact, ablation of the MLP will do this. 

Figure 10: Agent facing Obstacle at RTG = -1 and bar chart of component contributions to the forward/right direction with and without ablation to mean of MLP. The agent classifies forward with probability ~0 without ablation and 0.88 with ablation of the MLP.  

Analysing Embeddings

Having found concrete evidence showing us where various behaviours originate in our model, we can start going through the model architecture to see how the model reasons about each of the inputs. 

Analysing the RTG and Time Tokens

Reward-to-Go

The Reward-to-Go token is a linear projection of one number, the RTG. For the RTG to affect the decision of the transformer, the “information” in this token must be moved to the state token, which will be projected onto the action prediction. 

Nevertheless, it appears the network learnt to project almost 1:1 in the forward/right direction as the dot product of the RTG embedding with the forward/right direction is ~1.13 (This means if the information was moved to the state token it would contribute to forward/right).  The dot product of the right/left direction with the RTG embedding is ~ -0.16.  This can be interpreted as the RTG embedding encouraging forward movement and probably being fairly ambivalent with rest to right/left when this embedding is attended to in the attention heads later in the model. 

Time

The Time embedding is a learned embedding that essentially functions as a learned look-up table of vectors as a function of the integer time value. Unlike a positional embedding in a regular transformer, which is of the same form, this embedding is added to groups of 3 tokens which make up one timestep, an RTG, a state, and an action group.  The time embedding is added to the state token (and RTG token).

Figure 11: Dot product of Time Embedding and Forward/Right direction as Time Changes. The dot product with the forward/right direction of the time embeddings is not linear as the time embedding is learned. The slight positive slope might be meaningful, but this is unclear. 

Analysing the State Embeddings

Minigrid uses a 3-channel encoding scheme, providing the agent with information about the objects, colours and states in a 7 by 7 grid around itself (when you use a partial observation view, which we do in this project). Figure 12 is useful for understanding the Minigrid observation encoding schema
 

Figure 12: State Encoding Diagram showing how the agents’ partial view (observation) of the state is decomposed into the Minigrid schema and flattened and how weights from a specific position can be projected into the preference direction.  Table 3 below shows the colour/object numeric values.

Figure 13: State view and observation channels from the partial view. LHS: Full environment render (the entire state). RHS: the object embedding, the colour embedding and the state embedding. The agent implicitly occupies the position at row 6, column 3 facing upward in the partial view. The state embedding is empty since we lack any objects in this environment with variable state (like doors). In the object encoding, Wall is 2, Empty is 1, balls (obstacles) are 6, and the Goal is 8. In the colour encoding, Walls are grey/5, Balls are blue/2, empty squares are 0, and the goal is green/1. 

Since I hadn’t thought about it until a decent way through my original analysis, I didn’t one-hot encode the state. This introduces a significant inductive bias induced by spurious index ordinality. For example, in the object channel, the goal is 8, and the obstacle (ball) is 6. So an obstacle looks like ¾ s of the goal, which is really dumb. 

Figure 13 visualises each channel of the state/observation, with colour representing the index. In the object view, obstacles are more distinguishable from empty space than in the colour channel. The state channel provides no information since it’s used to show whether doors/boxes are locked, unlocked or open, which isn’t relevant to the Dynamic Obstacles task. 

This inductive bias led to a weird behaviour where agents would successfully avoid the obstacles but also try to avoid the goal square. This was fixed when I started over-sampling the final step of trajectories, a hack I think wouldn’t be necessary for the one-hot encoded version of the model but is also a generally good trick for getting agents to train faster on mini-grid tasks. 

It’s possible to look at the activation at each position in each channel projected into the forward/right direction since each position in each channel projects into the residual stream (128 dimensions) which is itself eventually projected into the action logits (see Figure 14). 

Figure 14: Dot Product of Linear Embedding weights each position in the observation encoding into the forward/right direction shown from left to right: the object embedding, the colour embedding and the state embeddingColour scales are centred on 0 but the magnitude varies. State magnitudes are tiny (e-38) since the state input is empty in this environment, and weights are sent to 0 by regularisation.

There are a few interpretations worth taking away from here:

  1. this may look fairly noisy because these neurons are doing lots of other things whilst also projecting directly onto the preference direction. 
  2. It’s also possible that collectively they project onto the preference direction quite cohesively, but there are strong correlations between features leading to hard-to-interpret or dispersed processing (diversity hypothesis type problem). 

Nevertheless, I do think there is some valuable information in Figure 14. We know from our observations above that the state embedding mostly works as a wall detector.  

The square in front of the agent in the colour embedding has a large negative value (-0.742). Empty Squares are 0 in colour, obstacles are 2, and walls are 5. This square thus will project weakly onto “don’t walk forward when there is an obstacle in front” and strongly onto “don’t walk forward when there is a wall in front”. 

The square in front of the agent in the object embedding (left in Figure 14) has a slightly positive value (0.27), so it encourages walking into walls and obstacles but encourages walking into the goal square more than other objects.

This is cool because it suggests that the inductive bias associated with my idiotically not one-hot-encoding the vision is directly why the model avoids walls much more strongly than it avoids obstacles. It also explains why early versions of the model avoided the goal square. 


I’ll make a table to show how the two neurons for object/state embedding for the square directly in front of the agent combine to respond to objects/colours. 
 

Object/

Colour

Object ValueColour ValuePositional Contribution to Forward/Right
Empty Space

1

0

0.27 * 1 =  0.27

Wall

2

5

0.27 * 2 - 5 * 0.742 = -3.17

Obstacle

6

2

0.27 * 6 - 2 * 0.742 = 0.136

Goal Square

8

1

0.27 * 8-  1 * 0.742 = 1.42

Table 3: Back of Envelope Calculation for Object and Colour Embeddings activations dot product with forward/right direction for the position directly in front of the agent. The results suggest embedding functions as a strong wall avoider and a weak goal square seeker. 


Table 3 provides evidence for how the state embedding detects walls directly in front of the agent. However, there are lots of other neurons projecting into the forward/right direction in the state embedding, as shown in Figure 15

Figure 15: Histograms of weight projection onto forward/right in objects, colour and state embeddings. Note that there are outliers, and there are enough values of large enough magnitude to seriously affect the decision project of the state embedding. 

Let’s summarise:

  1. You might have expected that the state embedding wouldn’t project directly in any preference directions. In practice, sometimes we see it does, suggesting that besides likely encoding features, it is directly informing the model's output (which we saw in the residual stream dot product contributions earlier). 
  2. Using attribution, we can see that the model uses a combination of colour and object channels to detect walls directly in front of the agent and use this to inhibit forward motion. 
  3.  Analysis of the state encoded positions does not suggest information is only being read from the square in front of the agent, but rather many parts of the view, suggesting the decision transformer may be using other information such as possibly regions walls vs empty space.

Explaining Obstacle Avoidance at positive RTG using QK and OV circuits

To explain how our decision transformer avoids obstacles at positive RTG, we can reference the QK and OV circuits, the encoding scheme, and the RTG attribution. Briefly, the QK (Query-Key) circuit controls which token the head attends to, and the OV (Output-Value) circuit determines how attending to each token affects the logits. More details are here.

First, we need the circuit to activate at a higher RTG. The circuit needs to inhibit moving forward (which we’ll approximate roughly with the forward/right direction), so it will want to attend to the state (not the RTG, since we know it will project positively into forward/right direction and the forward logit). 

Attention to the RTG will be high where the query and key vectors match (key/source is the RTG token and the query/destination token is the state). The QK circuit visualisation in Figure 16 shows how values RTG attention is inhibited by anything that isn’t an empty square in front of the agent. 

Figure 16Head 1 QK Circuit Visualization for across state embedding channels object, colour and state. Colour ranges vary. The colour scheme is Red to Blue, centred on white at 0. The object weight in front of the agent is -0.244, and the Colour weight in front of the agent is -0.217. Roughly speaking, this will inhibit attention to RTG for any non-empty square in front of the agent. 

Now, we can look at the OV circuit for attention head 1 to determine the effect on the output from attending to the state rather than the RTG (Figure 17).  It appears that high object channel values in front of the agent will lead to a decrease in the forward logit. This includes obstacles (6) and the goal (8) (see table 3). 

Figure 17OV Circuit Visualization for the forward action across state embedding channels object, colour and state. Colour ranges vary. The colour scheme is Red to Blue, centred on white at 0. Object weight in front of the agent is -0.864, and the Colour weight in front of the agent is 0.168. The weights here suggest that the object channel is slightly more important and decreases the forward logit when there is any object in front of the agent and when there are objects in the column to the agent’s left (such as walls). 

Lastly, it’s worth noting that the MLP tends to accentuate the output of either attention head, helping them project strongly enough to counteract the strong forward/right projection we usually see coming from the state embedding when not facing a wall. Thanks to Callum for pointing out that we can measure cosine similarity between input directions of MLP neurons and directions of the OV circuit to provide more detail here. I plan to do this soon. 

Discussion

Analysis takeaways

Limitations

Some limitations worth highlighting:

I tweaked hyperparameters to encourage as interpretable a model as possible. This means training for a long time with no meaningful loss decrease to encourage weight decay to favour a generalising solution over a memorising solution [LW(p) · GW(p)]. I suspect that, in many cases, models can be performant and lack nice abstractions. 

Alignment Relevance

A more general list of why MI might be useful to AI alignment is provided here [LW · GW].I think that Decision Transformer MI might be useful to AI alignment in several ways:

  1. Retargeting the Search [LW · GW]. The ELI5 on this post is “understand how the AI works out what it wants and make it want what you want”. John Wentworth’s post makes a case for this approach being a “true reduction of the problem”, and creating an empirical research area around retargeting the search (and understanding the AI’s internal concepts/reasoning) seems like a robustly good thing to do. Decision Transformers enable us to do this in a much simpler context than large language models, but in which the notion of goals and an AI’s internal language feels somewhat natural. I picked the simplest possible problem I thought would work for this. The results make me optimistic about moving to tasks where we might find something akin to search (like maze solving, for example) or other tasks with instrumental goals. 
  2. Providing Mechanistic Explanations behind Goal Misgeneralization/Reward Mis-specification. Goal misgeneralization has been studied in simple RL tasks where we might plausibly be able to train decision transformers and interpret them. If circuit universality holds, then insights from decision transformers might transfer to other architectures. Still, even if they don’t, most LLMs are transformers, so I’d expect insights to be useful anyway. 
  3. (The instrumental goal of) Giving Mechanistic Interpretability a middle ground. MI is a growing field with lots of low-hanging fruit (like this project!). Decision Transformers give circuits and features a slightly different flavour. RTG modulation (see the RTG scan analysis) seems like it could be potent as a way to help us find circuits. Being able to vary RTG and thus behaviour feels like flashing lights on and off from a distance to draw attention to them. This work showed that at least basic interpretability is possible in this context, which gives us an intermediate between algorithmic tasks and LLM. Decision Transformers also obviously bring interpretability into the context of many RL tasks and have elements of multi-modality (tokens can represent visual input, the specified reward, and previous actions), and states can have text components such as in some mini-grid tasks and in GATO. 

Capabilities Externalities Assessment

I believe strongly that given the precipice we find ourselves standing on, it behoves everyone publishing empirical work to consider how it might backfire from an alignment perspective. Part of this means making a public commitment to care about capabilities, which you can consider as doing. I promise you, seriously if you think this or any other work I’m doing might enhance capabilities. If you feel this is the case, please consider the best way to notify me of this, probably an email (so as not to highlight any capabilities enhancing insight further in a public forum) but feel free to CC anyone you think would be a reasonable third party. 

Before publishing this work, I considered several factors and spoke to various alignment researchers. I have published it because it is a reasonable trade-off between empirical progress on alignment and possible capabilities enhancement. 

For reference, here are some posts on capabilities/MI. 

Future Directions

This project was designed to provide early feedback for what could be a much longer research investigation involving larger systems and more comprehensive analyses followed by interventions such as model editing. 

Meta

Acknowledgements

I’m very grateful to

Many thanks to all the people who have given feedback on this draft, including Callum McDougal, Dan Braun, Jacob Hilton, Tom Lieberum, Shmuli Bloom and Arun Jose.

Author Contributions

Paul Colognese wrote up the initial project description, and he, Joseph and Callum McDougall had a few early meetings to discuss the project. The PPO algorithm was almost entirely taken from the ARENA curriculum written by Callum, based on MLAB content. The Decision Transformer code was based heavily on the original code base, which is still a submodule of the GitHub repository. 

The rest of the work, building the code base, training the agents, building the interactive app and writing this post, was done by Joseph Bloom, whose perspective the write-up is written in. 

Feedback

Please feel free to provide feedback below or email me at jbloomaus@gmail.com. The GitHub repository is a reasonable place for technical feedback. Feedback on the app/codebase would be appreciated too!

Glossary

I highly recommend Neel’s glossary [LW · GW] on Mechanistic Interpretability.

13 comments

Comments sorted by top scores.

comment by Oskar Hollinsworth (oskar-hollinsworth) · 2023-03-08T14:23:25.436Z · LW(p) · GW(p)

Really interesting and impressive work, Joseph.

Here are a few possibly dumb questions which spring to mind:

  • What is the distribution of RTG in the dataset taken from the PPO agent? Presumably it is quite biased towards positive reward? Does this help to explain the state embedding having a left/right preference?
  • Is a good approximation of the RTG=-1 model just the RTG=1 model with a linear left bias?
  • Does the state tokenizer allow the DT to see that similar positions are close to each other in state space even after you flatten? If not, might this be introducing some weird effects?
Replies from: Jbloom
comment by Joseph Bloom (Jbloom) · 2023-04-17T12:37:17.825Z · LW(p) · GW(p)

My apologies! I thought I had responded to this. Better late than never.  All reasonable questions. 

"What is the distribution of RTG in the dataset taken from the PPO agent? Presumably it is quite biased towards positive reward? Does this help to explain the state embedding having a left/right preference?"

The RTG distribution taken from the PPO agent is shown in Figure 4. It is the marginal distribution of reward, since the reward only occurs at the end of the episode, the RTG and Reward distributions are identical more/less. 

If you are referring to the tendency for the agent conditioned on low RTG (0.3) to get RTG closer to (0.8) as "biased towards positive reward". I think this is a function of the training distribution which includes way more RTG~0.8 than 0.3 so yes. 

The state embedding having a left/right preference is probably a function of the PPO agent getting success going clockwise, this being reinforced and then this dominating the training data for the DT. 

"Is a good approximation of the RTG=-1 model just the RTG=1 model with a linear left bias?"

I'm not sure what you mean by "linear left bias". the RTG=-1 model behaves characteristically differently (see table 1).

"Does the state tokenizer allow the DT to see that similar positions are close to each other in state space even after you flatten? If not, might this be introducing some weird effects?"

Yes I believe the encoding used had that property. I've seen replicated these results with a less nasty encoding and the results are mostly similar if a little easier to interpret. 

comment by TurnTrout · 2023-02-07T18:15:35.774Z · LW(p) · GW(p)

The ablations seems surprisingly clear-cut. I consider myself to be very on board with "RL-trained behaviors are contextually activated and modulated", but even I wasn't expecting such strong localization. 

To be calibrated for RTG values between 0 (non-inclusive) and 1, the decision transformer must reach the goal in a precise number of steps which is difficult given that obstacles move randomly and can extend the amount of time the agent must take to reach the obstacle. To do this well is likely hard, given that there isn’t much training data for low positive RTG values (Figure 4). 

Seems to me that randomness wouldn't prevent the agent from bieng calibrated, because even though any given episode might deviate from the prescribed number of steps, on average the randomness can (presumably) be made to add up to that number. EG it might be hard to bump into the goal after exactly 14 steps due to random obstacles, but I'd imagine ensuring this falls between 10 and 18 steps is feasible?

This inductive bias led to a weird behaviour where agents would successfully avoid the obstacles but also try to avoid the goal square.

This seems like an important manifestation of "models don't 'get reward' [LW · GW], they are shaped by reward [LW · GW]"; even on a simple task where presumably agents can fully explore all relevant options. The e.g. observation encoding (where an obstacle is "3/4" of a goal) matters when predicting what behavioral shards & subroutines get trained into the policy, or considering what e.g. early-stage policies will be like.

(I thiiink I weakly predicted this particular behavior in advance, when I read about the encoding.)

You might have expected that the state embedding wouldn’t project directly in any preference directions. In practice, sometimes we see it does, suggesting that besides likely encoding features, it is directly informing the model's output (which we saw in the residual stream dot product contributions earlier). 

Wait, how? Isn't the state observation constant in this task? I'm guessing you're discussing something else?

Replies from: Jbloom
comment by Joseph Bloom (Jbloom) · 2023-02-07T21:41:34.019Z · LW(p) · GW(p)

The ablations seems surprisingly clear-cut. I consider myself to be very on board with "RL-trained behaviors are contextually activated and modulated", but even I wasn't expecting such strong localization. 

 

Neither was I.  After the fact, it seems easy to come up with reasons why this might be the case. I think measuring this with something like excluded loss might enable a more precise quantification of exactly how localised it was. I also don't see strong reasons to expect this to generalise to larger models. If I see similar results when the task is more complicated, the model is bigger and especially with a larger context window, then I will be more interested in trying to precisely describe how/why/when you get more localisation. 

Seems to me that randomness wouldn't prevent the agent from bieng calibrated, because even though any given episode might deviate from the prescribed number of steps, on average the randomness can (presumably) be made to add up to that number. EG it might be hard to bump into the goal after exactly 14 steps due to random obstacles, but I'd imagine ensuring this falls between 10 and 18 steps is feasible?

I think you might be right in the infinite training data regime, where I would expect it to be unbiased however I suspect that the training data being sparse, especially in the low positive RTG reason is enough to make the signal weak. The loss incurred by finishing in the wrong number of steps is likely very small compared to failing when you have positive RTG or succeeding when you have positive RTG, so it could also be that the model doesn't allocate much capacity to being well-calibrated in this sense. 

This seems like an important manifestation of "models don't 'get reward' [LW · GW], they are shaped by reward [LW · GW]"; even on a simple task where presumably agents can fully explore all relevant options. The e.g. observation encoding (where an obstacle is "3/4" of a goal) matters when predicting what behavioral shards & subroutines get trained into the policy, or considering what e.g. early-stage policies will be like.

I thought this too and should have remembered to cite "Reward is not the Optimization target". I feel like that concept is now more visceral for me.  A relevant takeaway that may or may not be obvious is that simulators with inductive biases might be better/worse at simulating particular stuff as a function of their inductive biases. In this case, the positive RTG range was more miscalibrated as a result of the bias. 

Wait, how? Isn't the state observation constant in this task? I'm guessing you're discussing something else?

I'm not sure what you mean by constant. If it were constant then the agent/obstacles wouldn't be moving? I'll elaborate a little bit in case that helps. 

Since RTG gives information about whether the agent should go forward/not in many contexts, you might have expected (although now I think I wouldn't) the residual stream embedding for the state token not to directly contribute to one logit over another before RTG has been seen. In practice, it seems like it does. For example, in this situation (picture below from the app) with RTG = 90, there is a wall to the left of the agent and the state appears to strongly encourage forward/right. I interpret this as as "some agent behaviours are independent of RTG and can be encouraged as a function of the observation/state before RTG is seen".  

To clear up some language as well:
- state encoding -> how is the state represented? weird minigrid schema.
- state token -> the value of a state represented as a vector input to the model.
- state embedding -> an internal representation of the state at some point in the model.

I should have written state-token not state-embedding in the quotes paragraph.  Apologies if this led to confusion. 

comment by TurnTrout · 2023-02-07T17:37:15.666Z · LW(p) · GW(p)

This looks cool, going to read in detail later. 

Start working on larger/harder RL tasks that involve more complicated algorithms and/or search and/or alignment relevant phenomena such as goal misgeneralization. The way I see this going is Minigrid  D4RL < ProcGen < Atari < whatever tasks Gato does. 

Note that team shard (MATS stream) is already doing MI on procgen. I've been supervising/working with them on procgen for the last few weeks. We have a nice set of visualization techniques, have reproduced some of Langosco et al.'s cheese-maze policies, and overall it's been going quite well. 

Replies from: Jbloom, butanium-1
comment by Joseph Bloom (Jbloom) · 2023-02-07T21:07:51.728Z · LW(p) · GW(p)

Thank you for letting me know about your work on procgen with MI. It sounds like you're making progress, particularly I'd be interested in your visualisation techniques (how do they compare to what was done in Understanding RL Vision?) and the reproduction of the cheese-maze policies (is this tricky? Do you think a DT could be well-calibrated on this problem?). 

Some questions that might be useful to discuss more:

  • What are the pros/cons of doing DT vs actor-critic MI? (You're using Actor-Critic of some form?). It could also be interesting to study analogous circuits in the DT vs AC scenarios. 
  • I haven't done anything with CNNs yet, for simplicity, but I might be able to calibrate my expectations on the value/challenges involved by chatting to the team shard MATS stream. 

Glad to hear your progress is going well! I'll be in the Bay Area for EAG if anyone from the team would like to chat. 

Replies from: TurnTrout
comment by TurnTrout · 2023-02-14T00:14:17.912Z · LW(p) · GW(p)

We're studying a net with the structure I commented below [LW · GW], trained via PPO. I'd be happy to discuss more at EAG. 

Not posting much publicly right now so that we can a: work on the research sprint and b: let people preregister credences in various mechint / generalization propositions, so that they can calibrate / see how their opinions evolve over time. 

comment by Butanium (butanium-1) · 2023-02-08T01:00:17.539Z · LW(p) · GW(p)

Are you using decision transformers or other RL agents on procgens ? Also, do you plan to work on coinrun ?

Replies from: TurnTrout, Jbloom
comment by TurnTrout · 2023-02-13T23:39:39.911Z · LW(p) · GW(p)

We're analyzing the mech-int-ungodly Impala architecture, from the paper. Basically

=== Impala
conv
maxpool2D
---- residual x2:
relu
conv
relu
conv
residual add from input to this residual block
=== /IMPALA
(repeat 2 more impalas)
---
relu
flatten
fully connected
relu
---
linear policy and value heads

so this mess has sixteen conv layers, was trained on pixels. We're not doing coinrun for this MATS sprint, although a good amount of tooling should cross over.

This has presented some challenges -- no linearity from decomposing an ongoing residual stream into head-contributions.  

comment by Victor Levoso · 2023-02-08T01:47:18.922Z · LW(p) · GW(p)

Oh nice, I was interested on doing mechanistic interpretability on decision transformers myself and had gotten started during SERI MATS but now was more interested in looking into algorithm distillation and the decision transformers stuff fell to the wayside(plus I haven't been very productive during the last few weeks unfortunately). It's too late to read the post in detail today but will probably read it in detail and look at the repo tomorrow. I'm interested in helping with this and I'm likely going to be working on some related research in the near future anyway. Also btw I think that once someone gets to the point that we understand what's going on the setup from the original dt paper it would be interesting to look into this: https://arxiv.org/abs/2201.12122

Also the dt paper finds their model generalizes to bigger rtg than the training set in the seaquest env and it would be interesting to get a mechanistic explanation of why that happens (tough that's an atari task and I think you are right in that that's probably going to have to come later cause CNN are probably harder to work with).

Another thing to note is that OpenAI's VPT while it's technically not a decision transformer (because it doesn't predict rewards if I remember correctly) it a similar kind of thing in that is Ofline RL as sequence prediction, and is probably one of the biggest publicly avaliable pretrained models of this kind. There's also multiple open source implementation of Gato that could probably be interesting to try to do interpretability on. https://github.com/Shanghai-Digital-Brain-Laboratory/BDM-DB1

Also training decision transformers on minerl(or in eleuther's future minetest enviroment) seems like what might come next after atari(the task gato is trained are mostly atari games and google stuff that is not publicly avaliable if I remember correctly)

(sorry if this is too rambly I'm half asleep and got excited because I think work on dt is a very potentially promising area on alignment and was procrastinating on writing a post trying to convince more people to work on it, and I'm pleasantly suprised other people had the same idea)

Replies from: Jbloom
comment by Joseph Bloom (Jbloom) · 2023-02-08T06:07:27.146Z · LW(p) · GW(p)

Hi Victor, 

Glad you are keen on this area, I'd be very happy to collaborate. I'll respond to your comments here but am happy to talk more after you've read the post. 

The linked paper (Can Wikipedia help Offline Reinforcement Learning) is very interesting in a few ways, however, I'd be interested in targeted reasons to investigate this specifically. I think working with larger models is often justified but it might make sense to squeeze more juice out of the small models before moving to larger models. Happy to hear the arguments though. 

Also the dt paper finds their model generalizes to bigger rtg than the training set in the seaquest env and it would be interesting to get a mechanistic explanation of why that happens (tough that's an atari task and I think you are right in that that's probably going to have to come later cause CNN are probably harder to work with)

I'm glad you asked about this. In terms of extrapolation, an earlier model I trained seemed to behave like this.  Analysis techniques like the RTG Scan functionality in the app present ways to explore the mechanisms behind this which I decided not to explore in this post (and possibly in general) for a few reasons:

It's not clear to me that this is more than a coincidence. I think it could be that in the space of functions that map RTG to behaviour, for certain games, it is possible to learn coincidentally extrapolating functions. If the model were to develop qualitatively different behaviour (under some definition) in out-of-distribution RTG ranges for any task, then my interest will be renewed. 

I suspect doing something like integrated gradients for CNN layers is pretty doable (maybe that's what the MATS shard team have done, see one of Alex Turner's comments) but yeah, they are probably harder to work with. 

Another thing to note is that OpenAI's VPT while it's technically not a decision transformer (because it doesn't predict rewards if I remember correctly) it a similar kind of thing in that is Ofline RL as sequence prediction, and is probably one of the biggest publicly avaliable pretrained models of this kind. There's also multiple open source implementation of Gato that could probably be interesting to try to do interpretability on. https://github.com/Shanghai-Digital-Brain-Laboratory/BDM-DB1

Thank you for sharing this! I'd be very excited to see attempts to understand these models. I've started with toy models for reasons like simplicity and complete control but I can see many arguments in favour of jumping to larger models. The main challenge I see would be loading the weights into a TransformerLens model so we can get the cache enabling easy analysis. This is likely quite doable. 

Also training decision transformers on minerl(or in eleuther's future minetest enviroment) seems like what might come next after atari(the task gato is trained are mostly atari games and google stuff that is not publicly avaliable if I remember correctly)

Interesting!

comment by Victor Levoso · 2023-02-10T02:34:04.291Z · LW(p) · GW(p)

About the sampling thing. I think a better way to do it that will work for other kind models would be trainining a few diferent models that do better or worse on the task and use different policies, and then you just make a dataset of samples of trajectories from multiple of them. Wich should be cleaner in terms of you knowing what is going on on the training set than getting the data as the model trains (wich on the other hand is actually better for doing AD)

That also has the benefit of letting you study how wich agents you use to generate the training data affects the model. Like if you have two agents that get similar rewards using diferent policies does the dt learn to output a mix of the two policies or what exactly?. The agents don't even need to be neural nets, could be random samples or a handcrafted one.

For example I tried training a dt-like model(though didn't do the time encoding) on a mixture of a DQN that played the dumb toy env I was using (the frozen lake env from gym wich is basically solved by memorizing the correct path) and random actions and it apparently learned to ouput the correct actions the DQN took on rtg 1 and a uniform distribution of the action tokens on rtg 0.

Replies from: Jbloom
comment by Joseph Bloom (Jbloom) · 2023-02-10T10:25:39.954Z · LW(p) · GW(p)

Agreed about the better way to sample. I think in the long run this is the way to go. Might make sense to deal with this at the same time we deal with environment learning schedules (learning on a range of environments where the difficulty gets harder over time). 

I hadn't thought about learning different policies at the same RTG. This could be interesting. Uniform action preference at RTG = 0 make sense to me, since while getting a high RTG is usually low entropy (there's one or a few ways to do it), there are many ways not to get reward. In practice, I might expect the actions to just match the base frequencies of the training data which might be uniform.