Influence functions - why, what and how
post by Nina Panickssery (NinaR) · 2023-09-15T20:42:08.653Z · LW · GW · 6 commentsContents
Deriving the exact form of the influence function Influence on some function of the model weights Problems with this expression Efficient calculation Kronecker-Factored Approximate Curvature (KFAC) Eigenvalue correction Influence functions for autoregressive models Implementing in PyTorch Results of small experiment on MNIST None 6 comments
Anthropic recently published the paper Studying Large Language Model Generalization with Influence Functions, which describes a scalable technique for measuring which training examples were most influential for a particular set of weights/outputs of a trained model. This can help us better understand model generalization, offering insights into the emergent properties of AI systems. For instance, influence functions could help us answer questions like "is the model asking not to be shut down because it has generalized that this is a generically good strategy for pursuing some goal, or simply because texts where AIs ask not to be shut down are commonly found in the training corpus?".
In this post, I aim to summarize the approximations used in the paper to calculate the influence of different training examples and outline how the approximations can be implemented in PyTorch to form the basis of further research on influence functions by the AI safety community.
(Note: most formulae are copied or adapted from the original paper, with a few additional derivation steps / simplified notation used for clarity in some places.)
Deriving the exact form of the influence function
Before we go into approximations, it is necessary to understand what specifically we are trying to measure.
Given some element of a dataset , we define the response function as the optimal solution (weights that minimize expected loss ) as a function of the weighting of this example.
We define the influence of on using the first-order Taylor approximation to the response function at .
We can get the following way:
We know is a minimum of and so the gradient wrt is zero at that point
Differentiating each side wrt :
(The LHS both directly depends on , and indirectly via , so we use the Implicit Function Theorem )
The second term can be simplified:
And so we can rearrange to get an expression for :
This tells us how the optimal parameters change with a perturbation to the weighting of an added data point . The change is proportional to the negative product of the inverse Hessian of the loss on all the data and the gradient of the loss on the data point in question with respect to the model parameters (both evaluated at the optimal parameters).
For simplicity, as in the paper, we'll denote as .
Therefore, (This corresponds to Equation 3 in the paper).
Influence on some function of the model weights
So far, we have derived an expression for the influence of an added data point on the parameters . However, we are more interested in the influence of particular data points on some measurable properties of the model, such as the output logits or validation loss. We can see this as some function of the trained parameters.
By the chain rule and so
(This corresponds to Equation 5 in the paper).
Problems with this expression
- Hessian could have zeros and be not invertible (optimal parameters could be underspecified by loss function in case of overparameterized models)
- We often don't train to convergence, so the first derivative of the loss wrt the parameters is then not necessarily zero, as previously assumed
The paper mentions that because of these problems, "past works have found influence functions to be inaccurate for modern neural networks."
How do we fix this?
One approach is to define a new objective that:
- Has a single defined optimum in parameter space
- Is fully optimized when the model stops training
This is what the proximal Bregman objective (PBO) attempts to define.
( here is the output of the model at the parameters on input , and is the output of the model at parameters on input )
The PBO basically introduces a penalty for diverging too far from the initialized parameters, so there is some defined optimum that balances moving too far from the parameters at initialization and achieving good loss.
So we can redefine the gradients used in in terms of this new loss function that considers both the loss given a new training data point and the divergence from current parameters.
From Bae et al.'s 2022 paper If Influence Functions are the Answer, Then What is the Question?:
...while influence functions for neural networks are often a poor match to LOO [Leave One Out] retraining, they are a much better match to what we term the proximal Bregman response function (PBRF). Intuitively, the PBRF approximates the effect of removing a data point while trying to keep the predictions consistent with those of the (partially) trained model.
...
In addition, although the PBRF may not necessarily align with LOO retraining due to the warm-start[1], proximity, and non-convergence gaps, the motivating use cases for influence functions typically do not rely on exact LOO retraining. This means that the PBRF can be used in place of LOO retraining for many tasks such as identifying influential or mislabelled examples
Applying the Implicit Function Theorem to the PBO, we can obtain an influence function with respect to the PBO objective (Equation 9 in the paper):
Where is the Gauss-Newton Hessian . is the Jacobian - the first derivative of the network's outputs with respect to the parameters, and is the Hessian of the loss with respect to the network's outputs.
Efficient calculation
So, we want to get .
Let's assume we have the following:
- A trained network with parameters
- An observable property of the network, , for instance, its output logits for some chosen input
- The training dataset
- The loss function the model was trained on
The key ingredients needed to calculate are:
- The gradient of the property of interest with respect to the parameters , evaluated at
- A way of getting the inverse damped Gauss-Newton Hessian vector product
- The gradient of the loss on the training data points (which we want to calculate the influence of) with respect to the parameters , evaluated at
Notice that only key ingredient 1 depends on the property of interest. We can pre-compute ingredients 2 and 3 and then use this to test a bunch of different properties (for example, find the most influential training examples for a bunch of different model input-output pairs).
We can also calculate the influence as a batched operation over many training data points (batch over multiple 's) to increase efficiency via vectorization.
Which leaves the final key question: how do we get ?
Kronecker-Factored Approximate Curvature (KFAC)
Originally introduced in the 2015 paper Optimizing Neural Networks with Kronecker-factored Approximate Curvature by Martens and Grosse, KFAC is an approximation to the Fischer information matrix (FIM) that can be inverted very efficiently. In the case of many models where the loss is given by the negative log probability associated with a simple predictive distribution, the FIM is equal to the Gauss-Newton Hessian.
KFAC for MLP models involves the following:
Given a fully connected model with layers, let's assume each layer 's output is:
where , and is a nonlinear activation function. [2]
When we backpropagate to get , we need to calculate the derivative of the with respect to intermediate stages of the computation at each layer. So as we go backward through the computational graph, once we get to the output of , we'll have computed .
By the chain rule, and using the fact that ,
This means we can decompose the gradient[3] of the log-likelihood loss on some data point with respect to the weight matrix into the intermediate gradients of the loss with respect to the output of applying the weight matrix and the activations prior to that layer.
Working with gradients of weight matrices is inconvenient though, as we end up with 3D tensors for the Jacobian. We can instead consider , the unrolled weight matrix for layer .
Then, defining , , and as the Kronecker product:
So far, so exact... But now, time for approximations. KFAC makes things simpler by assuming:
- Gradients are uncorrelated between different layers
- Activations are independent of pre-activation gradients
This allows us to write down a simple block-diagonal approximation for :
Where and are uncentered covariance matrices for the layer's input activations and pre-nonlinearity gradients, respectively.
This structure enables us to efficiently get the inverse (approximate) Gauss-Newton Hessian vector product:
Let denote the entries of for layer , reshaped to match , and let
Using various Kronecker product identities, we can compute the inverse (approximate) Gauss-Newton Hessian vector product as:
Eigenvalue correction
We made an approximation earlier when we went from to
Using the eigendecompositions of and :
we can write a more accurate expression for :
where the diagonal matrix is defined as:
which "captures the variances of the pseudo-gradient projected onto each eigenvector of the K-FAC approximation".
We can get the damped inverse Gauss-Newton Hessian vector product approximation by adding to the eigenvalues, obtaining:
Influence functions for autoregressive models
A few details change when we want to calculate for a Transformer language model trained with an autoregressive loss function.
In this case, the property of interest considered (the thing we are calculating the influence on) is the log-likelihood of a particular token string completion , given a token string prompt [4]:
The paper only considers measuring the influence on a subset of the Transformer's weights - only the MLP layers - so the MLP approximation derived above applies almost exactly.
However, the parameter gradients are now summed over token indices:
Each diagonal block of is given by , however we want to take into account how this second moment is affected by the inter-token correlations and so cannot as accurately directly approximate with as before.
The paper presents the following middle-ground between efficiency and accuracy:
We first fit the covariance factors and as if the tokens were fully independent, and compute their respective eigendecompositions. Then, when fitting the diagonal matrix , we use the exact pseudo-gradients which are summed over tokens. This way, at least the estimated diagonal entries of the moments in the Kronecker eigenbasis are unbiased.
Implementing in PyTorch
As described above, the key ingredients for are:
- The gradient of the property of interest with respect to the parameters , evaluated at
- The inverse damped Gauss-Newton Hessian, which we can calculate from the expectations of the following quantities:
- - the MLP layer inputs
- - the MLP pre-nonlinearity gradients (gradients of loss wrt output of linear transformation )
- The gradient of the loss on the training data points (which we want to calculate the influence of) with respect to the parameters , evaluated at
We can get 1) and 3) by simply fetching parameter.grad
[5] after performing a backward pass of the loss on some input, target pair.
We can get 2a) using a forward hook that saves the input to a layer during the forward pass. We can get 2b) using a backward hook on the linear layer that saves the gradient wrt the linear layer's output.
You can find my implementation attempt on GitHub here [6]- includes code applying influence functions analysis to a vanilla MLP trained on MNIST and a 2-layer transformer trained on a basic next character prediction task.
Results of small experiment on MNIST
I trained an MLP on MNIST (with flattened images) and then used the influence function approximation code to extract influential training examples for particular predicted test set labels.
I found that influential training digits were usually more sloppy / unclear compared to the average MNIST digit, and shared some resemblance with the query image. Not all top influential images shared the same label as the query. I only searched a subset of the training corpus, for efficiency.
Here are some examples, filtered by cases where the influence was non-negligible (some queries returned ~0 for all sampled training datapoints) (first image on left is query, followed by top most influential training images given by the approximation):
- ^
The warm-start problem referenced by Bae et al. refers to the fact that for a not strictly convex objective, the influence of a training example in the neighborhood of a minimum may be different from the influence at a different initialization point.
- ^
The paper uses homogeneous vector notation to account for biases / affine transformations - you can assume there is a 1 appended to the activations and a bias vector appended to to cover this case.
- ^
The paper refers to these as "pseudo-gradients" since they are sampled from the final output distribution and are distinct from gradients during training.
- ^
The , pair is referred to as the "query" in the paper, as we are "querying" which training examples were most influential for the model producing given .
- ^
Specifically, concatenate a linear layer's
.weight
and.bias
grad
s - ^
If you look through the code and find any bugs (quite possible) or performance improvements (definitely findable; e.g. more batching + splitting of GPU ops - WIP) I'd be super happy to merge PRs and/or hear from you! I hope to gradually improve this codebase and run larger experiments.
6 comments
Comments sorted by top scores.
comment by Troof · 2023-09-19T15:03:58.090Z · LW(p) · GW(p)
Thanks for this! One thing I don't understand about influence functions is: why should I care about the proximal Bregman objective? To interpret a model, I'm really interested in in the LOO retraining, right? Can we still say things like "it seems that the model relied on this training sample for producing this output" with the PBO interpretation?
Replies from: NinaR↑ comment by Nina Panickssery (NinaR) · 2023-09-20T13:16:06.703Z · LW(p) · GW(p)
I agree that approximating the PBO makes this method more lossy (not all interesting generalization phenomena can be found). However, I think we can still glean useful information about generalization by considering "retraining" from a point closer to the final model than random initialization. The downside is if, for example, some data was instrumental in causing a phase transition at some point in training, this will not be captured by the PBO approximation.
Indeed, the paper concedes:
Influence functions are approximating the sensitivity to the training set locally around the final weights and might not capture nonlinear training phenomena
Purely empirically, I think Anthropic's results indicate there are useful things that can be learnt, even via this local approximation:
One of the most consistent patterns we have observed is that the influential sequences reflect increasingly sophisticated patterns of generalization as the model scale increases. While the influential sequences for smaller models tend to have short overlapping sequences of tokens, the top sequences for larger models are related at a more abstract thematic level, and the influence patterns show increasing robustness to stylistic changes, including the language.
My intuition here is that even if we are not exactly measuring the counterfactual "what if this datum was not included in the training corpus?", we could be estimating "what type of useful information is the model extracting from training data that looks like this?".
comment by Gurkenglas · 2023-09-16T10:00:13.785Z · LW(p) · GW(p)
I found that influential training digits were usually more sloppy / unclear compared to the average MNIST digit, and shared some resemblance with the query image.
It's pbzcnevat gb gur arnerfg cbvagf ba gur obhaqnel bs qvtvg-pyhfgref! Bonus points if you made your observation without that interpretation in mind. What if you do Jacobian regularization?
Replies from: Hoagy↑ comment by Hoagy · 2023-09-17T00:01:42.540Z · LW(p) · GW(p)
How do you know?
Replies from: Gurkenglas↑ comment by Gurkenglas · 2023-09-17T07:02:42.291Z · LW(p) · GW(p)
It's the same training datums I would look at to resolve an ambiguous case.
comment by Sonia Joseph (redhat) · 2024-03-27T05:34:11.105Z · LW(p) · GW(p)
Thank you for this. How would you think about the pros/cons of influence functions vs activation patching or direct logit attribution in terms of localizing a behavior in the model?