A Bunch of Matryoshka SAEs
post by chanind, TomasD (tomas-dulka), Adrià Garriga-alonso (rhaps0dy) · 2025-04-04T14:53:56.805Z · LW · GW · 0 commentsContents
I don't care about any of that, just give me the SAEs! Neuronpedia Training info Snap loss Balancing absorption and hedging SAEs and stats Gemma-2-2b Snap Matryoshka SAEs Standard Matryoshka SAEs Gemma-2-9b Gemma-3-1b None No comments
This work was done as part of MATS 7.0.
MATS provides a generous compute stipend, and towards the end of the program we found we had some unspent compute. To let this not go to waste, we trained batch topk Matryoshka SAEs on all residual stream layers of Gemma-2-2b, Gemma-2-9b, and Gemma-3-1b, and are now releasing them publicly. The hyperparams for these SAEs were not aggressively optimized, but they should hopefully be decent. Below we describe our rationale for how these SAEs were trained and why, and the stats for each SAE. Key decisions:
- We use more narrow inner widths than in the original Matryoshka SAEs work, and increase each width by a larger amount. We do this to make it easier to study the highest-frequency features of the model.
- We include standard and snap loss variants for Gemma-2-2b. Snap loss and the rationale behind it is described in our Feature Hedging post [LW · GW]. There is probably not much practical difference between the snap and standard versions of the SAEs.
- We do not stop gradients between Matryoshka layers. We find in toy models that hedging [LW · GW] and absorption [LW · GW] pull the encoder in opposite directions, and this helps moderate the severity of feature hedging in Matryoshka SAE inner layers.
I don't care about any of that, just give me the SAEs!
You can load all of the SAEs using SAELens via the following releases:
- Gemma-2-2b snap loss matryoshka:
gemma-2-2b-res-snap-matryoshka-dc
- Gemma-2-2b standard matryoshka:
gemma-2-2b-res-matryoshka-dc
- Gemma-2-9b standard matryoshka:
gemma-2-9b-res-matryoshka-dc
- Gemma-3-1b standard matroshka:
gemma-3-1b-res-matryoshka-dc
For each release, the SAE ID is just the corresponding Transformer Lens post residual stream hook point, e.g. blocks.5.hook_resid_post
for the layer 5 residual stream SAE.
Each SAE can be loaded in SAELens as follows:
from sae_lens import SAE
sae = SAE.from_pretrained("<release>", "<sae_id>")[0]
For instance, to load the layer 5 snap variant SAE for gemma-2-2b, this would look like the following:
sae = SAE.from_pretrained("gemma-2-2b-res-snap-matryoshka-dc", "blocks.5.hook_resid_post")[0]
Neuronpedia
Neuronpedia has generously hosted some of these SAEs, with more coming soon. Check them out at: https://www.neuronpedia.org/res-matryoshka-dc.
Matryoshka SAEs should be much better than standard SAEs at finding general, high-frequency concepts like parts of speech. In standard SAEs, latents tracking these concepts will get shot to pieces by feature absorption as they co-occur with so many other concepts. As Matryoshka SAEs should be much more resilient to absorption, we thus expect to find more meaningful high-density latents in Matryoshka SAEs (although these latents may be messed up by feature hedging instead). For instance, here's a high-density first-layer latent from layer 12 of Gemma-2-2b, which appears to (very noisily) perform a grammatical function similar to Treebank's IN (Preposition or subordinating conjunction) part of speech.
Higher frequency concepts should be concentrated in earlier latent indices. The highest frequency concepts should be in latents 0-127, then the next highest frequency should be in latents 128-511, etc...
Training info
All Matryoshka SAEs in this release are trained on 750M tokens from the Pile using a modified version of SAELens. The SAEs are all 32k width with the following Matryoshka levels: 128, 512, 2048, 8192, and 32768. We included two layers (128, 512) that are much narrower than the model residual stream to make it easier to study what the first features are that the SAE learns. These are all batch top-k SAEs, following the original Matryoshka SAEs work. We largely did not optimize hyperparams for these SAEs, so it's likely possible to squeeze out more performance from the SAE with optimized choices of learning rate and more training tokens, but hopefully these SAEs should be decent.
Snap loss
One of the notable components of this release is the addition of snap loss variants of all SAEs for Gemma-2-2b. Snap loss is described in our post on Feature Hedging [LW · GW], and involves switching the reconstruction loss of the SAE from MSE to L2 mid-way through training. Practically, we don't see much difference in SAEs trained on LLMs using snap loss, but are releasing these regardless in case others are curious to investigate the effect of snap loss, as we have the SAE trained anyway. If you notice a meaningful difference in practice between the snap loss and standard variants of these SAEs, please let us know!
Balancing absorption and hedging
Intuitively, it might seem like we'd want the inner layers of Matryoshka SAEs to be insulated from gradients from outer layers. Outer layers will pull the inner latents towards absorption, which defeats the purpose of a Matryoshka SAE! However, in toy models, hedging and absorption have opposite effects on the SAE encoder, so allowing some absorption pressure can help counteract the hedging of the SAE and improve performance. We notice that the dictionary_learning implementation of matryoshka SAEs also does not stop gradients between layers, and likely this is because stopping gradients causes hedging to mess up the SAE more severly.
For a further investigation of balancing hedging and absorption in Matryoshka SAEs, check out this colab.
We suspect that it may be possible to intentionally balance hedging with absorption in a more optimal way, and we plan to investigate this in future work.
SAEs and stats
Below we list all the SAEs trained along with some core stats.
Gemma-2-2b
We trained both snap and standard variants of SAEs for Gemma-2-2b. These SAEs have the relase ID gemma-2-2b-res-snap-matryoshka-dc
for snap-loss variant, and gemma-2-2b-res-matryoshka-dc
for the standard variant.
Snap Matryoshka SAEs
layer | SAE ID | width | l0 | explained variance |
---|---|---|---|---|
0 | blocks.0.hook_resid_post | 32768 | 40 | 0.919964 |
1 | blocks.1.hook_resid_post | 32768 | 40 | 0.863969 |
2 | blocks.2.hook_resid_post | 32768 | 40 | 0.858767 |
3 | blocks.3.hook_resid_post | 32768 | 40 | 0.815844 |
4 | blocks.4.hook_resid_post | 32768 | 40 | 0.821094 |
5 | blocks.5.hook_resid_post | 32768 | 40 | 0.797083 |
6 | blocks.6.hook_resid_post | 32768 | 40 | 0.79815 |
7 | blocks.7.hook_resid_post | 32768 | 40 | 0.78946 |
8 | blocks.8.hook_resid_post | 32768 | 40 | 0.779236 |
9 | blocks.9.hook_resid_post | 32768 | 40 | 0.759022 |
10 | blocks.10.hook_resid_post | 32768 | 40 | 0.743998 |
11 | blocks.11.hook_resid_post | 32768 | 40 | 0.731758 |
12 | blocks.12.hook_resid_post | 32768 | 40 | 0.725974 |
13 | blocks.13.hook_resid_post | 32768 | 40 | 0.727936 |
14 | blocks.14.hook_resid_post | 32768 | 40 | 0.727065 |
15 | blocks.15.hook_resid_post | 32768 | 40 | 0.757408 |
16 | blocks.16.hook_resid_post | 32768 | 40 | 0.751874 |
17 | blocks.17.hook_resid_post | 32768 | 40 | 0.763654 |
18 | blocks.18.hook_resid_post | 32768 | 40 | 0.77644 |
19 | blocks.19.hook_resid_post | 32768 | 40 | 0.768622 |
20 | blocks.20.hook_resid_post | 32768 | 40 | 0.761658 |
21 | blocks.21.hook_resid_post | 32768 | 40 | 0.765593 |
22 | blocks.22.hook_resid_post | 32768 | 40 | 0.741098 |
23 | blocks.23.hook_resid_post | 32768 | 40 | 0.729718 |
24 | blocks.24.hook_resid_post | 32768 | 40 | 0.754838 |
Standard Matryoshka SAEs
layer | SAE ID | width | l0 | explained variance |
---|---|---|---|---|
0 | blocks.0.hook_resid_post | 32768 | 40 | 0.91832 |
1 | blocks.1.hook_resid_post | 32768 | 40 | 0.863454 |
2 | blocks.2.hook_resid_post | 32768 | 40 | 0.841324 |
3 | blocks.3.hook_resid_post | 32768 | 40 | 0.814794 |
4 | blocks.4.hook_resid_post | 32768 | 40 | 0.820418 |
5 | blocks.5.hook_resid_post | 32768 | 40 | 0.796252 |
6 | blocks.6.hook_resid_post | 32768 | 40 | 0.797322 |
7 | blocks.7.hook_resid_post | 32768 | 40 | 0.787601 |
8 | blocks.8.hook_resid_post | 32768 | 40 | 0.779433 |
9 | blocks.9.hook_resid_post | 32768 | 40 | 0.75697 |
10 | blocks.10.hook_resid_post | 32768 | 40 | 0.745011 |
11 | blocks.11.hook_resid_post | 32768 | 40 | 0.732177 |
12 | blocks.12.hook_resid_post | 32768 | 40 | 0.726209 |
13 | blocks.13.hook_resid_post | 32768 | 40 | 0.719405 |
14 | blocks.14.hook_resid_post | 32768 | 40 | 0.719056 |
15 | blocks.15.hook_resid_post | 32768 | 40 | 0.756888 |
16 | blocks.16.hook_resid_post | 32768 | 40 | 0.742889 |
17 | blocks.17.hook_resid_post | 32768 | 40 | 0.757294 |
18 | blocks.18.hook_resid_post | 32768 | 40 | 0.76921 |
19 | blocks.19.hook_resid_post | 32768 | 40 | 0.766661 |
20 | blocks.20.hook_resid_post | 32768 | 40 | 0.760939 |
21 | blocks.21.hook_resid_post | 32768 | 40 | 0.759883 |
22 | blocks.22.hook_resid_post | 32768 | 40 | 0.740612 |
23 | blocks.23.hook_resid_post | 32768 | 40 | 0.729678 |
24 | blocks.24.hook_resid_post | 32768 | 40 | 0.747313 |
Gemma-2-9b
These SAEs have the release ID gemma-2-9b-res-matryoshka-dc
.
layer | path | width | l0 | explained variance |
---|---|---|---|---|
0 | blocks.0.hook_resid_post | 32768 | 60 | 0.942129 |
1 | blocks.1.hook_resid_post | 32768 | 60 | 0.900656 |
2 | blocks.2.hook_resid_post | 32768 | 60 | 0.869154 |
3 | blocks.3.hook_resid_post | 32768 | 60 | 0.84077 |
4 | blocks.4.hook_resid_post | 32768 | 60 | 0.816605 |
5 | blocks.5.hook_resid_post | 32768 | 60 | 0.826656 |
6 | blocks.6.hook_resid_post | 32768 | 60 | 0.798281 |
7 | blocks.7.hook_resid_post | 32768 | 60 | 0.796018 |
8 | blocks.8.hook_resid_post | 32768 | 60 | 0.790385 |
9 | blocks.9.hook_resid_post | 32768 | 60 | 0.775052 |
10 | blocks.10.hook_resid_post | 32768 | 60 | 0.756327 |
12 | blocks.12.hook_resid_post | 32768 | 60 | 0.718319 |
13 | blocks.13.hook_resid_post | 32768 | 60 | 0.714065 |
14 | blocks.14.hook_resid_post | 32768 | 60 | 0.709635 |
15 | blocks.15.hook_resid_post | 32768 | 60 | 0.706622 |
16 | blocks.16.hook_resid_post | 32768 | 60 | 0.687879 |
17 | blocks.17.hook_resid_post | 32768 | 60 | 0.695821 |
18 | blocks.18.hook_resid_post | 32768 | 60 | 0.691723 |
19 | blocks.19.hook_resid_post | 32768 | 60 | 0.690914 |
20 | blocks.20.hook_resid_post | 32768 | 60 | 0.684599 |
21 | blocks.21.hook_resid_post | 32768 | 60 | 0.691355 |
22 | blocks.22.hook_resid_post | 32768 | 60 | 0.705531 |
23 | blocks.23.hook_resid_post | 32768 | 60 | 0.702293 |
24 | blocks.24.hook_resid_post | 32768 | 60 | 0.707655 |
25 | blocks.25.hook_resid_post | 32768 | 60 | 0.721022 |
26 | blocks.26.hook_resid_post | 32768 | 60 | 0.721717 |
27 | blocks.27.hook_resid_post | 32768 | 60 | 0.745809 |
28 | blocks.28.hook_resid_post | 32768 | 60 | 0.753267 |
29 | blocks.29.hook_resid_post | 32768 | 60 | 0.76466 |
30 | blocks.30.hook_resid_post | 32768 | 60 | 0.763025 |
31 | blocks.31.hook_resid_post | 32768 | 60 | 0.765932 |
32 | blocks.32.hook_resid_post | 32768 | 60 | 0.760822 |
33 | blocks.33.hook_resid_post | 32768 | 60 | 0.73323 |
34 | blocks.34.hook_resid_post | 32768 | 60 | 0.746912 |
35 | blocks.35.hook_resid_post | 32768 | 60 | 0.738031 |
36 | blocks.36.hook_resid_post | 32768 | 60 | 0.730805 |
37 | blocks.37.hook_resid_post | 32768 | 60 | 0.722875 |
38 | blocks.38.hook_resid_post | 32768 | 60 | 0.715494 |
39 | blocks.39.hook_resid_post | 32768 | 60 | 0.7044 |
40 | blocks.40.hook_resid_post | 32768 | 60 | 0.711277 |
Gemma-3-1b
These SAEs have the release ID gemma-3-1b-res-matryoshka-dc
.
layer | SAE ID | width | l0 | explained variance |
---|---|---|---|---|
0 | blocks.0.hook_resid_post | 32768 | 40 | 0.99118 |
1 | blocks.1.hook_resid_post | 32768 | 40 | 0.985819 |
2 | blocks.2.hook_resid_post | 32768 | 40 | 0.981468 |
3 | blocks.3.hook_resid_post | 32768 | 40 | 0.979252 |
4 | blocks.4.hook_resid_post | 32768 | 40 | 0.973719 |
5 | blocks.5.hook_resid_post | 32768 | 40 | 0.977229 |
6 | blocks.6.hook_resid_post | 32768 | 40 | 0.982247 |
7 | blocks.7.hook_resid_post | 32768 | 40 | 0.989271 |
8 | blocks.8.hook_resid_post | 32768 | 40 | 0.985447 |
9 | blocks.9.hook_resid_post | 32768 | 40 | 0.985869 |
10 | blocks.10.hook_resid_post | 32768 | 40 | 0.98235 |
11 | blocks.11.hook_resid_post | 32768 | 40 | 0.980853 |
12 | blocks.12.hook_resid_post | 32768 | 40 | 0.977682 |
13 | blocks.13.hook_resid_post | 32768 | 40 | 0.969005 |
14 | blocks.14.hook_resid_post | 32768 | 40 | 0.956484 |
15 | blocks.15.hook_resid_post | 32768 | 40 | 0.937399 |
16 | blocks.16.hook_resid_post | 32768 | 40 | 0.928849 |
17 | blocks.17.hook_resid_post | 32768 | 40 | 0.912209 |
18 | blocks.18.hook_resid_post | 32768 | 40 | 0.904198 |
19 | blocks.19.hook_resid_post | 32768 | 40 | 0.895405 |
20 | blocks.20.hook_resid_post | 32768 | 40 | 0.883044 |
21 | blocks.21.hook_resid_post | 32768 | 40 | 0.868396 |
22 | blocks.22.hook_resid_post | 32768 | 40 | 0.831975 |
23 | blocks.23.hook_resid_post | 32768 | 40 | 0.793732 |
24 | blocks.24.hook_resid_post | 32768 | 40 | 0.7452 |
0 comments
Comments sorted by top scores.