Posts

Takeaways From Our Recent Work on SAE Probing 2025-03-03T19:50:16.692Z
SAE Probing: What is it good for? 2024-11-01T19:23:55.418Z

Comments

Comment by Josh Engels (JoshEngels) on StefanHex's Shortform · 2025-02-07T01:33:46.623Z · LW · GW

I was having trouble reproducing your results on Pythia, and was only able to get 60% variance explained. I may have tracked it down: I think you may be computing FVU incorrectly. 

https://gist.github.com/Stefan-Heimersheim/ff1d3b92add92a29602b411b9cd76cec#file-clustering_pythia-py-L309

I think FVU is correctly computed by subtracting the mean from each dimension when computing the denominator. See the SAEBench impl here:

https://github.com/adamkarvonen/SAEBench/blob/5204b4822c66a838d9c9221640308e7c23eda00a/sae_bench/evals/core/main.py#L566

When I used your FVU implementation, I got 72% variance explained; this is still less than you, but much closer, so I think this might be causing the improvement over the SAEBench numbers.

In general I think SAEs with low k should be at least as good as k means clustering, and if it's not I'm a little bit suspicious (when I tried this first on GPT-2 it seemed that a TopK SAE trained with k = 4 did about as well as k means clustering with the nonlinear argmax encoder).

Here's my clustering code: https://github.com/JoshEngels/CheckClustering/blob/main/clustering.py
 

Comment by Josh Engels (JoshEngels) on StefanHex's Shortform · 2025-02-07T00:39:58.703Z · LW · GW

I just tried to replicate this on GPT-2 with expansion factor 4 (so total number of centroids = 768 * 4). I get that clustering recovers  ~87% fraction of variance explained, while a k = 32 SAE gets more like 95% variance explained. I did the nonlinear version of finding nearest neighbors when using k means to give k means the biggest advantage possible, and did k-means clustering on points using the FAISS clustering library. 

Definitely take this with a grain of salt, I'm going to look through my code and see if I can reproduce your results on pythia too, and if so try on a larger model to. Code: https://github.com/JoshEngels/CheckClustering/tree/main

Comment by Josh Engels (JoshEngels) on StefanHex's Shortform · 2025-02-06T20:37:14.233Z · LW · GW

What do you mean you’re encoding/decoding like normal but using the k means vectors? Shouldn’t the SAE training process for a top k SAE with k = 1 find these vectors then? 

In general I’m a bit skeptical that clustering will work as well on larger models, my impression is that most small models have pretty token level features which might be pretty clusterable with k=1, but for larger models many activations may belong to multiple “clusters”, which you need dictionary learning for.