Interpreting autonomous driving agents with attention based architecture

post by Manav Dahra (manav-dahra) · 2025-02-01T23:20:27.162Z · LW · GW · 0 comments

Contents

  Abstract
  Introduction
    Environment
    State
    Actions & Rewards:
    Agent:
        Attention layers:
  Methods
      Replication of agent behaviour
      Analysis on Embedding matrices:
        Notes:
      QK & OV circuit analysis of Transformer block
      Study on activation patterns of an episode run:
        Notes:
      Feature importance
        Notes:
    Interpretation:
        Step 1: Input Features
        Step 2: Attention Mechanism (QK Circuits)
        Step 3: Q-Value Calculation (OV Circuits)
    Discussion
    Future work:
      Acknowledgements:
      Appendix:
        Glossary:
        Model architecture:
        Environment configuration:
        Evaluation:
None
1 comment

Abstract

In this experiment, I study the behaviour of a Deep Q-Network agent with attention based architecture driving a vehicle in a simulated traffic environment. The agent learns to cross an intersection without traffic lights and under a dense traffic setting while avoiding collisions with other vehicles. 

As first part of this experiment, I first train the agent using the attention based architecture. Later, I study the behaviour of the agent by applying some interpretability techniques on the trained Q-network and find that there is some evidence to show that network comprises of 3 different layers serving specific functions, namely - sensory (embedding layers), processing (attention layer) & motor (output layers). 
 
The purpose of this experiment is to gain deeper understanding of the agent from interpretability perspective which may be used for developing safer agents in real world applications.

Introduction

With increasing usage and deployment of autonomous driving vehicles on roads, it is important that the behaviour of these autonomous agents is thoroughly tested, understood and analysed from a safety perspective. 

To measure safety, one of the traditional approaches involves running large number of experiments in a simulated environment and collecting statistics on the number of failures cases. While this is certainly useful and gives an overall perspective on the failures modes of the agent, however it does not say anything about the specificity of those failures. 

One may prefer to delve deeper and study why a particular agent failed and understand if that behaviour was a consequence of agent's actions or a conditioning of that environment. This calls upon applying some of the interpretability techniques on the agent's behaviour and derive more specific conclusions on what features the agent senses from the environment and what decisions it takes.

For this experiment, I study the behaviour of a trained agent by applying some interpretability techniques on the policy network of the model and share my observations and conclusions derived from the experiment. 

The agent under study is only trained and deployed in a simulated environment (with enough simplifications) and is far from a real world setting and it's complexities. While this does not really represent the behaviour of the agent in the real world, I still think a study like this can be worthwhile in providing some insights on what kind of decision making process is learned by the agent and how it can be used to make agents more safer.

Now, to give some context on the problem, let's understand how 

  1. The agent and the environment interplay
  2. The model architecture of the agent

Then I will share my observations and insights.



Environment

The environment used in this experiment is an Intersection-env which is a customised gymnasium type that has agent-environment loop

Fig 1: Agent environment loop in reinforcement learning

The environment setting contains total N=15 vehicles at any given point in time.
The agent in question is controlling the green vehicle while the blue vehicles are simulated by traffic flow model which in this case is controlled by intelligent driver model. The intelligent driver model is less nuanced and lacks any complex behaviour in comparison to the agent in question. The blue vehicles are spawned at random points initially. 

And below is the animated image of a trained agent crossing the intersection.

Fig 2: A recorded video from the evaluation stage showing agent successfully crossing intersection without collision

State

The joint observation of a road traffic with one agent denoted - s0 and other vehicles - N is described by a combined list of individual vehicle states:

where 

Individual values of each state variables are described as follows:

FeatureDescription
presenceDisambiguate agents at 0 offset from non-existent agents.
xWorld offset of agent vehicle or offset to agent vehicle on the x axis.
yWorld offset of agent vehicle or offset to agent vehicle on the y axis.
vxVelocity on the x axis of vehicle.
vyVelocity on the y axis of vehicle.
headingHeading of vehicle in radians.
coshTrigonometric heading of vehicle.
sinhTrigonometric heading of vehicle.

The vehicle kinematics are described by Kinematic Bicycle Model. More on this topic can be found here
 

Actions & Rewards:

The agent drives the vehicle by controlling its speed chosen from a finite set of actions A = {SLOWER, NO-OP, FASTER}.

Rewards:

RewardAction
1Agent driving at maximum velocity
-5On collision
0Otherwise



Agent:

The agent used in this experiment uses a DQN algorithm with attention based architecture which was first proposed in the paper on - Social Attention for Autonomous Decision-Making in Dense Traffic [1].

For this experiment, I delve deeper on the agent's policy network since that network encodes the decision making of the agent. 

Here's how the network looks like:

Layer NameDimensions
Ego & Others embedding layer - 07x64
Ego & Others embedding layer - 164x64
Attention Layer Query64x64
Attention Layer Key64x64
Attention Layer Value64x64
Output Layer - 064x64
Output Layer - 164x64
Output Layer - predict64x3


Embedding layers:

Fig 3: Schematic of Embedding layers of agent's Q-net 

It is composed of several linear identical encoders, a stack of attention heads, and a linear decoder.

There are two embedding layers:
1. Ego embedding - dedicated to tracking features for vehicle driven by the agent itself.
2. Others embedding - dedicated to tracking features for vehicles driven by other agents.

Attention layers:

Fig 4: Schematic of attention layer of agent's Q-net 

 

Essentially, a single query Q = [] and a set of keys K = [] are emitted by doing linear projections on the state of the environment. Here, N is the number of vehicles including the agent's vehicle.

The outputs from all heads are finally combined with a linear layer, and the resulting tensor is then added to the residual networks. 

Authors of the paper claim that an agent with the proposed attention architecture shows increased performance gains in autonomous decision making under a dense traffic setting. Their study involved comparing the performance of the agent against common architectures like FCN and CNN. The social interaction patterns with other vehicles were visualised and studied qualitatively. 


Methods

As part of this experimentation, I first replicate the behaviour of the agent as described by the authors of the paper. Then, I proceed to study the agent's behaviour by borrowing some well recognised interpretability techniques in the literature like Understanding RL Vision[2] and A Mathematical framework for Transformer Circuits[3]
After collecting the observations, I derive some key insights on the behaviour of the agent and mention some interesting directions for work in future.

List of techniques applied in this experiment are as follows - 

  1. Replication of agent behaviour
  2. Analysis on Embedding matrices
  3. QK & OV circuit analysis of Transformer block
  4. Study on activation patterns of an episode run
  5. Analysis on feature importance in reference to output layer

     

Source code for training the agent and details on choice of hyper-parameters, model architecture parameters and extended results are in the Appendix section.
 

Replication of agent behaviour

First step is to replicate the studies of the paper by training and evaluating the agent in dense traffic setting having single intersection on the road without any traffic lights. 
According to my observations, I confirm that the agent successfully learns to cross the intersection while avoiding collisions with other vehicles in most scenarios. 

Following clip shows the trained agent navigating through the intersection along with it's attention patterns for the time step when the agent decides to slow down noticing another vehicle in the way.


Notes: 
Above results confirm that the agent learns to navigates through the crossing avoiding collisions in most scenarios, by paying attention to the other vehicles in on the crossing. Attention to other vehicles at every time step of the episode are highlighted by thick coloured lines from green to blue vehicles.

 

Analysis on Embedding matrices:

First I analyse what insights I can find in the weights of embedding matrices. Since the embedding layers consist of 2 hidden layers, namely,  being the first layer with dimension 7x64 and  being the second layer with dimension 64x64. 
If embedding matrix weights are computed as follows:


then,  dimension is 7x64 
 

Checking the embedding matrix for both layers individually reveal the following:

Fig 5: Comparison of Ego embedding matrices between Untrained and Trained agents.
Difference highlights that y coordinate feature is assigned large weights
Fig 6: Comparison of Others embedding matrices between Untrained and Trained agents.
Difference highlights that x,y coordinates, vy and sinh features are assigned large weights

Notes:

  1. Ego embedding matrix - Higher weight values assigned on features y and cosh
  2. Others embedding matrix - Higher weight values assigned on features x, y, vy and sinh
  3. Heat maps show other features are assigned relatively lesser weights.

This shows us that sensory function of the model is picking up some interesting signals from the environment.

Next I analyse how these features interact with the attention layers of the network.

 

QK & OV circuit analysis of Transformer block

According to the research done by Anthropic team, they outline a mathematical framework of understanding attention layers. 

Quoting from the paper:

Attention heads can be understood as having two largely independent computations: a QK (“query-key”) circuit which computes the attention pattern, and an OV (“output-value”) circuit which computes how each token affects the output if attended to.

Here, I study any emerging QK and OV circuit patterns in the attention layers. To study the emergence of any learned structure, I compare it with the Untrained vs Trained network and find their squared difference measures.

As seen from the code, The agent in this experiment has only 1 layer with 4 attention heads whose vectors are first slice in 4 parts, computed and then later concatenated. Hence, for simplicity, I computed QK and OV circuits for the combined matrix instead of slicing it in 4 parts.


 

where, 

  

where, 

 

QK and OV circuits:

Fig 7: QK circuit computed using Ego embedding layer
Left: represents no learned attention patterns
Right: shows y-coordinate and velocities heavily being attended to, in 3rd row from bottom.
Fig 8: QK circuit computed using Others embedding layer
Left: shows no learned attention patterns
Right: shows features (x, y, vy & sinh) being heavily attended to, in 3rd row from bottom.
Fig 9: OV circuits between different embedding layers
Left: shows no learned patterns
Right: represents what action will be taken by the agent when that feature is attended to.

Furthermore, attention scores and output value matrices are shown in Fig 9a and 9b (in appendix section), show some interesting learned structural patterns.

Above figures show distinctions between the two scenarios, one where the agent is untrained and the other where the agent is trained. The QK and OV Circuits show high activations for for certain lines and areas in the heat map indicating a learned structure/pattern from the interplay between the agent, other vehicles and the environment.

Notes:

  1. Fig 7 -
    1. Agent is attending to it's own y-coordinate with a strong positive correlation.
    2. While vy is being attended to with a negative correlation.
    3. This suggests the agent uses its own velocity and heading to assess risk when making decisions.
  2. Fig 8 -
    1. Strong positive attention between Ego embedding y coordinate and Other vehicles x, y coordinates and vy velocity. Hinting at the possibility of computing a distance metric.
    2. Strong negative attention between Ego embedding y coordinate feature and Others embedding presence feature.
    3. Above indicate that the model has learned to focus on other vehicles' movement to adjust its own strategy.
  3. Fig 9 -
    1. Shows strong activations for velocity (vx, vy) and heading (cos_h, sin_h), meaning the model values its own speed and direction when choosing actions.
    2. Presence and position (x, y) of other vehicles become crucial. The agent learns to consider the positions of other vehicles when deciding whether to slow down, idle, or speed up.


Study on activation patterns of an episode run:

Let's study the agent's activations per time step and intermediate attention matrices collected over the full episode run.

Here, I extract environment frames and the activations of Attention head vs Vehicle for each time step of the evaluation shows the following:

Fig 10: 
Frame rendered from the scene at the time step when agent decides to slow down near the intersection.
Fig 11: Heat map of Attention head vs Vehicle number
Vertical lines of activations show that a particular vehicle is being attended to by all 4 attention heads in varying degrees.

Notes:

  1. Vertical lines of activations show that a particular vehicle is being attended to by all 4 attention heads in varying degrees.
  2. Vehicle_0 is the green vehicle which is controlled by the agent and hence is always being attended to by all 4 attention heads at all time steps.
  3. For other vehicles the activations increase as they appear closer to the intersection.

Next, I study output layers further to understand what insights I can draw from there.

 

Feature importance

In this section I try to understand which features does the model learn to extract from the environment state. One of the common techniques for finding feature importance is that of computing Integrated gradients. The integrated gradients give an understanding of overall importance of features. 

For this scenario, I compare the integrated gradients between the Untrained vs Trained networks averaged over 30 episodes and found the following.

Fig 14: Computed integrated gradients for untrained and trained agents. Left: For untrained agent, turns out to be negligible. Right: For trained agent - presence, x, y coordinates and angles are important.

Notes:

Q-net's output layer makes final decision by looking into presence, x-y coordinates and sinh features. All the above graphs pile up more evidence to the previous notes/observations I gathered earlier.


In following section, I make some speculative interpretations about the agent. I would love to validate some of those interpretations by conducting more thorough experiments in the future. For the claims that I am not confident, I have marked them inline.


Interpretation:

Following is a walkthrough of the agent in action along with the interpretation on key observations/notes collected so far.

Step 1: Input Features

The agent observes its environment using the following features:

  1. Ego Features
    • vx, vy: Ego vehicle’s velocity (speed in x and y directions).
    • cos_h, sin_h: Ego vehicle’s heading direction.
    • x, y: Ego vehicle’s position.
  2. Others Features
    • presence: Indicates whether another vehicle is nearby.
    • vx, vy: Other vehicle’s velocity.
    • x, y: Other vehicle’s position.

Step 2: Attention Mechanism (QK Circuits)

The Query-Key (QK) circuits determine which features should be attended to when making a decision.

  1. Ego Embedding (Self-awareness)
    1. The agent queries its own vx, vy, cos_h, and sin_h to understand its motion.
    2. It attends to (y) to determine its position in the intersection. (why x coordinate does not have high activations ? Something that I would like to find more later.)
  2. Others Embedding (Awareness of other vehicles)
    1. The agent queries presence to check if another vehicle is nearby.
    2. It attends to vx, vy, x, y of other vehicles to predict their movement.
    3. If a vehicle is approaching, the attention on presence and vx increases. (unverified claim, is model computing some distance metric ? can we verify this ?)
  3. When does the agent choose "Slow"?
    1. If the presence of other vehicles (presence feature) is high.
    2. If relative velocity (vx, vy) suggests a collision risk.
    3. Strong correlations in Others embedding QK circuit show that the agent reacts to nearby vehicles. (unverified claim, same reason as above)
  4. When does the agent choose "Idle"?
    1. Likely in neutral situations where no immediate action is needed.
    2. OV circuits show that cos_h and sin_h influence the decision, meaning the agent aligns with road orientation.
  5. When does the agent choose "Fast"?
    1. If there is clear space ahead (low presence activation in Others embedding).
    2. Trained OV circuits show positive activations for vx and sin_h, meaning the model prefers accelerating when aligned with the road.

Step 3: Q-Value Calculation (OV Circuits)

The Output-Value (OV) circuits determine how much each feature contributes to the Q-values for each action.

Feature Contributions to Actions:

Discussion

This experiment shows that a DQN agent with attention based mechanism can learn to cross a road intersection environment under a dense traffic setting with reasonable levels of safety.

Additionally, analysis on attention layers of the agent's Q-network show that there is sufficient evidence to believe that these layers learn some high level Q-policies that drive the decision making of the agent. Although, it was possible to find some high level policies, more work is needed to find how different policies combine together to form a concrete algorithm. 

It was shown to some extent, that the agent learned to delegates different types of functions to it's embedding, attention and output layers. it is fair to say that the embedding, attention and output layers of the agent learn to serve the functions of sensory, processing and motor neurons. 


Future work:

This experiment was limited in scope and timing (up to 4 weeks). For this reason, I chose to focus on replicating the behaviour of the agent and running various types of interpretability techniques to narrow down on a promising approach of finding exact behaviour of the agent in further research. 

Following are some of the areas that can be explored in future:

  1. Does agent compute a distance metric from the features (x, y) coordinates of the other vehicles ?
  2. Do changing the y coordinate of intersection in the environment break the agent's decision making ? Has the agent really generalised or simply memorised ?
  3. Enlist different policies learned by the agent on same action. Example, Slowing down - high activations of presence feature and high activations of  vx feature of other vehicles. Do these two policies correlate highly for the agent or largely stay independent ?
  4. Train model on more attention heads and layers with more episodes, repeat the experiments. Do we get any new insights ?
     

Acknowledgements:

My sincere thanks to this amazing community who have made Interpretability research easily accessible reachable to general public. I hope that my experiments bring some value to others and to this community. I look forward to delve deeper in this topic, any support & guidance is highly appreciated.

I would also like to thank BlueDot impact for running a 12 week online course on AI Safety fundamentals. I conducted this experiment as part of the project submission phase of this course and I am grateful to their course facilitators and their team for conducting amazing sessions and providing a comprehensive list of resources on the key topics.  

I'm looking forward to collaborating. Reach out to me on 
My portfolio
LinkedIn
 


Appendix:

Github source code

Glossary:

DQN: Deep Q-Network
FCN: Fully Convolutional Net
CNN: Convolutional Neural Net
QK Circuit: Query-Key Circuit
OV Circuit: Output-Value Circuit
 

Model architecture:

EgoAttentionNetwork(
  (ego_embedding): MultiLayerPerceptron(
    (layers): ModuleList(
      (0): Linear(in_features=7, out_features=64, bias=True)
      (1): Linear(in_features=64, out_features=64, bias=True)
    )
  )
  (others_embedding): MultiLayerPerceptron(
    (layers): ModuleList(
      (0): Linear(in_features=7, out_features=64, bias=True)
      (1): Linear(in_features=64, out_features=64, bias=True)
    )
  )
  (attention_layer): EgoAttention(
    (value_all): Linear(in_features=64, out_features=64, bias=False)
    (key_all): Linear(in_features=64, out_features=64, bias=False)
    (query_ego): Linear(in_features=64, out_features=64, bias=False)
    (attention_combine): Linear(in_features=64, out_features=64, bias=False)
  )
  (output_layer): MultiLayerPerceptron(
    (layers): ModuleList(
      (0-1): 2 x Linear(in_features=64, out_features=64, bias=True)
    )
    (predict): Linear(in_features=64, out_features=3, bias=True)
  )
)

 

Environment configuration:

N: number of vehicles
Observations type: Kinematics
Observation space: 7

where 

Action space: 3 {SLOWER, NO-OP, FASTER}


Hyper-parameters:

Gamma: 0.95
Replay buffer size: 15000
Batch size: 64
Exploration strategy: Epsilon greedy
Tau: 15000
Initial temperature: 1.0
Final temperature: 0.05

 

Evaluation:

Running evaluation over 10 episodes with display enabled shows high scores and successful navigation through the intersection.

2025-02-01 10:14:16.251 Python[52966:5211290] +[IMKClient subclass]: chose IMKClient_Modern
2025-02-01 10:14:16.251 Python[52966:5211290] +[IMKInputSession subclass]: chose IMKInputSession_Modern
/Users/mdahra/workspace/machine-learning/rl-interp/.venv/lib/python3.12/site-packages/rl_agents/agents/deep_q_network/pytorch.py:80: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/utils/tensor_new.cpp:281.)
  return self.value_net(torch.tensor(states, dtype=torch.float).to(self.device)).data.cpu().numpy()
[INFO] Episode 0 score: 8.6 
[INFO] Episode 1 score: 5.5 
[INFO] Episode 2 score: 3.0 
[INFO] Episode 3 score: 9.6 
[INFO] Episode 4 score: 8.5 
[INFO] Episode 5 score: -1.0 
[INFO] Episode 6 score: 9.0 
[INFO] Episode 7 score: -1.0 
[INFO] Episode 8 score: 6.5 
[INFO] Episode 9 score: 7.6 


Learned Attention scores:

Fig 7a: QK scores comparison between Untrained vs trained agent. 
Difference is computed by doing a square difference per term of the 2-dim tensor.
Fig 7b: OV scores comparison between Untrained vs Trained agent. 
Difference is computed by doing a square difference per term of the 2-dim tensor.
  1. ^
  2. ^

    https://distill.pub/2020/understanding-rl-vision/

  3. ^

    https://transformer-circuits.pub/2021/framework/index.html

0 comments

Comments sorted by top scores.