Interpreting autonomous driving agents with attention based architecture
post by Manav Dahra (manav-dahra) · 2025-02-01T23:20:27.162Z · LW · GW · 0 commentsContents
Abstract Introduction Environment State Actions & Rewards: Agent: Attention layers: Methods Replication of agent behaviour Analysis on Embedding matrices: Notes: QK & OV circuit analysis of Transformer block Study on activation patterns of an episode run: Notes: Feature importance Notes: Interpretation: Step 1: Input Features Step 2: Attention Mechanism (QK Circuits) Step 3: Q-Value Calculation (OV Circuits) Discussion Future work: Acknowledgements: Appendix: Glossary: Model architecture: Environment configuration: Evaluation: None 1 comment
Abstract
In this experiment, I study the behaviour of a Deep Q-Network agent with attention based architecture driving a vehicle in a simulated traffic environment. The agent learns to cross an intersection without traffic lights and under a dense traffic setting while avoiding collisions with other vehicles.
As first part of this experiment, I first train the agent using the attention based architecture. Later, I study the behaviour of the agent by applying some interpretability techniques on the trained Q-network and find that there is some evidence to show that network comprises of 3 different layers serving specific functions, namely - sensory (embedding layers), processing (attention layer) & motor (output layers).
The purpose of this experiment is to gain deeper understanding of the agent from interpretability perspective which may be used for developing safer agents in real world applications.
Introduction
With increasing usage and deployment of autonomous driving vehicles on roads, it is important that the behaviour of these autonomous agents is thoroughly tested, understood and analysed from a safety perspective.
To measure safety, one of the traditional approaches involves running large number of experiments in a simulated environment and collecting statistics on the number of failures cases. While this is certainly useful and gives an overall perspective on the failures modes of the agent, however it does not say anything about the specificity of those failures.
One may prefer to delve deeper and study why a particular agent failed and understand if that behaviour was a consequence of agent's actions or a conditioning of that environment. This calls upon applying some of the interpretability techniques on the agent's behaviour and derive more specific conclusions on what features the agent senses from the environment and what decisions it takes.
For this experiment, I study the behaviour of a trained agent by applying some interpretability techniques on the policy network of the model and share my observations and conclusions derived from the experiment.
The agent under study is only trained and deployed in a simulated environment (with enough simplifications) and is far from a real world setting and it's complexities. While this does not really represent the behaviour of the agent in the real world, I still think a study like this can be worthwhile in providing some insights on what kind of decision making process is learned by the agent and how it can be used to make agents more safer.
Now, to give some context on the problem, let's understand how
- The agent and the environment interplay
- The model architecture of the agent
Then I will share my observations and insights.
Environment
The environment used in this experiment is an Intersection-env which is a customised gymnasium type that has agent-environment loop.
The environment setting contains total N=15
vehicles at any given point in time.
The agent in question is controlling the green vehicle while the blue vehicles are simulated by traffic flow model which in this case is controlled by intelligent driver model. The intelligent driver model is less nuanced and lacks any complex behaviour in comparison to the agent in question. The blue vehicles are spawned at random points initially.
And below is the animated image of a trained agent crossing the intersection.
State
The joint observation of a road traffic with one agent denoted - s0
and other vehicles - N
is described by a combined list of individual vehicle states:
where
Individual values of each state variables are described as follows:
Feature | Description |
---|---|
presence | Disambiguate agents at 0 offset from non-existent agents. |
x | World offset of agent vehicle or offset to agent vehicle on the x axis. |
y | World offset of agent vehicle or offset to agent vehicle on the y axis. |
vx | Velocity on the x axis of vehicle. |
vy | Velocity on the y axis of vehicle. |
heading | Heading of vehicle in radians. |
cosh | Trigonometric heading of vehicle. |
sinh | Trigonometric heading of vehicle. |
The vehicle kinematics are described by Kinematic Bicycle Model. More on this topic can be found here
Actions & Rewards:
The agent drives the vehicle by controlling its speed chosen from a finite set of actions A = {SLOWER, NO-OP, FASTER}
.
Rewards:
Reward | Action |
---|---|
1 | Agent driving at maximum velocity |
-5 | On collision |
0 | Otherwise |
Agent:
The agent used in this experiment uses a DQN algorithm with attention based architecture which was first proposed in the paper on - Social Attention for Autonomous Decision-Making in Dense Traffic [1].
For this experiment, I delve deeper on the agent's policy network since that network encodes the decision making of the agent.
Here's how the network looks like:
Layer Name | Dimensions |
---|---|
Ego & Others embedding layer - 0 | 7x64 |
Ego & Others embedding layer - 1 | 64x64 |
Attention Layer Query | 64x64 |
Attention Layer Key | 64x64 |
Attention Layer Value | 64x64 |
Output Layer - 0 | 64x64 |
Output Layer - 1 | 64x64 |
Output Layer - predict | 64x3 |
Embedding layers:
It is composed of several linear identical encoders, a stack of attention heads, and a linear decoder.
There are two embedding layers:
1. Ego embedding - dedicated to tracking features for vehicle driven by the agent itself.
2. Others embedding - dedicated to tracking features for vehicles driven by other agents.
Attention layers:
Essentially, a single query Q = [] and a set of keys K = [] are emitted by doing linear projections on the state of the environment. Here, N is the number of vehicles including the agent's vehicle.
The outputs from all heads are finally combined with a linear layer, and the resulting tensor is then added to the residual networks.
Authors of the paper claim that an agent with the proposed attention architecture shows increased performance gains in autonomous decision making under a dense traffic setting. Their study involved comparing the performance of the agent against common architectures like FCN and CNN. The social interaction patterns with other vehicles were visualised and studied qualitatively.
Methods
As part of this experimentation, I first replicate the behaviour of the agent as described by the authors of the paper. Then, I proceed to study the agent's behaviour by borrowing some well recognised interpretability techniques in the literature like Understanding RL Vision[2] and A Mathematical framework for Transformer Circuits[3].
After collecting the observations, I derive some key insights on the behaviour of the agent and mention some interesting directions for work in future.
List of techniques applied in this experiment are as follows -
- Replication of agent behaviour
- Analysis on Embedding matrices
- QK & OV circuit analysis of Transformer block
- Study on activation patterns of an episode run
Analysis on feature importance in reference to output layer
Source code for training the agent and details on choice of hyper-parameters, model architecture parameters and extended results are in the Appendix section.
Replication of agent behaviour
First step is to replicate the studies of the paper by training and evaluating the agent in dense traffic setting having single intersection on the road without any traffic lights.
According to my observations, I confirm that the agent successfully learns to cross the intersection while avoiding collisions with other vehicles in most scenarios.
Following clip shows the trained agent navigating through the intersection along with it's attention patterns for the time step when the agent decides to slow down noticing another vehicle in the way.
Notes:
Above results confirm that the agent learns to navigates through the crossing avoiding collisions in most scenarios, by paying attention to the other vehicles in on the crossing. Attention to other vehicles at every time step of the episode are highlighted by thick coloured lines from green to blue vehicles.
Analysis on Embedding matrices:
First I analyse what insights I can find in the weights of embedding matrices. Since the embedding layers consist of 2 hidden layers, namely, being the first layer with dimension 7x64 and being the second layer with dimension 64x64.
If embedding matrix weights are computed as follows:
then, dimension is 7x64
Checking the embedding matrix for both layers individually reveal the following:
Notes:
- Ego embedding matrix - Higher weight values assigned on features
y
andcosh
- Others embedding matrix - Higher weight values assigned on features
x
,y
,vy
andsinh
- Heat maps show other features are assigned relatively lesser weights.
This shows us that sensory function of the model is picking up some interesting signals from the environment.
Next I analyse how these features interact with the attention layers of the network.
QK & OV circuit analysis of Transformer block
According to the research done by Anthropic team, they outline a mathematical framework of understanding attention layers.
Quoting from the paper:
Attention heads can be understood as having two largely independent computations: a QK (“query-key”) circuit which computes the attention pattern, and an OV (“output-value”) circuit which computes how each token affects the output if attended to.
Here, I study any emerging QK and OV circuit patterns in the attention layers. To study the emergence of any learned structure, I compare it with the Untrained vs Trained network and find their squared difference measures.
As seen from the code, The agent in this experiment has only 1 layer with 4 attention heads whose vectors are first slice in 4 parts, computed and then later concatenated. Hence, for simplicity, I computed QK and OV circuits for the combined matrix instead of slicing it in 4 parts.
where,
where,
QK and OV circuits:
Furthermore, attention scores and output value matrices are shown in Fig 9a and 9b (in appendix section), show some interesting learned structural patterns.
Above figures show distinctions between the two scenarios, one where the agent is untrained and the other where the agent is trained. The QK and OV Circuits show high activations for for certain lines and areas in the heat map indicating a learned structure/pattern from the interplay between the agent, other vehicles and the environment.
Notes:
- Fig 7 -
- Agent is attending to it's own
y-coordinate
with a strong positive correlation. - While
vy
is being attended to with a negative correlation. - This suggests the agent uses its own velocity and heading to assess risk when making decisions.
- Agent is attending to it's own
- Fig 8 -
- Strong positive attention between Ego embedding
y
coordinate and Other vehiclesx
,y
coordinates andvy
velocity. Hinting at the possibility of computing a distance metric. - Strong negative attention between Ego embedding
y
coordinate feature and Others embeddingpresence
feature. - Above indicate that the model has learned to focus on other vehicles' movement to adjust its own strategy.
- Strong positive attention between Ego embedding
- Fig 9 -
- Shows strong activations for velocity (
vx
,vy
) and heading (cos_h
,sin_h
), meaning the model values its own speed and direction when choosing actions. - Presence and position (
x
,y
) of other vehicles become crucial. The agent learns to consider the positions of other vehicles when deciding whether to slow down, idle, or speed up.
- Shows strong activations for velocity (
Study on activation patterns of an episode run:
Let's study the agent's activations per time step and intermediate attention matrices collected over the full episode run.
Here, I extract environment frames and the activations of Attention head vs Vehicle for each time step of the evaluation shows the following:
Notes:
- Vertical lines of activations show that a particular vehicle is being attended to by all 4 attention heads in varying degrees.
Vehicle_0
is the green vehicle which is controlled by the agent and hence is always being attended to by all 4 attention heads at all time steps.- For other vehicles the activations increase as they appear closer to the intersection.
Next, I study output layers further to understand what insights I can draw from there.
Feature importance
In this section I try to understand which features does the model learn to extract from the environment state. One of the common techniques for finding feature importance is that of computing Integrated gradients. The integrated gradients give an understanding of overall importance of features.
For this scenario, I compare the integrated gradients between the Untrained vs Trained networks averaged over 30 episodes and found the following.
Notes:
Q-net's output layer makes final decision by looking into presence
, x-y
coordinates and sinh
features. All the above graphs pile up more evidence to the previous notes/observations I gathered earlier.
In following section, I make some speculative interpretations about the agent. I would love to validate some of those interpretations by conducting more thorough experiments in the future. For the claims that I am not confident, I have marked them inline.
Interpretation:
Following is a walkthrough of the agent in action along with the interpretation on key observations/notes collected so far.
Step 1: Input Features
The agent observes its environment using the following features:
- Ego Features
vx
,vy
: Ego vehicle’s velocity (speed in x and y directions).cos_h
,sin_h
: Ego vehicle’s heading direction.x
,y
: Ego vehicle’s position.
- Others Features
presence
: Indicates whether another vehicle is nearby.vx
,vy
: Other vehicle’s velocity.x
,y
: Other vehicle’s position.
Step 2: Attention Mechanism (QK Circuits)
The Query-Key (QK) circuits determine which features should be attended to when making a decision.
- Ego Embedding (Self-awareness)
- The agent queries its own
vx
,vy
,cos_h
, andsin_h
to understand its motion. - It attends to (
y
) to determine its position in the intersection. (whyx
coordinate does not have high activations ? Something that I would like to find more later.)
- The agent queries its own
- Others Embedding (Awareness of other vehicles)
- The agent queries
presence
to check if another vehicle is nearby. - It attends to
vx
,vy
,x
,y
of other vehicles to predict their movement. - If a vehicle is approaching, the attention on
presence
andvx
increases. (unverified claim, is model computing some distance metric ? can we verify this ?)
- The agent queries
- When does the agent choose "Slow"?
- If the presence of other vehicles (
presence
feature) is high. - If relative velocity (
vx
,vy
) suggests a collision risk. - Strong correlations in Others embedding QK circuit show that the agent reacts to nearby vehicles. (unverified claim, same reason as above)
- If the presence of other vehicles (
- When does the agent choose "Idle"?
- Likely in neutral situations where no immediate action is needed.
- OV circuits show that
cos_h
andsin_h
influence the decision, meaning the agent aligns with road orientation.
- When does the agent choose "Fast"?
- If there is clear space ahead (low presence activation in Others embedding).
- Trained OV circuits show positive activations for
vx
andsin_h
, meaning the model prefers accelerating when aligned with the road.
Step 3: Q-Value Calculation (OV Circuits)
The Output-Value (OV) circuits determine how much each feature contributes to the Q-values for each action.
Feature Contributions to Actions:
- If
presence
is high, the Slow action gets higher weight. - If
vx
of the other car is high, meaning it is moving fast toward the intersection, the agent reduces Q-values for Fast action. - If the agent’s own
vx
is high, but a collision is possible, the Q-value for Idle is also reduced.
Discussion
This experiment shows that a DQN agent with attention based mechanism can learn to cross a road intersection environment under a dense traffic setting with reasonable levels of safety.
Additionally, analysis on attention layers of the agent's Q-network show that there is sufficient evidence to believe that these layers learn some high level Q-policies that drive the decision making of the agent. Although, it was possible to find some high level policies, more work is needed to find how different policies combine together to form a concrete algorithm.
It was shown to some extent, that the agent learned to delegates different types of functions to it's embedding, attention and output layers. it is fair to say that the embedding, attention and output layers of the agent learn to serve the functions of sensory, processing and motor neurons.
Future work:
This experiment was limited in scope and timing (up to 4 weeks). For this reason, I chose to focus on replicating the behaviour of the agent and running various types of interpretability techniques to narrow down on a promising approach of finding exact behaviour of the agent in further research.
Following are some of the areas that can be explored in future:
- Does agent compute a distance metric from the features
(x, y)
coordinates of the other vehicles ? - Do changing the y coordinate of intersection in the environment break the agent's decision making ? Has the agent really generalised or simply memorised ?
- Enlist different policies learned by the agent on same action. Example, Slowing down - high activations of
presence
feature and high activations ofvx
feature of other vehicles. Do these two policies correlate highly for the agent or largely stay independent ? - Train model on more attention heads and layers with more episodes, repeat the experiments. Do we get any new insights ?
Acknowledgements:
My sincere thanks to this amazing community who have made Interpretability research easily accessible reachable to general public. I hope that my experiments bring some value to others and to this community. I look forward to delve deeper in this topic, any support & guidance is highly appreciated.
I would also like to thank BlueDot impact for running a 12 week online course on AI Safety fundamentals. I conducted this experiment as part of the project submission phase of this course and I am grateful to their course facilitators and their team for conducting amazing sessions and providing a comprehensive list of resources on the key topics.
I'm looking forward to collaborating. Reach out to me on
My portfolio
LinkedIn
Appendix:
Glossary:
DQN: Deep Q-Network
FCN: Fully Convolutional Net
CNN: Convolutional Neural Net
QK Circuit: Query-Key Circuit
OV Circuit: Output-Value Circuit
Model architecture:
EgoAttentionNetwork(
(ego_embedding): MultiLayerPerceptron(
(layers): ModuleList(
(0): Linear(in_features=7, out_features=64, bias=True)
(1): Linear(in_features=64, out_features=64, bias=True)
)
)
(others_embedding): MultiLayerPerceptron(
(layers): ModuleList(
(0): Linear(in_features=7, out_features=64, bias=True)
(1): Linear(in_features=64, out_features=64, bias=True)
)
)
(attention_layer): EgoAttention(
(value_all): Linear(in_features=64, out_features=64, bias=False)
(key_all): Linear(in_features=64, out_features=64, bias=False)
(query_ego): Linear(in_features=64, out_features=64, bias=False)
(attention_combine): Linear(in_features=64, out_features=64, bias=False)
)
(output_layer): MultiLayerPerceptron(
(layers): ModuleList(
(0-1): 2 x Linear(in_features=64, out_features=64, bias=True)
)
(predict): Linear(in_features=64, out_features=3, bias=True)
)
)
Environment configuration:
N: number of vehicles
Observations type: Kinematics
Observation space: 7
where
Action space: 3 {SLOWER, NO-OP, FASTER}
Hyper-parameters:
Gamma: 0.95
Replay buffer size: 15000
Batch size: 64
Exploration strategy: Epsilon greedy
Tau: 15000
Initial temperature: 1.0
Final temperature: 0.05
Evaluation:
Running evaluation over 10 episodes with display enabled shows high scores and successful navigation through the intersection.
2025-02-01 10:14:16.251 Python[52966:5211290] +[IMKClient subclass]: chose IMKClient_Modern
2025-02-01 10:14:16.251 Python[52966:5211290] +[IMKInputSession subclass]: chose IMKInputSession_Modern
/Users/mdahra/workspace/machine-learning/rl-interp/.venv/lib/python3.12/site-packages/rl_agents/agents/deep_q_network/pytorch.py:80: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/utils/tensor_new.cpp:281.)
return self.value_net(torch.tensor(states, dtype=torch.float).to(self.device)).data.cpu().numpy()
[INFO] Episode 0 score: 8.6
[INFO] Episode 1 score: 5.5
[INFO] Episode 2 score: 3.0
[INFO] Episode 3 score: 9.6
[INFO] Episode 4 score: 8.5
[INFO] Episode 5 score: -1.0
[INFO] Episode 6 score: 9.0
[INFO] Episode 7 score: -1.0
[INFO] Episode 8 score: 6.5
[INFO] Episode 9 score: 7.6
Learned Attention scores:
- ^
- ^
https://distill.pub/2020/understanding-rl-vision/
- ^
https://transformer-circuits.pub/2021/framework/index.html
0 comments
Comments sorted by top scores.