Exploring toy neural nets under node removal. Section 1.
post by Donald Hobson (donald-hobson) · 2022-04-13T23:30:40.012Z · LW · GW · 7 commentsContents
Introduction Training code Video of interactive heatmap Plotting the Loss Generating the data Probability density functions of loss Probability density functions of log loss Gradients of the loss Is that node on How much does N flips matter Heatmap of Gradient None 7 comments
Introduction
This post is a long and graph heavy exploration of a tiny toy neural network.
Suppose you have some very small neural network. For some inexplicable reason, you want to make it even smaller. Can we understand how the network behaviour changes when some nodes are deleted?
Training code
The network, and how it was trained.
#This module trains a small neural network to determine if a point is inside or outside a circle, and then saves that model to a file.
#First import some libraries.
import tensorflow as tf
import numpy as np
from math import *
import matplotlib.pyplot as plt
from itertools import *
#This function accepts n, and returns a generated dataset of n datapoints. Each individual datapoint consists of 2 float inputs, and one bool output.
#Actually the bool is stored as the integers 0,1 because the tensorflow library was written for the general case of catagorization into any number of catagories.
#The x input is uniformly sampled from the -1,1 square.
#The y output is 1 if the point is outside the circle centred on the origin with radius sqrt(2/pi). This radius was chosen so exactly half the points
# (on average) will lie within the circle
def generate(n):
x=np.random.uniform(-1,1,size=[n,2])
y=(np.sum(x*x,1)>sqrt(2/pi)).astype(int)
return x,y
#The number of hidden layer neurons
N=20
#The keras model accepts an input for x
inputs = tf.keras.Input(shape=(2,))
#In order to make removing some neurons easier, the network structure accepts a list of on_nodes. These will be 1 for any node considered to be turned on, and 0 for any node considered off.
#This value will be set to a constant block of ones during training. When the network is evaluated it is usually on many neurons at once. The input here will generally be all 0's and 1's, despite this
#input accepting any floating point values. The input will also generally consist of BATCH_SIZE repititions of the same vector. Ie within any one batch evaluation, the set of neurons that are off is generally fixed.
#Another advantage of this setup is it lets us take gradients of the performance with respect to these on_nodes.
on_nodes = tf.keras.Input(shape=(N,))
#A single hidden layer containing 20 relu nodes.
hidden_layer = tf.keras.layers.Dense(N, activation=tf.nn.relu)(inputs)
#multiplying by on_nodes to turn off any nodes we might want off.
hidden_layer_picked=hidden_layer*on_nodes
#an output layer of 2 neurons, containing probabilities assigned to each possibility.
#this goes through a softmax.
outputs = tf.keras.layers.Dense(2, activation=tf.nn.softmax)(x_picked)
model = tf.keras.Model(inputs=[inputs, on_nodes], outputs=outputs)
#compiling the model. The choice of optimiser and loss was because these seemed like standardish sensible choices, and are often used on bigger nets.
model.compile(
optimizer=tf.keras.optimizers.Adam(0.01),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)
BATCH=1000
#a whole lot of 1. Guess what this is used for. Thats right on_nodes. All the nodes are on during training.
ones=np.ones([BATCH,N])
x,y=generate(BATCH)
#why 20 epochs, not more or less. Fiddling about until it seemed to converge in practice, thats why.
model.fit([x,ones],y,epochs=20)
#keras lets you save your models with a single API call. How convenient.
model.save('circle_test')
Video of interactive heatmap
Given that the model only has 2 inputs, we can plot the model to get a good idea of what it is doing. The network was only trained on inputs within the square, but out of distribution behaviour can be interesting, so it is visualized on the square. The circle which the network is trained to approximate is shown. In the upper plot, yellow represents the networks certainty that the point is inside the circle, and navy blue, as certainty that a point is outside the circle. The colour scheme shown was chosen by whoever chose the defaults on the library.
Key takeaways:
- With all the nodes on, the network accurately approximates a circle
- Network performance varies a lot depending on which neurons in particular are removed
- Flipping one neuron can have a significant effect, but doesn't usually completely change the network behaviour.
- The network generally performs better with more neurons.
Plotting the Loss
Lets analyse the loss (as measured by sparse categorical cross-entropy)
We can let each node exist independently with a probability , and plot the probability density function of the resulting.
Generating the data
The following code runs the network with a random subset of the nodes removed, and saves the losses to file.
import Network_loader_neat as net
import numpy as np
from random_utils_neat import *
import pickle
#20
N=net.N
#because 17 is the most random number. ;-)
np.random.seed(17)
data={}
for prob in [0.25,0.5,0.75]:
result_list=[]
for i in range(1000):
on_nodes=(np.random.rand(N)<prob).astype(int)
#uses a fixed random seed and so a fixed test dataset.
#within the net.test function.
#while also still allowing on_nodes to be different every time
with RandomStateHolder(seed=23):
loss,grad=net.test(on_nodes)
result_list.append((loss,grad,on_nodes))
data[prob]=result_list
print("Finished calculating values for prob=%s"%prob)
with open("generated_data_1",mode="wb") as file:
pickle.dump(data,file)
Probability density functions of loss
And then the result is loaded and plotted to create the plot below. Note that the lines are smoothed using a Gaussian kernel of width 0.03. This value is enough to make sure the graph isn't unreadably squiggly, but leaves the structure of the PDF.
Some small wiggles remain.
Prob | Mean loss |
0.25 | 1.0277158 |
0.5 | 0.87821335 |
0.75 | 0.584365 |
Notes:
- The small local wiggles are probably just random noise.
- The grey line represents a loss of 0.117, the networks score with no neurons missing.
- The order of the green, orange and blue lines show that networks with more nodes missing generally have higher loss.
- The green line touches the grey line, this is partly because in a 1000 samples, 5 of them happen to contain all the nodes. Not that unexpected given .
- If only the node 14 is missing, then the loss is 0.119, a barely noticeable difference. This occurs 3 times in the 75% on data.
- Despite what you might think from the above, the data for sampled 952 distinct points. (Points with an excess of 1's are unusually likely to be sampled repeatedly)
- The data looks rightward skewed. This makes sense as there are more opportunities to do exceptionally badly than exceptionally well.
Probability density functions of log loss
To consider the hypothesis that the loss is log-normally distributed, lets see the same data, but with the logarithm of the loss.
If the data was perfectly log-normally distributed, the plot above would be a bell curve. The bell curves of best fit are shown for comparison. I would say that these curves look to be pretty close to bell curves, and the variation could well be written off as noise.
Here is a table of the means and standard deviations of those bell curves.
Prob | Mean log loss | STD log loss |
0.25 | -0.08020699 | 0.44782004 |
0.5 | -0.28468585 | 0.5544732 |
0.75 | -0.7286095 | 0.6133959 |
Even if the small deviations from the (dotted line) bell curves isn't noise, it looks like assuming a log-normal distribution is a reasonably good approximation, so I will be doing that going forward.
Gradients of the loss
These nodes are being turned on or off by multiplying each node by .
The multiplication happens after the Relu activation function, not that this is any different to the multiplication happening before the activation function.
Why? Well these gradients can be easily computed by back propagation, and they could give insight into the structure of the network.
We can see that the gradients tend to be more pointy topped and heavy tailed than their bell curves predict.
The table of means and standard deviations
Prob | Mean loss gradient | STD loss gradient |
0.25 | -0.0693278 | 0.36804602 |
0.5 | -0.04751453 | 0.33512998 |
0.75 | -0.027095174 | 0.26713666 |
Note that the means are slightly negative, indicating that on average, the score is better with more nodes. On average, the 0.75 samples should contain about 10 nodes more than the 0.25 samples. Taking the mean loss gradient at 0.5, and scaling it by a factor of 10, we get -0.4751453. The mean loss at 0.75 minus the mean loss at 0.25 is 0.584365-1.0277158=-0.4433508. These numbers are reasonably close, which seems like a good sanity check.
In the previous section, we considered that log loss had a nicer distribution. What is the gradient of log loss doing?
This means we can find just by dividing the gradient of the loss by the loss.
Prob | Mean log loss gradient | STD log loss gradient |
0.25 | -0.0746937 | 0.36064485 |
0.5 | -0.05837615 | 0.38386792 |
0.75 | -0.046731997 | 0.44563505 |
The only really surprising thing about the above graph is the consistent trough at 0.
Is that node on
One hypothesis we might form is that we are observing the sum of 2 different distributions, each with a single peak.
Lets take the data for Prob=0.5. Each neuron is equally likely to be present or removed.
If is a function that takes a list of 1's and 0's, representing the presence or absence of each neuron, and returns the log loss, then the th component of the gradient is
An obvious distinction to make here is if is 0 or 1.
corresponds to looking at a node operating at 0% (ie turned off) and asking if the network would do better if the node was at 1% instead.
corresponds to looking at a node operating at 100% (ie turned on) and asking if the network would do better if the node was at 101% instead.
Node is | Mean log loss gradient | STD log loss gradient |
On | 0.055790696 | 0.35940725 |
Off | -0.17263436 | 0.37343097 |
Ok. That wasn't quite what I expected. Lets split this up by neuron and see how that affects the picture.
At a glance, some of these plots look smoother, and others look more jagged. For example, neuron 11 has smooth looking curves, and neuron 4 has more jagged curves. Actually, both curves are still being smoothed by a Gaussian kernel of width 0.03. Inspecting the axis, we see that the gradients on neuron 11 are actually much smaller. Some neurons just don't make as much of a difference.
How much does N flips matter
Consider taking a random starting position (independent Bernoulli, ) and a random permutation of the neurons. Flip each neuron in the random order until every neuron that was on is now off, and every neuron that was off is now on.
We can visualize this process.
Here is a cube with edges labelled with coordinates. Each corner of the cube has a list of 1's and 0's which represents a way some nodes could be missing. (Of course, the network being visualized has 20 nodes, not just 3, so imagine this cube, but 20 dimensional.) We start by picking a random corner, and drawing the red line, which goes once in each direction.
There are N=20 red arrows, meaning 21 values for the loss at the endpoints.
The distribution over paths is symmetric about reversal.
The heat map show the covariance matrix between log losses along the paths.
The turquoise line on the graph beneath shows the covariance between the log loss at start of the path, and the log loss along the path. (In distance from start).
The pink line shows the covariance of log loss at the end of the path, and log loss along the path, measured in distance from the end.
These should be identical under symmetry by path reversal. And indeed the lines look close enough that any difference can reasonably be attributed to sampling error.
This graph shows that it takes around 5 or 6 node-flip operations before the correlations decay into insignificance.
The surprising thing shown is the significant positive correlation between log loss of a network, and log loss of the reverse. This means if you take a pruned network that scores well, and reverse it, turning on all the nodes that were off, and turning off all the nodes that were on, the result is usually still a well scoring network.
I suspect that a pruned network scores well when its nodes balance out, their being equally many nodes focussing on the top, bottom, left and right of the network.
As the full network is well balanced, this makes the nodes of the reversal also well balanced.
Heatmap of Gradient
So far, all considerations of the gradients are averaged over a test sample. But the gradient of the loss is well defined at every point.
Here is a heatmap of gradient of loss for each node. (Without pruning)
Blue represents a part of the solution space where a node is lowering the loss. If the network is confidant in its predictions one way or the other, then a marginal change to the neurons has a negligible effect on loss. The places where the network is uncertain are the annulus around the boundary circle. Thus the mid blue/green of 0 is seen away from the boundary circle.
Nodes 0, 9, 10 and 19 are acting as a bias. They ignore their input, assigning all points to within the circle. Hence blue (improved predictions) inside the circle, and yellow (worse predictions) just outside the circle. The other nodes slice a part of the space away at the edges, saying that every point sufficiently far in some direction is outside the circle. Slicing off an edge makes the part outside the circle bluer, and the part inside that gets caught on this edge somewhat yellow.
Here is a plot of the loss, showing most of the predictive loss occurs around the edge of the circle.
7 comments
Comments sorted by top scores.
comment by Maxwell Peterson (maxwell-peterson) · 2022-04-15T19:31:49.271Z · LW(p) · GW(p)
This is super cool. I’d have thought this was a great post if it was just the content of the video, so the additional analysis is, like, super great.
Replies from: donald-hobson, donald-hobson↑ comment by Donald Hobson (donald-hobson) · 2022-04-21T19:22:35.458Z · LW(p) · GW(p)
I've shared the draft of the next post with you, in case you want to look at it.
For anyone else reading this, my supervisors don't want this public until the idea has been submitted to a conference. (Plagerism concerns) But DM me and if your profile shows you can come up with your own ideas, I'll let you in.
Replies from: maxwell-peterson↑ comment by Maxwell Peterson (maxwell-peterson) · 2022-04-21T19:57:30.556Z · LW(p) · GW(p)
Thanks! I’ll give it a read
↑ comment by Donald Hobson (donald-hobson) · 2022-04-15T22:54:49.637Z · LW(p) · GW(p)
I'm a PhD student. My supervisors like getting reports on what I've been doing. Lesswrong has a good user interface. The comments I get on lesswrong have so far been about as insightful as my supervisors comments.
The only slight problem is my supervisors want to discourage plagiarists from reading and copying my work by getting me to write in incomprehensible formaliese in pay per view journals. After all, without self appointed corporate gatekeepers, how would anyone know if my work was any good? By looking at it?
comment by TLW · 2022-04-14T11:29:32.442Z · LW(p) · GW(p)
Probability density functions of loss
I might plot the CDF instead. That way you don't need to smooth.
The surprising thing shown is the significant positive correlation between log loss of a network, and log loss of the reverse. This means if you take a pruned network that scores well, and reverse it, turning on all the nodes that were off, and turning off all the nodes that were on, the result is usually still a well scoring network.
I suspect that this is only because you have a single hidden layer.
Replies from: donald-hobson↑ comment by Donald Hobson (donald-hobson) · 2022-04-14T13:21:13.095Z · LW(p) · GW(p)
I might plot the CDF instead. That way you don't need to smooth.
Only by applying a very smoothing transformation, namely integration. I think its harder to see what is going on in CDF plots, because its easy to see a line falling by 5%, but hard to notice a line getting 5% less steep.
For example, which of these plots is easier to read
Or
Plotting the CDF has turned a very obvious massive spike into a slightly flatter section. One of these curves is from normally distributed data. You can tell at a glance which it is from the top plot. The bottom plot makes it less obvious.
Yep. Testing this on bigger networks is on my todo list.
Replies from: TLW