Sparsity and interpretability?

post by Ada Böhm (spirali), RobertKirk, Tomáš Gavenčiak (tomas-gavenciak) · 2020-06-01T13:25:46.557Z · LW · GW · 3 comments

Contents

    Assumptions and motivation
  Introduction [Optional]
  Introduction of the Experiment
  Initial results
  Making it clearer
  Visualising behaviour on specific inputs
  Pruning once more
  Domains other than MNIST
  Conclusion
    References
None
3 comments

A series of interactive visualisations of small-scale experiments to generate intuitions for the question: "Are sparse models more interpretable? If so, why?"

Assumptions and motivation

Many of the visualisations here are interactive. Due to technical forum limitations, the interactive versions are on an external page here and are linked to from every visualisation snapshot in this post.

Introduction [Optional]

We want to understand under which circumstances (if any) a sparser model is more interpretable. A mathematical formulation of “transparent or interpretable” is currently not available: sparsity has been proposed as a possible candidate proxy for these properties, and so we want to investigate how well the links between sparsity and interpretability holds, and what kinds of sparsity aid in what kinds of interpretability.

(A note on the field of Interpretability: we’ve talked a bit previously about what interpretability is, as the topic is broad and the term is very underdefined. It’s important to note, when discussing whether a particular property (like sparsity) makes a model more or less interpretable, we need to specify what kind of interpretability method we’re using, as different methods will benefit from different properties of the model.)

We have several intuitions that sparsity in neural networks might imply greater interpretability, and some which point in the opposite direction:

In this post we investigate a specific instantiation of an interpretability method which intuitively seems to benefit greatly from sparsity, and discuss ways in which we can make the most of these benefits while not changing the network to be too sparse and lose performance

Introduction of the Experiment

To test whether sparsity is good for interpretability, the obvious thing to do is to get a sparse neural network, and see whether it's easier to interpret. To do it we need following things: a neural network, a pruning method to produce a sparse network, and an interpretability method to apply to this sparse network.

Neural network. We will use a simple network of the following architecture trained on MNIST (2 hidden layers of size 32 with ReLU activations; in total 26432 weights in kernels):

MNIST Network

This architecture gets 95+% accuracy on the MNIST test set after three epochs of training.

Pruning method. The lottery ticket hypothesis [2] claims that neural networks pruned, reset to their initial training weights, and retrained, and that this retrained network will often learn faster, attain better asymptotic performance and generalise better than the original network. The pruning method used in this paper is relatively simple:

Note that in our experiment we are pruning only only kernels, not biases. As there are only 74 weights in biases, this isn't a consequential number of parameters compared to 27432, and it makes the implementation simpler.

Gaining insight. Having matrices with many zeros is not in itself helpful. We need to use an interpretation method to take advantage of this sparsity. Because we hope for quite a sparse representation at the end, chose to visualize the network as it's flowgraph.

To demonstrate how this works, we'll use the following computation:

This can be represented as the following flowgraph:

Flowgraph

Initial results

When we apply the lottery ticket pruning method on all layers, we can prune the network to 2% of its original size (leaving only 539 weights) and still achieve 90.5% accuracy on test set. For more detail on the number of weights per layer see the following figure:

Pruned network weights

Now we can apply our flow-graph visualisation on the pruned network. The output is a graph with 963 vertices and 1173 edges. Note that for this figure and all following, we are rounding values to two decimal places.

Visualisation one

Making it clearer

As we can see, the flowgraph is quite still quite dense, and it's difficult to extract any insight into how the network behaves from it. To make it clearer, we can exploit the following observations: Many vertices and corresponding edges are from the input layer to the first hidden layer. We can detect dot product operations in the first layer and simply visualize weights as a 2D image, because the first layer is directly connected to the pixels. For example, instead of the following dot product:

(where is an input pixel at position ) we can instead use the following image:

Example dotproduct visualisation

Colour intensity is used to show absolute value; green is used for positive values; red for negative, and grey is the zero.

A further simplification is to display only subgraph of the flowgraph that is relevant for each output separately. While it loses information about what is reused between outputs, it provides more readable graphs.

Combining these two simplifications, we get the following visualisation:

Visualisation two

Here we can observe that some filters on the input are relatively interpretable, and it is clear what shapes they detect:

Filter 1 Filter 2 Filter 3

On the other hand, some of them look more like random pixel detections:

Filter 4 Filter 5 Filter 6

The computation after the initial layer still remains relatively dense and hard to understand. It is now possible to track the flow of individual values through the graph in most places, but we can do better.

Visualising behaviour on specific inputs

To get more insight, we can also visualize flowgraphs on individual inputs. This allows us to visualize values flowing between nodes; the thickness of lines corresponds to the absolute value of the scalar value produced by a source operation while green lines carry positive values, red lines carry negative values, and grey lines are exactly 0.

Also, for each visualized dot product, we show the result of applying this filter to the input image (that is, the result of the dot product multiplication of the input and the weights). This is shown as the right-hand figure next to the dot product filter visualisations at the top of the flowgraph.

The following figure shows this visualisation for six random images from the test set (the values in square brackets in the nodes are the actual values that flow through the graph):

Visualisation three

The figure shows that it relatively trackable to see which part of the network contributes to the final decision. For example, let us see how the probability that the input is digit 0 (output_0) is computed. For the digit 1 input, it is clear that the decision to not classify it as a 0 is heavily influenced by one specific dot product that detects pixels in the middle of the figure. Another observation is that there are two main branches in the flow that plays an important role of classifying digit 0 correctly (i.e. digit 0 and output_0). It mixes together a detector of a bottom pixels and detection of some specific pixels. For the sixth input (digit 6) is used, it is not classified as 0 even though some of detectors are positively triggered (similarly to case with input 0); however, the detector of the central pixels overweights and produce a negative score.

Pruning once more

Let us revisit the idea of pruning. We have pruned the network as much as possible to get a more understandable object then a fully connected network. However, there is space for trade-offs; we don't need to prune each layer equally, and could trade off pruning some layers more to prune other layers less.

Specifically, in our case we have a visualization of the first layer who's interpretability is not influenced as much by the sparsity. This allows us to completely disable pruning of the first layer. With this approach we get the following network flowgraph that has 93.5% accuracy on the test set:

New Pruned Architecture

The network has 95% of its original weights, but last two kernels are several times smaller than in the previous case. This gives us a significantly simpler flowgraph. It can be now easily shown in one diagram for each outputs without loosing clarity. The following figure shows resulting flow graphs applied on six figures from the testing test.

Visualisation four

One interesting observation is that this network returns a constant value for output_2. This means that the digit 2 is a kind of reference point; an input is classified as digit 2 if other outputs are sufficiently suppressed.

We can again easily track down properties like digit 1 is not classified as 0 because of the pixel in the middle part. In this version, it is less clear what each dot products detects, but on the other hand it very straightforward what is happening with each dot product value in the rest of the computation.

Domains other than MNIST

The approach of pruning and visualization of flow graphs may be also applied to other domains such as reinforcement learning for continuous control. We use the classic OpenAI CartPole-v0 environment. The following figure shows a network trained with DQN, and then pruned when reaching maximal score, on the CartPole environment. For this network, we can observe that one input (i_1, which is the cart velocity) is completely ignored, while the agent still gains the maximal score.

Visualisation five

On the other hand, our approach doesn't seem useful for convolutional neural networks, since applying the convolutional filters over the whole output of the previous layer is a part of the algorithm and it is not pruned with weights. Therefore, a convolutional filter even with a single weight may produce many connections in the flow graph. The presented approach still produces dense flow graphs even for heavily pruned convolutional networks.

The following figure shows a network created as 3 convolutional layers (6 channels, 3x3 filters, without padding, kernel sizes: 54, 324, 324) followed by a dense layer (kernel size: 29040, no activation). The layers was pruned to 6 (11%), 23 (7%), 23 (7%), and 135 (0.46%) of their original sizes. The resulting network has only 80% accuracy and quite a low number of weights, but still produces a relatively large and dense graphs. The graph for computing output_7 is the largest one (1659 vertices).

Visualisation six

Conclusion

While we may not have conclusively answered the question posed in the title, we have generated insight and intuitions with this simple example. As expected, pruning dense object may lead to more interpretable objects; however, there are trade-offs. Only counting the number of weights in the final object is probably not the ideal metric to get a most interpretable result. In our experiment, the later flowgraphs (of dense nets) where more interpretable despite having many more weights than the previous examples. This points to an answer to the opposing views of sparsity possibly giving both more and less interpretability; In some settings sparsity is beneficial in giving humans understanding and transparency, but when a visualisation which doesn't rely on sparsity can still produce results that humans can understand and interpret, the sparsity isn't necessarily useful.

Acknowledgement: This text was created as a part of AISRP.

References

Adversarial Robustness as a Prior for Learned Representations [PDF] Engstrom, L., Ilyas, A., Santurkar, S., Tsipras, D., Tran, B. and Madry, A., 2019.

The Lottery Ticket Hypothesis: Training Pruned Neural Networks  [PDF] Frankle, J. and Carbin, M., 2018. CoRR, Vol abs/1803.03635.

3 comments

Comments sorted by top scores.

comment by Charlie Steiner · 2020-06-03T01:40:58.538Z · LW(p) · GW(p)

I feel like this is trying to apply a neural network where the problem specification says "please train a decision tree." Even when you are fine with part of the NN not being sparse, it seems like you're just using the gradient descent training as an elaborate img2vec method.

Maybe the idea is that you think a decision tree is too restrictive, and you want to allow more weightings and nonlinearities? Still, it seems like if you can specify from the top down what operations are "interpretable," this will give you some tree-like structure that can be trained in a specialized way.

comment by Rohin Shah (rohinmshah) · 2020-06-18T04:46:49.357Z · LW(p) · GW(p)

Planned summary for the Alignment Newsletter:

If you want to visualize exactly what a neural network is doing, one approach is to visualize the entire computation graph of multiplies, additions, and nonlinearities. While this is extremely complex even on MNIST, we can make it much simpler by making the networks _sparse_, since any zero weights can be removed from the computation graph. Previous work has shown that we can remove well over 95% of weights from a model without degrading accuracy too much, so the authors do this to make the computation graph easier to understand.
They use this to visualize an MLP model for classifying MNIST digits, and for a DQN agent trained to play Cartpole. In the MNIST case, the computation graph can be drastically simplified by visualizing the first layer of the net as a list of 2D images, where the kth activation is given by the dot product of the 2D image with the input image. This deals with the vast majority of the weights in the neural net.

Planned opinion:

This method has the nice property that it visualizes exactly what the neural net is doing -- it isn’t “rationalizing” an explanation, or eliding potentially important details. It is possible to gain interesting insights about the model: for example, the logit for digit 2 is always -2.39, implying that everything else is computed relative to -2.39. Looking at the images for digit 7, it seems like the model strongly believes that sevens must have the top few rows of pixels be blank, which I found a bit surprising. (I chose to look at the digit 7 somewhat arbitrarily.)
Of course, since the technique doesn’t throw away any information about the model, it becomes very complicated very quickly, and wouldn’t scale to larger models.
comment by ErickBall · 2020-06-10T17:51:34.988Z · LW(p) · GW(p)

I like this visualization tool. There are some very interesting things going on here when you look into the details of the network in the second-to-last MNIST figure. One is that it seems to mostly identify each digit by ruling out the others. For instance, the first two dot-product boxes (on the lower left) could be described as "not-a-0" detectors, and will give a positive result if they detect pixels in the center, near the corners, or at the extreme edges. The next two boxes could be loosely called "not-a-9" detectors (though they also contribute to the 0 and 4 classifications) and the three after that are "not-a-4" detectors. (The use of ReLU makes this functionally different from having a "4" detector with a negative weight.)

Now, take a look at the two boxes that contribute to output 1 (I would call these "not-a-1" detectors). If you select the "7" input and see how those two boxes respond to it, they both react pretty strongly to the very bottom of the 7 (the fact that it's near the edge) and that's what allows the network to distinguish a 7 from a 1. Intuitively, this seems like a shared feature--so why is the network so sure that anything near the bottom cannot be part of a 1?

It looks to me like it's taking advantage of the way the MNIST images are preprocessed, with each digit's center-of-mass translated to the center of the image. Because the 7 has more of its mass near the top, its lower extremity can reach farther from the center. The 1, on the other hand, is not top-heavy or bottom-heavy, so it won't be translated by any significant amount in preprocessing and its extremities can't get near the edges.

The same thing happens with the "not-a-3" detector box when you select input 2. The "not-a-3" detector triggers quite strongly because of the tail that stretches out to the right. That area could never be occupied by a 3, because the 3 has most of its pixel mass near its right edge and will be translated left to get its center of mass centered in the image. The "7" detector (an exception to the pattern of ruling digits out) mostly identifies a 7 by the fact that it does not have any pixels near the top of the image (and to a lesser extent, does not have pixels in the lower-right corner).

What does this pattern tell us? First, that a different preprocessing technique (centering a bounding box in the image, for instance, instead of the digit's center of mass) would require a very different strategy. I don't know off hand what it would look like--maybe there's some trick for making the problems equivalent, maybe not. Second, that it can succeed without noticing most of what humans would consider the key features of these digits. For the most part it doesn't need to know the difference between straight lines and curved lines, or whether the parts connect the way they're supposed to, or even whether lines are horizontal or vertical. It can use simple cues like how far each digit extends in different directions from its center of mass. Maybe with so few layers (and no convolution) it has to use those simple cues.

As far as interpretability, this seems difficult to generalize to non-visual data, since humans won't intuitively grasp what's going on as easily. But it certainly seems worthwhile to explore ideas for how it could work.