Transformers Explained (Again)

post by RohanS · 2024-10-22T04:06:33.646Z · LW · GW · 0 comments

Contents

  Introduction
  Inputs and Outputs
  Exploring Model Internals
    Preface: Loss function and grouping components
    Tracing the path of an input through a transformer
  Summary
  References
None
No comments

I actually wrote this in Spring 2023. I didn't post then because I had trouble converting this Google doc to any other format; the Lesswrong gdoc import feature made that easy. :) 

Introduction

This is yet another post about the Transformer neural network architecture, written in large part for my own benefit. There are many other resources for understanding transformers, and you may be better off using one of them. However, I will emphasize some of the things that I did not fully understand after reading a few posts and watching a few videos about transformers, so this could potentially still be useful to people. I filled in a lot of the remaining gaps by watching this video by Neel Nanda and writing all the code to follow along with it myself (though that took a while). The video focuses on a GPT-2 style (decoder-only) transformer, and that is also what this post will focus on.

Inputs and Outputs

This section doesn’t say anything about the internals of the transformer architecture, but it contains a lot of the information relevant to transformer models that took me the longest to figure out.

Exploring Model Internals

Preface: Loss function and grouping components

Note that nothing described in the previous section involves learning parameters. That’s because we’ve only discussed the inputs and the outputs of the model, not its internal structure. We’ve already had to do quite a bit of manipulation on both ends, but the model details are yet to come.

We start with a randomly initialized model with a structure we have imposed, and then we tune parameters within that structure to perform a task well. This post is trying to elucidate “the structure we have imposed,” because that is what the Transformer architecture is. I find it useful to take careful note of what things receive random initializations, because those things do not intrinsically serve a certain purpose - instead, they learn to serve a certain purpose because of the way they are used.

Let’s focus on the loss function used to train GPT-2 in order to understand the task it is trained to perform. It is an autoregressive language model, which means that given a sequence of tokens, it predicts the next token. This is sometimes called “causal language modeling.” The reason for this is something like “only tokens before the token-to-be-predicted can have a causal influence on the model’s prediction.” An alternative is “masked language modeling,” which involves selecting a token to insert into a missing space in a sequence of tokens. BERT is a popular model trained on a masked language modeling objective.

The way (or at least, one way) to train an autoregressive language model is as follows. Start with all parameters of the model randomly initialized. Run the model on an input string where you know what the next token should be. Apply a softmax to the logits vector it produces, so we have the probability the model assigns to each possible next token. Extract the probability assigned to the correct next token. Take the logarithm of this probability, then negate it. This is the loss associated with this model output. Use backpropagation and stochastic gradient descent (SGD) to train the model using this loss.

More concisely: Loss = -log(softmax(logits)[correct_token_index])

What’s going on here? Minimizing the loss is equivalent to maximizing the probability that the model assigns to the correct next token, since the loss is equal to the negative log probability assigned and log is monotonically increasing (i.e. if you increase x, log(x) also increases). I believe the log is included in the loss because that makes computing derivatives of the loss (with respect to the model parameters) tractable, which is necessary for training the model using SGD and backpropagation.

Now that we understand the loss function, we know that the trainable parameters of a GPT-2 style transformer model are randomly initialized and then slowly adjusted to perform the task of next-token prediction more effectively.

Before finally seeing what happens to an input as it passes through a transformer, there is one meta-level point about understanding this that I want to make. Part of the difficulty of understanding transformer models is in figuring out how “zoomed in” you are supposed to be - that is, how many parts of the model do you have to pay attention to at the same time in order to understand what’s going on? Which ones have functions that are detached enough from each other that you can look at them separately? Here is a quick walkthrough of the components of a transformer model with my answers to these questions - everything will be explained in more depth in the next section.

(Read this image from bottom to top, and use the text below to clarify things.)

It’s fine if not everything here makes sense yet. Hopefully reading the next section will make things much clearer, and you can refer back to this overview to check if it is starting to make sense. This overview is the collection of high-level concepts that I want you to store as a compressed understanding of transformers at the end of this post, but you need to see the low-level details in order for the high-level summary to really make sense.

Tracing the path of an input through a transformer

Summary

The following is roughly the compressed understanding of transformers that I store in my head - even more compressed than the overview above.

I hope this guide was helpful!

References

  1. Implementing GPT-2 From Scratch (Transformer Walkthrough Part 2/2)
  2. Tokenizer - OpenAI API
  3. Overview of Large Language Models: From Transformer Architecture to Prompt Engineering
  4. A Mathematical Framework for Transformer Circuits
  5. But what is a neural network? | Chapter 1, Deep learning
  6. Key Query Value Attention Explained

0 comments

Comments sorted by top scores.