Posts
Comments
Regarding some features not being learnt at all, I was anticipating this might happen when some features activate much more rarely than others, potentially incentivising SAEs to learn more common combinations instead of some of the rarer features. In order to potentially see this we'd need to experiment with more variations as mentioned in my other comment
Nice work! I was actually planning on doing something along these lines and still have some things I'd like to try.
Interestingly your SAEs appear to be generally failing to even find optimal solutions w.r.t the training objective. For example in your first experiment with perfectly correlated features I think the optimal solution in terms of reconstruction loss and L1 loss combined (regardless of the choice of the L1 loss weighting) would have the learnt feature directions (decoder weights) pointing perfectly diagonally. It looks like very few of your hyperparameter combinations even came close to this solution.
My post was concerned primarily with the training objective being misaligned with what we really want, but here we're seeing an additional problem of SAEs struggling to even optimise for the training objective. I'm wondering though if this might be largely/entirely a result of the extremely low dimensionality and therefore very few parameters causing them to get stuck in local minima. I'm interested to see what happens with more dimensions and more variation in terms of true feature frequency, true feature correlations, and dictionary size. And orthogonality loss may have more impact in some of those cases.
Nice, that's promising! It would also be interesting to see how those peaks are affected when you retrain the SAE both on the same target model and on different target models.
Thanks, that's very interesting!
Testing it with Pythia-70M and few enough features to permit the naive calculation sounds like a great approach to start with.
Closest neighbour rather than average over all sounds sensible. I'm not certain what you mean by unique vs non-unique. If you're referring to situations where there may be several equally close closest neighbours then I think we can just take the mean cos-sim of those neighbours, so they all impact on the loss but the magnitude of the loss stays within the normal range.
Only on features that activate also sounds sensible, but the decoder weights of neurons that didn't activate would need to be allowed to update if they were the closest neighbours for neurons that did activate. Otherwise we could get situations where e.g. one neuron (neuron A) has encoder and decoder weights both pointing in sensible directions to capture a feature, but another neuron has decoder weights aligned with neuron A but has encoder weights occupying a remote region of activation space and thus rarely activates, causing its decoder weights to remain in that direction blocking neuron A if we don't allow it to update.
Yes I think we want to penalise high cos-sim more. The modified sigmoid flattens out as x->1 but the I think the purple function below does what we want.
Training with a negative orthogonality regulariser could be an option. I think vanilla SAEs already have plenty of geometrically aligned features (e.g. see @jacobcd52 's comment below). Depending on the purpose, another option to intentionally generate feature combinatorics could be to simply add together some of the features learnt by a vanilla SAE. If the individual features weren't combinations then their sums certainly would be.
I'll be very interested to see results and am happy to help with interpreting them etc. Also more than happy to have a look at any code.
Thanks for clarifying! Indeed the encoder weights here would be orthogonal. But I'm suggesting applying the orthogonality regularisation to the decoder weights which would not be orthogonal in this case.
Thanks, I mentioned this as a potential way forward for tackling quadratic complexity in my edit at the end of the post.
Regarding achieving perfect reconstruction and perfect sparsity in the limit, I was also thinking along those lines i.e. in the limit you could have a single neuron in the sparse layer for every possible input direction. However please correct me if I’m wrong but assuming the SAE has only one hidden layer then I don't think you could prevent neurons from activating for nearby input directions (unless all input directions had equal magnitude), so you'd end up with many neurons activating for any given input and thus imperfect sparsity.
Otherwise mostly agreed. Though as discussed, as well as making it necessary to figure out how to break apart feature combinations (as you said), feature splitting would also seem to incur the risk of less common “true features” not being represented even within combinations so those would get missed entirely.
My bad! Yes since that's just one batch it does indeed come out as quadratic overall. I'll have a think about more efficient methods
This looks interesting. I'm having a difficult time understanding the results though. It would be great to see a more detailed write up!
yeah I was thinking abs(cos_sim(x,x'))
I'm not sure what you're getting at regarding the inhibitory weights as the image link is broken
If n is the number of feature we're trying to discover and m is the number of features in each batch, then I'm thinking the naive approach is O(n^2) while the batch approach would be O(m^2 + mn). Still quadratic in m, but we would have m<<n
Even for a fairly small target model we might want to discover e.g. 100K features and and the input vectors might be e.g. 768D. That's a lot of work to compute that matrix!
Thanks! Yeah I think those steps make sense for the iterative process, but I'm not sure if you're proposing that would tackle the problem of feature combinations by itself? I'm still imagining it would require orthogonality regularisation with some weighting