graphpatch: a Python Library for Activation Patching

post by Occam's Laser (evan-lloyd) · 2024-06-05T15:08:47.416Z · LW · GW · 2 comments

Contents

  What is graphpatch?
  Why is graphpatch?
  How do I graphpatch?
None
2 comments

This post is an announcement for a software library. It is likely only relevant to those working, or looking to start working, in mechanistic interpretability [LW · GW].


What is graphpatch?

graphpatch is a Python library for activation patching [LW · GW] on arbitrary PyTorch neural network models. It is designed to minimize the amount of boilerplate needed to run experiments making causal interventions on the intermediate activations in a model. It provides an intuitive API based on the structure of a torch.fx.Graph representation compiled automatically from the original model. For a somewhat silly example, I can make Llama play Taboo [LW · GW] by zero-ablating its output for the token representing "Paris":

with patchable_llama.patch(
  {"lm_head.output": ZeroPatch(slice=(slice(None), slice(None), 3681))}
):
  print(
    tokenizer.batch_decode(
      patchable_llama.generate(
        tokenizer(
          "The Eiffel Tower, located in",
          return_tensors="pt"
        ).input_ids,
        max_length=20,
        use_cache=False,
      )
    )
  )

["<s> The Eiffel Tower, located in the heart of the French capital, is the most visited"]

Why is graphpatch?

graphpatch is a tool I wished had existed when I started my descent into madness entry into mechanistic interpretability with an attempt to replicate ROME on Llama. I hope that by reducing inconveniences (trivial [LW · GW] and otherwise) I can both ease entry into the field and lower cognitive overhead for existing researchers. In particular, I want to make it easier to start running experiments on "off-the-shelf" models without the need to handle annoying setup—such as rewriting the model's Python code to expose intermediate activations—before even getting started. Thus, while graphpatch should work equally well on custom-built research models, I focused on integration with the Huggingface ecosystem with:

How do I graphpatch?

graphpatch is available on PyPI and can be installed with pip:

pip install graphpatch

You can read an overview on the GitHub page for the project. Full documentation is available on Read the Docs.

I have also provided a Docker image that might be useful for starting your own mechanistic interpretability experiments on cloud hardware. See this directory for some scripts and notes on my development process, which may be adaptable to your own use case.

2 comments

Comments sorted by top scores.

comment by Charlie Steiner · 2024-06-07T12:07:44.785Z · LW(p) · GW(p)

This seems cool. Could you explain a little more why it's a pain to do ROME with vanilla hooks? I would have expected that, although it would be maybe 2x messier, it wouldn't require creating a custom model.

Replies from: evan-lloyd
comment by Occam's Laser (evan-lloyd) · 2024-06-07T22:22:03.475Z · LW(p) · GW(p)

Thanks! You’re correct that you can implement ROME with vanilla hooks, since these give you access to module inputs in addition to the outputs. But the fact that this works is contingent on both the specific interventions ROME makes and the way Llama/GPT2 happen to be implemented. To get maybe overly concrete, in this line

return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))

ROME wants the result of the multiplication, which isn’t the output of any individual submodule. You happen to be able to access it as the input of down_proj, because that happens to be a module, but it didn’t have to be implemented this way. (This would be even worse if we wanted to patch the value instead of just observing it, since we’d have to patch every consumer, and those would all also have to be modules or we’d be SOL). It's easy to imagine ROME-adjacent experiments that you might want to do that you simply can't with module hooks alone, which bothered me. The TransformerLens answer to this is to wrap everything in a submodule (HookPoint), which works well enough for the models that have already been converted, but struck me as a sufficiently “wrong” approach (hard to maintain, requires upfront work for every new model) that I wrote a library about it :)