Creating Interpretable Latent Spaces with Gradient Routing
post by Jacob G-W (g-w1) · 2024-12-14T04:00:17.249Z · LW · GW · 6 commentsThis is a link post for https://jacobgw.com/blog/ml/2024/12/12/interp-latent.html
Contents
How it works Classification None 7 comments
Over the past few months, I helped develop Gradient Routing [LW · GW], a non loss-based method to shape the internals of neural networks. After my team developed it, I realized that I could use the method to do something that I have long wanted to do: make an autoencoder with an extremely interpretable latent space.
I created an MNIST variational autoencoder with a 10 dimensional latent space, with each dimension of the latent space corresponding to a different digit. Before I get into how I did it, feel free to play around with my demo here (it loads the model into the browser): https://jacobgw.com/gradient-routed-vae/.
In the demo, you can both see how a random MNIST image encodes but also directly play around with the encoding itself and create different types of digits by just moving the sliders.
The reconstruction is not that good, and I assume this is due to some combination of (1) using the simplest possible architecture of MLP layers and ReLU (2) only allowing a 10 dimensional latent space which could constrain the representation a lot (3) not doing data augmentation, so it might not generalize that well, and (4) gradient routing targeting an unnatural internal representation, causing the autoencoder to not fit the data that well. This was just supposed to be a fun proof of concept project, so I’m not too worried about the reconstruction not being that good here.
How it works
My implementation of gradient routing is super simple and easy to add onto a variational autoencoder. During training, after I run the encoder, I just detach every dimension of the encoding except for the one corresponding to the label of the image:
def encode_and_mask(self, images: Tensor, labels: Tensor):
encoded_unmasked, zeta, mean, cov_diag = self.encode(images)
mask = F.one_hot(labels, num_classes=self.latent_size).float()
encoded = mask * encoded_unmasked + (1 - mask) * encoded_unmasked.detach()
return encoded, zeta, mean, cov_diag
This causes each dimension of the latent space to “specialize” to representing its corresponding image since the error for that image type can only be propagated through the single dimension of the latent space.
It turns out that if you do this, nothing forces the model to represent “more of a digit” in the positive direction. Sometimes the model represented “5-ness” in the negative direction in the latent space (e.g. as [0, 0, 0, 0, 0, -1.0, 0, 0, 0, 0]
). This messed with my demo a bit since I wanted all the sliders to only go in the positive direction. My solution? Just apply ReLU the encoding so it can only represent positive numbers! This is obviously not practical and I only included it so the demo would look nice.[1]
In our Gradient Routing paper, we found that models sometimes needed regularization to split the representations well. However, in this setting, I’m not applying any regularization besides the default regularization that comes with a variational autoencoder. I guess it turns out that this regularization is enough to effectively split the digits.
Classification
It turns out that even though there was no loss function causing the encoding to activate most strongly on the dimension corresponding to the digit being encoded, it happened! In fact, we can classify digits to 92.58% accuracy by just taking the argmax over the encoding (which slider is the most positive), which I find pretty amazing.
You can see the code here.
- ^
I did have to train the model a few times to get something that behaved nicely enough for the demo.
6 comments
Comments sorted by top scores.
comment by James Camacho (james-camacho) · 2024-12-14T19:15:44.631Z · LW(p) · GW(p)
If you're not already aware of the information bottleneck, I'd recommend The Information Bottleneck Method, Efficient Compression in Color Naming and its Evolution, and Direct Validation of the Information Bottleneck Principle for Deep Nets. You can use this with routing for forward training.
EDIT: Probably wasn't super clear why you should look into this. An optimal autoencoder should try to maximize the mutual information between the encoding and the original image. You wouldn't even need to train a decoder at the same time as the encoder! But, unfortunately, it's pretty expensive to even approximate the mutual information. Maybe, if you route to different neurons based on image captions, you could significantly decrease this cost.
comment by Daniel Tan (dtch1997) · 2024-12-14T19:35:46.281Z · LW(p) · GW(p)
Is this surprising for you, given that you’ve applied the label for the MNIST classes already to obtain the interpretable latent dimensions?
It seems like this process didn’t yield any new information - we knew there was structure in the dataset, imposed that structure in the training objective, and then observed that structure in the model
Replies from: g-w1↑ comment by Jacob G-W (g-w1) · 2024-12-14T21:48:03.844Z · LW(p) · GW(p)
I didn't impose any structure in the objective/loss function relating to the label. The loss function is just the regular VAE loss. All I did was detach the gradients in some places. So it is a bit surprising to me that this simple of a modification can cause the internals to specialize in this way. After I had seen gradient routing work in other experiments, I predicted that it would work here, but I don't think gradient routing working was a priori obvious (meaning that I would get zero new information by running an experiment since I predicted it with p=1).
Replies from: dtch1997, dtch1997↑ comment by Daniel Tan (dtch1997) · 2024-12-14T22:41:46.174Z · LW(p) · GW(p)
I agree, but the point I’m making is that you had to know the labels in order to know where to detach the gradient. So it’s kind of like making something interpretable by imposing your interpretation on it, which I feel is tautological
For the record I’m excited by gradient routing, and I don’t want to come across as a downer, but this application doesn’t compel me
Edit: Here’s an intuition pump. Would you be similarly excited by having 10 different autoencoders which each reconstruct a single digit, then stitching them together into a single global autoencoder? Because conceptually that seems like what you’re doing
Replies from: g-w1↑ comment by Jacob G-W (g-w1) · 2024-12-14T23:23:18.466Z · LW(p) · GW(p)
I disagree that this is the same as just stitching together different autoencoders. Presumably the encoder has some shared computation before specializing at the encoding level. I also don't see how you could use 10 different autoencoders to classify an image from the encodings. I guess you could just look at the reconstruction loss and then the autoencoder which got the lowest loss would probably correspond to the label, but that seems different to what I'm doing. However, I agree that this application is not useful. I shared it because I (and others) thought it was cool. It's not really practical at all. Hope this addresses your question :)
Replies from: dtch1997↑ comment by Daniel Tan (dtch1997) · 2024-12-14T23:40:46.323Z · LW(p) · GW(p)
I see, if you disagree w the characterization then I’ve likely misunderstood what you were doing in this post, in which case I no longer endorse the above statements. Thanks for clarifying!