You can remove GPT2’s LayerNorm by fine-tuning for an hour

post by StefanHex (Stefan42) · 2024-08-08T18:33:38.803Z · LW · GW · 10 comments

Contents

  Introduction
  Motivation
  Method
  Implementation
  Results
        GPT2:
        GP2_noLN:
    Residual stream norms
  Discussion
    Faithfulness to the original model
    Does the noLN model generalize worse?
  Appendix
    Representing the no-LayerNorm model in GPT2LMHeadModel
    Which order to remove LayerNorms in
      Which kinds of LayerNorms to remove first
      Which layer to remove LayerNorms in first
    Data-reuse and seeds
    Infohazards
    Acknowledgements
None
10 comments

This work was produced at Apollo Research, based on initial research done at MATS.

LayerNorm is annoying for mechanistic interpretability research (“[...] reason #78 for why interpretability researchers hate LayerNorm” – Anthropic, 2023).

Here’s a Hugging Face link to a GPT2-small model without any LayerNorm.

The final model is only slightly worse than a GPT2 with LayerNorm[1]:

DatasetOriginal GPT2Fine-tuned GPT2 with LayerNormFine-tuned GPT without LayerNorm
OpenWebText (ce_loss)3.0952.9893.014 (+0.025)
ThePile (ce_loss)2.8562.8802.926 (+0.046)
HellaSwag (accuracy)29.56%29.82%29.54%

I fine-tuned GPT2-small on OpenWebText while slowly removing its LayerNorm layers, waiting for the loss to go back down after reach removal:

Introduction

LayerNorm (LN) is a component in Transformer models that normalizes embedding vectors to have constant length; specifically it divides the embeddings by their standard deviation taken over the hidden dimension. It was originally introduced to stabilize and speed up training of models (as a replacement for batch normalization). It is active during training and inference.

The equation includes the standard deviation (std)  which makes it a non-linear operation. This hinders interpretability in a variety of ways, from annoyances and inaccuracies such as

In the Docstring circuit [LW · GW] analysis we seriously considered whether the model might be using LN in its algorithm. This post [LW · GW] even shows that LN can be used as the sole non-linearity to solve non-linear classification problems (see also this related work).

Recently, with progress in Sparse Dictionary Learning, agendas (e.g. this one [LW · GW]) imagine decomposing networks into sets of sparsely connected components (SAEs, Transcoders, etc.). A core difficulty to “putting it all together” is that the interactions between different components often route through LayerNorm whose effect we do not understand.

Motivation

It would be pretty neat to have an LLM that still works (speaks English etc.) while less or no LN layers. One option would be to train a model without LN from scratch (done for tiny models, e.g. TinyModel), but this is very hard or impossible for larger models (hearsay is that you need a low learning rate and to be very careful).

Taking an existing model and removing the LN layers however seems doable if LN isn’t implementing some important computation.[3] That is, LN “does its thing” and the model has learned to “deal with it”, but it’s not irreplaceable. A reason to be optimistic is that the spread of standard deviations across different samples isn’t that large [LW · GW], so maybe replacing the LN-computed standard deviation with a fixed number might kinda work.

Method

I take GPT2-small, fine-tune it on OpenWebText, and remove LNs one-by-one while fine-tuning.

The only non-linear operation in a LN layer is the division by the standard deviation (std) of the embedding vectors; the remaining operations can be absorbed into later weight matrices (see the fold_ln option in TransformerLens; also discussed in this appendix). Thus I mainly focus on the std part here.

My general strategy is to “remove” an LN layer (this makes the loss go up), and then to train the model for some time (on the original training data) until the loss is back near the baseline. For this “remove” step I do the following

Whenever I do the replacement the loss jumps up, from the baseline of 3.0 up to 3.5, sometimes even around 5.0. After 10-100 iterations (learning rate 6e-4 and batch size approx. 488 as recommended here) the loss typically goes down to between 3.0 and 3.1. However, if I’m not careful and change too much at once, the loss can jump very high (around 8.0), and in those cases it usually never recovers. Thus I want to avoid making too big of a change at once.

Here’s the recipe I empirically found to work. After every step, train for  50-200 iterations or until the loss is close to baseline.

I considered scaling individual LNs down slowly (e.g. interpolate between the actual calculated std and the average std) but I never ended up needing this, and did not really explore it.

In general I observed that

Implementation

I implement everything based on the NanoGPT repository. I replace the standard deviation calculation in the LN by a fixed number (set to the average standard deviation). This number is fixed, but it is degenerate with the LN scale (self.weight) which is learnable.

std = self.average_std if std_type == "avg" else self.bos_std]
(x - x.mean(-1, keepdim=True)) / std * self.weight + self.bias

I have two different average stds, self.average_std (average of std over all tokens except for position 0) and self.bos_std (average of stds at position 0). Initially, after replacing the real LN with this “dummy LN” I use the following policy for choosing which std to divide by:

The EOT rule in [brackets] is used only for the LN before the attention v vector. Pseudocode for a simplified version of my implementation:

x_v[eot_mask] = self.ln_1(x, std_type="eot")[eot_mask]
x_v[~eot_mask] = self.ln_1(x, std_type="avg")[~eot_mask]

I will upload the full code when I have more time next week (email me if you'd like it earlier).

Results

Well, the models train, the loss is low, and the models still speak English.

v1 model: Manually interrupt & resume training whenever loss went down enough:

v2 model (+ vanilla GPT2 fine-tuning comparison in green) with scheduled LN removals:

I'm currently training a third version employing the gradual removal technique to more of the components, and will update the Hugging Face repository (use the v1 and v2 revisions for models presented here).

Here I compare the cross entropy losses on OpenWebText and ThePile, as well as the score on the HellaSwap benchmark (using the Andrej Karpathy's implementation). I choose two baselines; the original GPT2 model straight from Hugging Face, and a second version that I fine-tuned on OpenWebText for 1000 iterations to be more comparable to the noLN ones (which were fine-tuned on that data).

DatasetOriginal GPT2GPT2 with LNGPT2 noLN v1GPT2 noLN v2
OpenWebText (ce_loss)3.0952.9893.1393.014 (+0.025)
ThePile (ce_loss)2.8562.8803.0572.926 (+0.046)
HellaSwag (accuracy)29.56%29.82%29.17%29.54%

Some sample generations with model.generate:

GPT2:

As the last leaf fell from the tree, John realized that it was empty. He took the leaf and turned it over to his wife, who told him that it was still there and that he would have to go to the church to find it. John went to the church, and found that it was empty. He said, "I am going to the church and I am going to find the rest of the leaves, and I am going to look for them and find out where they

 

As the last leaf fell from the tree, John realized that the tree had been torn down. As he turned his head, the other trees started to fall. "Come on," John said, "we're going to get out of here!" The next tree was a wildflower. "How is it?" John asked, "do you see any other way?" "It's a good thing," the other trees replied.

GP2_noLN:

As the last leaf fell from the tree, John realized that he was going to have to use it. “We’ve all been there. We’ve all been there,” he said. “It’s been a long time. But, it’s good to be back. It’s good to be back.” “It’s good to be back. It’s good to be back

 

As the last leaf fell from the tree, John realized that a large amount of weight had been lifted from him. "I had a little panic attack. I was afraid that I could not walk," he said. "I felt like my legs were going to break." John has since gone back to the tree. "I have to tell you that I'm sorry I did that, but I don't know if that will ever happen," he said.

Residual stream norms

Previously I noticed [LW · GW] that the norm of the residual stream looks very different at position 0 compared to the rest of the model. Does this behavior still exists in the noLN model? No. The noLN model (solid lines) does not show the different in norms we saw in the original models (dashed lines):

Discussion

Faithfulness to the original model

I expect that the new no-LayerNorm (noLN) model will not have the exact same internal mechanisms as the original model. To some extent I expect lots of similarities (as the new model just had an hour to train), but since the norms changed (see above) and the loss changed, I expect differences.

My goal with this model is more like “have a toy model almost as good as GPT2 but easier to interpret” to replace vanilla GPT2 in interpretability research. Models like GPT2 (and Pythia etc.) are useful not because they are the models we ultimately care about (GPT4, GPT5, …) but because they let us generally explore how LLM internals work. To this extent, I don’t mind if GPT2-noLN differs from GPT2.

A question for the future is whether we want to apply this LN removal method before interpreting gpt4 and other “production models”. This depends on how similar the internals are, and I am currently uncertain about this. I am primarily concerned with the earlier use-case.

Does the noLN model generalize worse?

I noticed that the noLN performance hit was worse on The Pile than OpenWebText. This might be a coincidence, but it could also suggest that removing LN hurts generalization. While LN was originally introduced for training stability purposes, it may have a side effect on generalization. I have not evaluated the models on more datasets and leave this question for future research. Edit 30th Aug 2024: On the other hand this makes the OpenWebText and ThePile losses of the no-LN model more similar than those of the original model.

Appendix

Representing the no-LayerNorm model in GPT2LMHeadModel

I replace LayerNorms with DummyLayerNorms that use a fixed std, rather than computing the actual variance of each sample. This is equivalent to removing because the remaining LN operations can be folded into the following layers (e.g. TransformerLens mostly does this; TL does not fold in the centering operations but I do). I perform this folding in for all ln_1 and ln_2 layers. Thus I obtain weights with which GPT2 can run without any LayerNorm layers (all LNs replaced by nn.identity).

To make the model available on Hugging Face without trust_remote_code I want to package it into the GPT2LMHeadModel class. Thus I want to “neuter” the LNs in GPT2LMHeadModel such that the model just works with my “LNs are identities”-weights. I do this by setting ln_eps (epsilon) to a very high value (1e12), and setting the ln weights (gamma) to a corresponding value (1e6). I set the biases to 0. This leaves the centering operation but this doesn’t matter as I also fold a centering operation into the following layers, thus the LNs can be removed without further changes.

There’s one exception to this, the final layer norm. GPT2LMHeadModel uses (a) tied embedding and unembedding weights, and (b) no unembedding bias. Thus it is impossible to fold the final LayerNorm, which includes a (diagonal) weight matrix and a bias, into the other weights here. I still “neuter” the normalizing function of the LayerNorm as above, so ln_final just represents a simple linear layer before the unembedding.

Which order to remove LayerNorms in

There’s two sources of reasons that inform which LayerNorms I want to remove furst

So I expect which order to remove the LNs in matters.

A meta-choice is whether

I went with the first option for ease of implementation, but have not tried the second option. However the second option would seem more principled to me once we understand which order is optimal.

Which kinds of LayerNorms to remove first

I tried out a few combinations, such as first removing ln_f and then removing ln_2 and ln_1, or vice versa. I haven’t done a systematic sweep of all options, and the current method is just what felt right after a couple of tries. It seems to work well enough though.

Here’s an example of removing ln_f first. The loss reaches a very high level, and even after 400 iterations only goes down to 3.138. So it seems this is a worse choice. Not however that in this run I didn’t do “warm up” iterations (training for a couple iterations with LN to reach a good loss on OpenWebText – gpt2 directly loaded from Hugging Face does badly for the first ~10 iterations).

Which layer to remove LayerNorms in first

Here I remove ln2 in the different layers in different orders:

I start removing LNs at iteration 300, and remove another LN every 10 iterations.

The loss differs during the process (this is expected, some LNs are possibly more important than others) but evens out at the end.

Data-reuse and seeds

In my initial tests I used lots of snapshots, and accidentally retrained the model on the same first couple of batches of openwebtext (fixed seed) every time. I have the impression that this worked slightly better than my later full-pipeline runs never re-using data. I might investigate this in the future.

Infohazards

I am not worried that publishing this work accelerates capabilities progress over alignment progress. Because (a) this is a pretty obvious idea, (b) it applies only to inference (not training), and (c) it only speeds up inference by a very small amount (that likely is not even worth the loss increase).

Acknowledgements

Thanks to Bilal Chughtai, Neel Nanda, and Rudolf Laine for comments and feedback on the draft. The nanoGPT repository and accompanying video by Andrej Karpathy were very helpful, allowing me to get a working prototype in a day!

  1. ^

     The GPT2 paper claims a loss of log(16)=2.77 on their training dataset (non-public webtext). I guess that must be an easier dataset. In any case, I fine-tune both models on OpenWebText for a total of 1000 iterations (~500k rows, ~500M tokens) to give a fairer comparison.

  2. ^

     For this case, only the final layer norm matters

  3. ^

     this paper discusses confidence regularization as one possible-important aspect

  4. ^

     I wonder how much of this effect is “just divide by a larger number” vs. actually dividing by the correct average. After all, the position 0 average shouldn’t be a great match for the EOT token average. [In this dataset the first position is not an EOT token. Neel Nanda / TransformerLebs recommends this for short prompts (see here for a discussion) but we don’t do it for the full dataset.]

10 comments

Comments sorted by top scores.

comment by Logan Riggs (elriggs) · 2024-08-08T18:49:28.046Z · LW(p) · GW(p)

This is extremely useful for SAE circuit work. Now the connections between features are at most ReLU(Wx + b) which is quite interpretable! (Excluding attn_in->attn_out)

Thanks for doing this!

comment by starship006 (cody-rushing) · 2024-08-08T23:25:44.315Z · LW(p) · GW(p)

Another reason why layernorm is weird (and a shameless plug): the final layernorm also contributes to self-repair in language models

comment by StefanHex (Stefan42) · 2024-08-08T20:11:03.552Z · LW(p) · GW(p)

Here's a quick snipped to load the model into TransformerLens!

import torch
from transformers import GPT2LMHeadModel
from transformer_lens import HookedTransformer

model = GPT2LMHeadModel.from_pretrained("apollo-research/gpt2_noLN").to("cpu")
hooked_model = HookedTransformer.from_pretrained("gpt2", hf_model=model, fold_ln=False, center_unembed=False).to("cpu")
# Kill the LayerNorms because TransformerLens overwrites eps
for block in hooked_model.blocks:
    block.ln1.eps = 1e12
    block.ln2.eps = 1e12
hooked_model.ln_final.eps = 1e12

# Make sure the outputs are the same
prompt = torch.tensor([1,2,3,4], device="cpu")
logits = hooked_model(prompt)
logits2 = model(prompt).logits

print(logits.shape, logits2.shape)
print(logits[0, 0, :10])
print(logits2[0, :10])
Replies from: Stefan42
comment by StefanHex (Stefan42) · 2024-08-08T21:01:46.583Z · LW(p) · GW(p)

And here's the code to do it with replacing the LayerNorms with identities completely:

import torch
from transformers import GPT2LMHeadModel
from transformer_lens import HookedTransformer

model = GPT2LMHeadModel.from_pretrained("apollo-research/gpt2_noLN").to("cpu")

# Undo my hacky LayerNorm removal
for block in model.transformer.h:
    block.ln_1.weight.data = block.ln_1.weight.data / 1e6
    block.ln_1.eps = 1e-5
    block.ln_2.weight.data = block.ln_2.weight.data / 1e6
    block.ln_2.eps = 1e-5
model.transformer.ln_f.weight.data = model.transformer.ln_f.weight.data / 1e6
model.transformer.ln_f.eps = 1e-5

# Properly replace LayerNorms by Identities
class HookedTransformerNoLN(HookedTransformer):
    def removeLN(self):
        for i in range(len(self.blocks)):
            self.blocks[i].ln1 = torch.nn.Identity()
            self.blocks[i].ln2 = torch.nn.Identity()
        self.ln_final = torch.nn.Identity()

hooked_model = HookedTransformerNoLN.from_pretrained("gpt2", hf_model=model, fold_ln=True, center_unembed=False).to("cpu")
hooked_model.removeLN()

prompt = torch.tensor([1,2,3,4], device="cpu")
logits = hooked_model(prompt)

print(logits.shape)
print(logits[0, 0, :10])
Replies from: John Dunbar, elriggs
comment by Quiche Eater (John Dunbar) · 2024-09-16T06:16:55.261Z · LW(p) · GW(p)

You should also set model.cfg.normalization_type = None afterwards. It's mostly a formality since you're doing it after initialization. ActivationCache.apply_ln_to_stack() is the only function I found which behaves incorrectly if you don't change this.

comment by Logan Riggs (elriggs) · 2024-08-18T04:22:21.853Z · LW(p) · GW(p)

And here's the code to convert it to NNsight (Thanks Caden for writing this awhile ago!)

import torch
from transformers import GPT2LMHeadModel
from transformer_lens import HookedTransformer
from nnsight.models.UnifiedTransformer import UnifiedTransformer


model = GPT2LMHeadModel.from_pretrained("apollo-research/gpt2_noLN").to("cpu")

# Undo my hacky LayerNorm removal
for block in model.transformer.h:
    block.ln_1.weight.data = block.ln_1.weight.data / 1e6
    block.ln_1.eps = 1e-5
    block.ln_2.weight.data = block.ln_2.weight.data / 1e6
    block.ln_2.eps = 1e-5
model.transformer.ln_f.weight.data = model.transformer.ln_f.weight.data / 1e6
model.transformer.ln_f.eps = 1e-5

# Properly replace LayerNorms by Identities
def removeLN(transformer_lens_model):
    for i in range(len(transformer_lens_model.blocks)):
        transformer_lens_model.blocks[i].ln1 = torch.nn.Identity()
        transformer_lens_model.blocks[i].ln2 = torch.nn.Identity()
    transformer_lens_model.ln_final = torch.nn.Identity()

hooked_model = HookedTransformer.from_pretrained("gpt2", hf_model=model, fold_ln=True, center_unembed=False).to("cpu")
removeLN(hooked_model)

model_nnsight = UnifiedTransformer(model="gpt2", hf_model=model, fold_ln=True, center_unembed=False).to("cpu")
removeLN(model_nnsight)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
prompt = torch.tensor([1,2,3,4], device=device)
logits = hooked_model(prompt)

with torch.no_grad(), model_nnsight.trace(prompt) as runner:
    logits2 = model_nnsight.unembed.output.save()
logits, cache = hooked_model.run_with_cache(prompt)
torch.allclose(logits, logits2)
comment by Chris_Leong · 2024-08-09T09:20:06.102Z · LW(p) · GW(p)

Fascinating. I would love to see follow up work on whether it does harm generalisation, because if we were able to train more interpretable models without damaging generalisation, that would be amazing.

I'd love to see other research along these lines. Like what if we could use interpretability to figure out what a circuit does, replace the circuit with something more principled/transparent, then train for a bit longer with the new circuit in place.

comment by CBiddulph (caleb-biddulph) · 2024-08-09T00:44:31.915Z · LW(p) · GW(p)

This is great! Maybe you'd get better results if you "distill" GPT2-LN into GPT2-noLN by fine-tuning on the entire token probability distribution on OpenWebText.

comment by Daniel Tan (dtch1997) · 2024-08-12T09:03:37.351Z · LW(p) · GW(p)

Interesting stuff! I'm very curious as to whether removing layer norm damages the model in some measurable way. 

One thing that comes to mind is that previous work finds that the final LN is responsible for mediating 'confidence' through 'entropy neurons'; if you've trained sufficiently I would expect all of these neurons to not be present anymore, which then raises the question of whether the model still exhibits this kind of self-confidence-regulation

comment by Review Bot · 2024-08-09T09:53:02.416Z · LW(p) · GW(p)

The LessWrong Review [? · GW] runs every year to select the posts that have most stood the test of time. This post is not yet eligible for review, but will be at the end of 2025. The top fifty or so posts are featured prominently on the site throughout the year.

Hopefully, the review is better than karma at judging enduring value. If we have accurate prediction markets on the review results, maybe we can have better incentives on LessWrong today. Will this post make the top fifty?