Arrakis - A toolkit to conduct, track and visualize mechanistic interpretability experiments.
post by Yash Srivastava (yash-srivastava) · 2024-07-17T02:02:50.492Z · LW · GW · 2 commentsContents
Introduction What is Arrakis and Why do we need it? Key Differentiator Walkthrough Install Arrakis Create a HookedAutoModel from HookedAutoConfig Setup Interpretability Bench Create you experiments Visualize the Results Conclusion and Future Work None 2 comments
Introduction
“The greatest enemy of knowledge is not ignorance, it is the illusion of knowledge.” -Daniel J. Boorstin
Understanding how we think is a question that has perplexed us for a long time. There have been countless theories, many thought experiments, and a lot of experiments to try to unravel the mysteries of the brain. With time, we have become more aware of our brain’s working, and in my honest, and a little biased opinion, Artificial Intelligence has come the closest to model our mystery organ.
This is one of the reasons why Interpretability as a field makes a lot sense to me. It tries to unravel the inner working of one of the most successful proxy of human brain - Large Language Model. Mechanistic Interpretability is on the approach in AI alignment to reverse engineer neural networks and understand the inner workings.
Although this field is really exciting(and challenging), researchers have made quite a significant progress in coming up with hypothesis and conducting experiments to prove their validity. Heavyweights like Anthropic, Google Deepmind, EleutherAI, Leap Labs and many open source organizations have been pushing the boundaries of knowledge in this field, and many researchers are pivoting to Interpretability to advance this field. With recent works breakthroughs in the field such as Dictionary Learning by the Anthropic team, there is a lot of scope in MI
What is Arrakis and Why do we need it?
Arrakis is one such tool which I've been making from the past four months(even longer if I include the ideation period) for the community to provide a means for researchers working in the MI to quickly iterate on their ideas and run experiments. Arrakis is a complete suite to conduct MI experiments. It is still very much in it’s infancy, and I hope to evolve this project with the support of the community.
Key Differentiator
The speed at which a researcher innovates is limited by its iteration speed.
The life of a research engineer mostly consists of coming up with ideas and testing them at different orders of scale and one the bottlenecks of working in that field is not iteration time. The more you work on conducting experiments and testing different hypothesis, the more you’ll realize that the biggest roadblock for any project of such kind is iteration speed. Arrakis is made so that this doesn’t happen. The core principle behind Arrakis is decomposability. Arrakis provides 10+ plug and play tools(plus the ability to make your own) to do common experiments in MI. Other than that, there are several other features such as version control, model profiling and logging provided as add-ons. This makes experimentation really flexible, and at the same time, reduces the iteration time. Everything in Arrakis is made in this plug and play fashion. Read more about the available tool here.
Walkthrough
Let us look at how Arrakis works in practice. Let's first install the package.
Install Arrakis
Arrakis is available as a package directly from PyPI. Here's how to install it to start working on conducting MI experiments.
pip install arrakis-mi
Create a HookedAutoModel
from HookedAutoConfig
HookedAutoModel
is a wrapper around the Huggingface PreTrainedModel
, and the only difference between them is a single decorator on the forward function. Everything just works out of the box - and the functionality can be removed without affecting the model just by removing the decorator. First, create a HookedConfig
for the model you want to support with the required parameters. Then, create a HookedAutoModel
from the config. As of now, these models are supported :
gpt2, gpt-neo, gpt-neox, llama,gemma,phi3,qwen2, mistral, stable-lm
Here's what the code for creating a model looks like :
from arrakis.src.core_arrakis.activation_cache import *
config = HookedAutoConfig(name="llama",
vocab_size=50256,
hidden_size=8,
intermediate_size=2,
num_hidden_layers=4,
num_attention_heads=4,
num_key_value_heads=4)
model = HookedAutoModel(config)
Setup Interpretability Bench
Interpretability bench is the workspace where researchers can conduct experiments. The base bench comes pre-equipped with a lot of tools that helps in keeping track of experiments. Just derive from the BaseInterpretabilityBench
and instantiate an object(exp
in this case). This object provides a lot of function out of the box based on the “tool” you want to use for the experiment, and have access to the functions that the tool provides. You can also create your own tool(read about that here )
from arrakis.src.core_arrakis.base_bench import BaseInterpretabilityBench
class MIExperiment(BaseInterpretabilityBench):
def __init__(self, model, save_dir="experiments"):
super().__init__(model, save_dir)
self.tools.update({"custom": CustomFunction(model)})
exp = MIExperiment(model)
Apart from access to MI tools, the object also provides you a convenient way to log your experiments. To log your experiments, just decorate the function you are working with @exp.log_experiment
, and that is pretty much it. The function creates a local version control on the contents of the function, the arguments and the results and stores it locally in a JSON file. You can run many things in parallel, and the version control helps you keep track of it.
# Step1: Create a function where you can do operations on the model.
@exp.log_experiment # This is pretty much it. This will log the experiment.
def attention_experiment():
print("This is a placeholder for the experiment. Use as is.")
return 4
# Step 2: Then, you run the function, get results. This starts the experiment.
attention_experiment()
# Step 3: Then, we will look at some of the things that logs keep a track of
l = exp.list_versions("attention_experiment") # This gives the hash of the content of the experiment.
print("This is the version hash of the experiment: ", l)
# Step 4: You can also get the content of the experiment from the saved json.
print(exp.get_version("attention_experiment", l[0])['source']) # This gives the content of the experiment.
# Apart from these tools, there are also `@exp.profile_model`(to profile how # much resources the model is using) and `@exp.test_hypothesis`(to test hypothesis). Support of more tools will be added as I get more feedback from the community.
Create you experiments
By default, Arrakis provides a lot of Anthropic’s Interpretability experiments as tools(Monosemanticity, Read Write Analysis, Feature Visualization and a lot more) and also EleutherAI(Model Surgery). These are provided as tools, so in your experiments, you can plug and play with them and conduct your experiments. Here’s an example of how you can do that :
# Making functions for Arrakis to use is pretty easy. Let's look it in action.
# Step 1: Create a function where you can do operations on the model. Think of all the tools you might need for it.
# Step 2: Use the @exp.use_tools decorator on it, with additional arg of the tool.
# Step 3: The extra argument gives you access to the function. Done.
@exp.use_tools("write_read") # use the `exp.use_tools()` decorator.
def read_write_analysis(read_layer_idx, write_layer_idx, src_idx, write_read=None): # pass an additional argument.
# Multi-hop attention (write-read)
# use the extra argument as a tool.
write_heads = write_read.identify_write_heads(read_layer_idx)
read_heads = write_read.identify_read_heads(write_layer_idx, dim_idx=src_idx)
return {
"write_heads": write_heads,
"read_heads": read_heads
}
print(read_write_analysis(0, 1, 0)) # Perfecto!
Visualize the Results
Generating plots is Arrakis is also plug and play, just add the decorator and plots are generated by default. The plan is to incorporate a library similar to CircuitsVis with implementation of commonly used function. As of now, the work in this area is pending a lot, and I want to work on it based on the feedback I receive from the community. Currently, support for a Matplotlib wrapper is there by default, and more functions will be added as this project evolves.
from arrakis.src.graph.base_graph import *
# Step 1: Create a function where you can want to draw plot.
# Step2: Use the @exp.plot_results decorator on it(set the plotting lib), with additional arg of the plot spec. Pass input_ids here as well(have to think on this)
# Step3: The extra argument gives you access to the fig. Done.
exp.set_plotting_lib(MatplotlibWrapper) # Set the plotting library.
@exp.plot_results(PlotSpec(plot_type = "attention", data_keys = "h.1.attn.c_attn"), input_ids=input_ids) # use the `exp.plot_results()` decorator.
def attention_heatmap(fig=None): # pass an additional argument.
return fig
attention_heatmap() # Done.
plt.show()
That's the complete walkthrough of Arrakis.
Conclusion and Future Work
Making the initial version of Arrakis was a long process, it took me around 4-5 months to go from ideation to here. Arrakis is not made to rival existing libraries like TransformerLens , nnsight or Garcon (all of which I admire and use a lot), but is a different approach to streamline the work that is being done in a MI research project. Conducting original research in this domain is my dream, and Arrakis is my way to learn more about it.
That being said, a lot of work is to be done in this project, and till now, I've received a lot of positive feedback from the community, and I would love to contribute more to it. Here are the links associated with the projects if you want to check it out and contribute(or give feedback) :
- Project Name : Arrakis.
- Project Description : Arrakis is a library to conduct, track and visualize mechanistic interpretability experiments.
- Project Home : Github Repo
- Project Documentation : Read the Docs
- PyPI Home : PyPi Home
- Twitter : Twitter Thread
2 comments
Comments sorted by top scores.
comment by Daniel Tan (dtch1997) · 2024-07-17T08:48:48.107Z · LW(p) · GW(p)
Really interesting! I'm a big proponent of improving the standards of infrastructure in the mech interp community.
Some questions:
- Have you used other things like TransformerLens and NNsight and found those to be insufficient in some way? Your library seems to diverge fundamentally from both of those implementations (pytorch hooks in the former case and "proxy variables" in the latter case). I'm curious about the motivating use case here.
- Do you have examples of reproducing specific mech interp analyses using your library? E.g. Neel Nanda's Indirect Object Identification tutorial, or other simple things like doing activation patching / logit lens.
↑ comment by Yash Srivastava (yash-srivastava) · 2024-07-24T07:10:31.167Z · LW(p) · GW(p)
Thanks a lot for the read. To answer your question :
1. I am a regular user of Transformer Lens(not so much of NNSight), and one the things that bugged me a lot is lack of abstractions to do common operations (ablations, head compositions, model surgery etc) and thought of just implementing it. In terms of architecture, what I've planned is to have a similar outline to Meta's Hydra - where you run your experiments from config files, and the library does the grunt work. I'm still open to ideas and have been in talking about it with people from OS community.
2. In my docs, I have included example usage of all the tools that are working as of now(for supported models). There are example usage for common attention operations (merging/ablating heads) removing /permuting layers others such as sparsity analysis, polysemantic scores. I will try to push more heavy tutorials such as IOI ones in the near future.