Induction heads - illustrated

post by CallumMcDougall (TheMcDouglas) · 2023-01-02T15:35:20.550Z · LW · GW · 9 comments

Contents

  TL;DR
  Introduction
  Prerequisites
  Q-composition
None
9 comments

 Many thanks to everyone who provided helpful feedback, particularly Aryan Bhatt and Lawrence Chan!

TL;DR

This is my illustrated walkthrough of induction heads. I created it in order to concisely capture all the information about how the circuit works.

There are 2 versions of the walkthrough:

The final image from version 1 is inline below, and depending on your level of familiarity with transformers, looking at this diagram might provide most of the value of this post. If it doesn't make sense to you, then read on for the full walkthrough, where I build up this diagram bit by bit.

 

 

Introduction

Induction heads are a well-studied and understood circuit in transformers. They allow a model to perform in-context learning, of a very specific form: if a sequence contains a repeated subsequence e.g. of the form A B ... A B (where A and B stand for generic tokens, e.g. the first and last name of a person who doesn't appear in any of the model's training data), then the second time this subsequence occurs the transformer will be able to predict that B follows A. Although this might seem like weirdly specific ability, it turns out that induction circuits are actually a pretty massive deal. They're present even in large models (despite being originally discovered in 2-layer models), they can be linked to macro effects like bumps in loss curves during training, and there is some evidence that induction heads might even constitute the mechanism for the actual majority of all in-context learning in large transformer models.

I think induction heads can be pretty confusing unless you fully understand the internal mechanics, and it's easy to come away from them feeling like you get what's going on without actually being able to explain things down to the precise details. My hope is that these diagrams help people form a more precise understanding of what's actually going on.

Prerequisites

This post is aimed at people who already understand how a transformer is structured (I'd recommend Neel Nanda's tutorial for that), and the core ideas in the Mathematical Framework for Transformer Circuits paper. If you understand everything on this list, it will probably suffice:

Basic concepts of linear algebra (e.g. understanding orthogonal subspaces and the image / rank of linear maps) would be  also be helpful.

Now for the diagram! (You might have to zoom in to read it clearly.)

Note - part of the reason I wrote this is as a companion piece to other material / as a useful thing to refer to while explaining how induction heads work. I'm not totally sure how well it will function as a stand-alone explanation, and I'd be grateful for feedback!


 

 [4]

 

 

Q-composition

Finally, here is a diagram just like the final one above, but which uses Q-composition rather than K-composition. The result is the same, however these heads seem to form less easily than K-composition because they require pointer arithmetic, meaning that they move positional information between tokens and does operations on it, to figure out which tokens to attend to. (although a lot of this is down to architectural details of the transformer[5]).

 

 

  1. ^

    Note that I'm using notation corresponding to the TransformerLens library, not to the Anthropic paper (this is because I'm hoping this post will help people who are actually working with the library). In particular, I'm following the convention that weight matrices multiply on the right. For instance, if  is a vector in the residual stream and  is the query projection matrix then  is the query vector. This is also why the QK circuit is different here than in the Anthropic paper.

  2. ^

    This terminology is also slightly different from the Anthropic paper. The paper  would call  the QK circuit, whereas I'm adopting Neel's notation of calling  the QK circuit and calling something a full circuit if it includes the  or  matrices.

  3. ^

    Again, this is different than the Anthropic paper because of the convention that we're right-multiplying matrices.  is the value vector (of size d_head) and  is the embedding of this vector back into the residual stream. So  is the OV circuit.

  4. ^

    I described subtracting one from the positional embedding as a "rotation". This is because positional embeddings are often sinusoidal (either because they're chosen to be sinusoidal at initialisation, or because they develop some kind of sinusoidal structure as the model trains).

  5. ^

    For example, if you specify shortformer=True when loading in transformers from TransformerLens, this means the positional embeddings aren't added to the residual stream, but only to the inputs to the query and key projection matrices (i.e. not to the the inputs to the value projection matrices ). This means positional information can be used in calculating attention patterns, but can't itself be moved around to different tokens. You can see from the diagram how this makes Q-composition impossible[6] (because the positional encodings need to be moved as part of the OV circuit, in the first attention head).

  6. ^

    That being said, it seems transformers seem to be able to rederive positional information [AF · GW], so they could in theory form induction heads via Q-composition with this rederived information. To my knowledge there's currently no evidence of this happening, but it would be interesting!

9 comments

Comments sorted by top scores.

comment by hold_my_fish · 2023-01-03T00:52:56.912Z · LW(p) · GW(p)

Thanks, the first diagram worked just as suggested: I have enough exposure to transformer internals that a few minutes of staring was enough to understand the algorithm. I'd always wondered why it is that GPT is so strangely good at repetition, and now it makes perfect sense.

Replies from: TheMcDouglas
comment by CallumMcDougall (TheMcDouglas) · 2023-01-03T13:49:43.058Z · LW(p) · GW(p)

Awesome, really glad to hear it was helpful, thanks for commenting!

comment by Perusha Moodley (perusha-moodley) · 2023-06-15T11:47:40.514Z · LW(p) · GW(p)

I'm at the beginning of the MI journey: I read the paper, watched a video and I am working through the notebooks.  I have seen the single diagram version of this before but I needed this post to really help me get a feel for how the subspaces and composition work. I think it works well as a stand-alone document and I feel like it has helped setup some mental scaffolding for the next more detailed steps I need to take. Thank you for this! 

Replies from: TheMcDouglas
comment by CallumMcDougall (TheMcDouglas) · 2023-07-07T11:12:58.451Z · LW(p) · GW(p)

Thanks so much for this comment, I really appreciate it! Glad it was helpful for you 🙂

comment by LawrenceC (LawChan) · 2023-01-03T01:32:58.552Z · LW(p) · GW(p)

This seems like a typo:

I'm adopting Neel's notation of calling WQK the OV circuit

(Surely you meant QK!)

Replies from: TheMcDouglas
comment by CallumMcDougall (TheMcDouglas) · 2023-01-03T13:49:21.036Z · LW(p) · GW(p)

Yep, fixed, thanks!

comment by Review Bot · 2024-08-11T21:45:11.542Z · LW(p) · GW(p)

The LessWrong Review [? · GW] runs every year to select the posts that have most stood the test of time. This post is not yet eligible for review, but will be at the end of 2024. The top fifty or so posts are featured prominently on the site throughout the year.

Hopefully, the review is better than karma at judging enduring value. If we have accurate prediction markets on the review results, maybe we can have better incentives on LessWrong today. Will this post make the top fifty?

comment by shen yue (shen-yue) · 2023-08-26T08:11:21.175Z · LW(p) · GW(p)

Thanks for your hard work. I wonder why in the layer 0 attention head, the positions of the query and value are 1?

Replies from: TheMcDouglas
comment by CallumMcDougall (TheMcDouglas) · 2023-09-07T09:25:47.344Z · LW(p) · GW(p)

Hi, sorry for the late response! The layer 0 attention head should have query at position 1, and value at position 0 (same as key). Which diagram are you referring to?