Measuring Nonlinear Feature Interactions in Sparse Crosscoders [Project Proposal]
post by Jason Gross (jason-gross), rajashree (miraya) · 2025-01-06T04:22:12.633Z · LW · GW · 0 commentsContents
TL;DR Introduction Linear Interaction Assumption Measuring nonlinear feature interactions Theoretical intuition The Math Defining sparse crosscoders Bounding reconstruction error The crosscoder achieves low loss The crosscoder and the model correspond The base case The recursive case: this is where the magic happens Feature interaction metrics Probabilistic bounds or proof bounds For ReLU For softmax-attention Cross-position interaction Cross-feature interaction Applications Model diffing Describe the model mechanisms in more detail Suppress undesired interactions Adversarial examples Training objective of the crosscoder Discussion Crosscoders vs. SAEs The encoder and the decoder should not be treated symmetrically Information content Meaning: we don't need causal crosscoders Mechanistic faithfulness and measuring the richness of the crosscoder explanation Crosscoder errors measure violation of the linear representation hypothesis Acknowledgements Citation None No comments
TL;DR
-
Problem: Sparse crosscoders are powerful tools for compressing neural network representations into interpretable features. However, we don’t understand how features interact.
-
Perspective: We need systematic procedures to measure and rank nonlinear feature interactions. This will help us identify which interactions deserve deeper interpretation. Success can be measured by how useful these metrics are for applications like model diffing and finding adversarial examples.
-
Starting contribution: We develop a procedure based on compact proofs [AF · GW]. Working backwards from the assumption that features are linearly independent, we derive mathematical formulations for measuring feature interactions in ReLU and softmax attention.
Introduction
Training a sparse crosscoder (henceforth simply crosscoders) can be thought of as stacking the activations of the residual stream across all layers and training an SAE on the stacked vector. This gives us a mapping between model activations and compressed, interpretable features. We’re excited about this approach to automated interpretability, as it paves the way for accounting for feature interaction in addition to feature discovery.
In this draft project proposal, we describe the linear interaction assumption, and provide a compact-proofs-based explanation for why we’d like to measure feature interaction. We are sharing concrete metrics for measuring feature interaction in ReLU and softmax attention. We close with a discussion of empirical projects that we’re excited about, and some interesting takeaways. Please reach out to us if you have feedback or would like to collaborate!
Linear Interaction Assumption
Crosscoders would be all we need if we could make the following two assumptions about how neural networks process information:
-
Linear Representation Hypothesis: Features are encoded explicitly and linearly in the residual stream. For example, if a model needs to work with geometric areas, this hypothesis suggests the area value would be directly encoded, rather than linearly storing just the radius and computing or nonlinearly extracting when needed.
-
Linear Interaction Assumption: Features are independent. This assumption suggests that the nonlinear computation in the model distributes over features in some approximately linear way.
If the linear representation hypothesis breaks down, then crosscoders would perform poorly, getting high reconstruction error. However, even if a crosscoder achieves 100% reconstruction accuracy, it may still be the case that the linear interaction assumption does not hold for the model, and thus that the crosscoder explanation is incomplete.
Consider a concrete example: Suppose we prompt the model with “The first letter of January is”. To complete this task, an attention head needs to:
- Look at the word “January”
- Extract its first letter
- Move this information to the position after “is”
Before this attention operation occurs, the residual stream at the “is” position cannot predict what will appear there after the operation, because the crucial information (“J”) is stored at a different position. The crosscoder might detect that “J” appears in the output, but it misses the complex interaction between:
- The attention mechanism that knows to look for “January”
- The position-specific information about where “January” is located
- The logic for extracting the first letter
Measuring nonlinear feature interactions
The open problem is measuring how strongly features interact in ways that violate the linear interaction assumption. The basic idea is to:
- Look at pairs of features that co-occur in the model’s activations on a given dataset
- Measure how much their effects on the model’s behavior interact/interfere
- Rank these interactions by strength and frequency
This will give us a precise way to identify where our current feature-based understanding is incomplete, where we need some notion of crosscoder circuits. For each interaction, we can ask:
- Which features are involved?
- At which layer/location do they interact?
- How strong is the interaction?
- How often does it occur?
By focusing on the strongest and most frequent nonlinear interactions first, we get a principled way to gradually improve our understanding of the model.
Theoretical intuition
We start by providing theory-grounded intuition for feature interaction. We use the compact proofs approach, where the goal for mechanistic understanding is to compress the length of the explanation of the model (the proof of the model’s behavior).
If crosscoder features genuinely have only linear interactions, then we can prove a tight bound with linear work. On the other hand, if crosscoder features interact strongly, then the linear proof will give vacuous bounds and making the bounds tighter will bring the proofs much closer to brute-force.
The proof workflow is as follows:
-
We treat the crosscoder as producing an encoding in features of the dataset and treat the decoder as approximating the model
-
We seek to bound the reconstruction error between the logits and the unembed applied to the decoded last-layer activations
-
We can perform exhaustive enumeration on the entire data distribution, but this is as expensive as running the entire model on every datapoint.
-
Instead, we may be able to reduce the time complexity of this bound
- by assuming that the crosscoder contains the full extent of our understanding of the model (what we do in this blog post), and
- by clustering the input distribution into sets of datapoints that behave similarly[1]
The next subsection discusses the math in detail; feel free to skip.
The Math
Defining sparse crosscoders
Training a sparse crosscoder can be thought of as stacking the activations of the residual stream across all layers and training an SAE on the stacked vector. Quoting Anthropic:
The basic setup of a crosscoder is as follows. First, we compute the vector of feature activations on a data point by summing over contributions from the activations of different layers for layers :
where is the encoder weights at layer , and is the activations on data point at layer . We then try to reconstruct the layer activations with approximations of the activations at layer :
Now let us define our architecture hyperparameters. Consider:
- , a transformer with
- vocab tokens,
- layers, and
- dimensions in the residual stream, so that
- is the dimension of the stacked activation vector. Let
- be the (finite) set of data points, underlying
- , the distribution, and let
- denote the number of samples in and let
- denote the average number of tokens in data points in (averaged over the uniform distribution rather than with respect to ), so that
- is the number of tokens in the dataset.
Suppose our crosscoder has
- features (the width) and
- as its sparsity (L0 norm), so that
- and generally have shape and respectively. Let
- and be the embedding and unembedding matrices of dimensions and and respectively.
We will sweep the details of positional encoding under the rug. Let
- be the floating-point precision we are using.
Bounding reconstruction error
Our proof will consist of two parts[2]: first the proof that the crosscoder (and therefore, we hope, the model) achieves low loss on the distribution ; and second the proof that the crosscoder and the model correspond.
The crosscoder achieves low loss
Take the set of data points and exhaustively encode them into the feature space by running the encoder , let us say .[3] This data will require bits to encode naïvely. To exhaustively establish low loss on the distribution of interest, we simply compute our approximation for the loss on data point as for each data point . This costs floating point operations. We can take a weighted average over to establish the overall expected loss.
The crosscoder and the model correspond
We must now establish that the crosscoder and its feature vectors are actually good approximations of the model activations. That is, we want to compute an upper bound on
We must ultimately compute this error recursively in layers, where the error term on layer is computed in terms of the error term on layer .[4] That is, we want to establish, for each layer , an upper bound on likely by computing this bound recursively in terms of .[5]
The base case
This procedure bottoms out in establishing how far the true token embed sequences are from the crosscoder approximation: and then we might want to find an upper bound on this error Note that we are sweeping under the rug details of the positional embedding of the sequence of tokens in here.
We can compute this embedding bound exhaustively in time .
The recursive case: this is where the magic happens
If we have to compute all the error terms exhaustively, we will have gained nothing compression-wise. However, by making an additional assumption that is only somewhat unrealistic, we can compute this error in only as many forward passes as the square of our feature count (time )[6].
The additional assumption we need is the linear interaction assumption, that features interact only in linear ways. This means, for example, that not only do we need the residual stream to be a sum of linearly represented features, but also that the ReLU of a sum of these features must be the sum of the ReLUs, or at least not far off from that.[7] Violations of this assumption are sure to abound, and define the scope of what interpretability work remains to be done, once we have crosscoders.
Feature interaction metrics
When we seek to bound the reconstruction error systematically in time , we are forced to compute a crude uniform bound on the interaction between features. We can measure how crude this computation is by measuring the size of the error terms it introduces, giving us a metric on feature interaction strength.
Probabilistic bounds or proof bounds
Broadly, two approaches are possible: probabilistic (average-case over model weights) bounds, and proof-based (worst-case) bounds. We will briefly sketch out the metrics grounded in probabilistic bounds, which we may develop more in the future, before diving into the details of proof-based bounds.
If we ground our metrics in probabilistic bounds, then we want to measure deviation from the feature independence hypothesis: that features can be well-modeled as independent random variables. We can, for example:
- fit the feature activations to a multidimensional Gaussian
- use the off-diagonal entries of the covariance matrix as a metric on feature independence
Alternatively, we can:
- analytically evaluate nonlinearities on a binned distribution under the assumption that the features are independent[8]
- measure the KL divergence between this distribution and the actual distribution of activations
If we ground our metrics in worst-case (proof) bounds, then we want to measure deviation from the linear feature interaction assumption: that nonlinearities distribute over features in approximately linear ways. Testing this assumption requires a different procedure for each kind of nonlinearity.
In the next subsections, we discuss ReLU and softmax-attention: broadly, we test the hypothesis that co-occurring features at each neuron do not interact at all; and that co-occurring features at the same attention head obey a sort of strict dominance principle where the head as a whole behaves as if only the most strongly activating feature were present.
For ReLU
Applying the compact-proof based approach, we define the following procedure for measuring feature interaction for ReLUs, testing the hypothesis that co-occurring features at each neuron do not interact at all.
Given a pair of features , that decode at layer to (and similarly for ), if the MLP input matrix is , we can measure the interaction between and at a neuron via their elementwise absolute ratios, averaged over the empirical distribution of their strengths , . Taking and , we may define the interaction strength to be so that we may express the neuron- preactivation from the linear combination of as Then we might measure the overall interaction strength by averaging over the dataset:
We can then bound the postactivations by the triangle inequality.
(A simpler but less useful measure would be the norm of the Hadamard (elementwise) product .)
For softmax-attention
Applying the compact-proof based approach, we define the following procedure for measuring feature interaction for softmax attention, testing the hypothesis that co-occurring features at the same attention head obey a sort of strict dominance principle where the head as a whole behaves as if only the most strongly activating feature were present. Because attention mixes information at multiple positions in the stream, we have more interaction terms than for ReLU MLPs.
Cross-position interaction
Given a pair of features , that decode at layer to (and similarly for ), ignoring biases, we may express the QK interaction term with in the query position as . Features interact insofar as this value is different for these features in particular than for the average of this term across features.
Cross-feature interaction
Two features can be said to interact as queries insofar as they interact with overlapping sets of keys.
Applications
How does measuring interactions help? There are several ideas that we’d love to test out:
Model diffing
In Stage-Wise Model Diffing, Bricken et al. propose a procedure for detecting what features change the most when a model is fine-tuned. We can build on this method to answer the questions:
-
Which features have the strongest interactions with the features that change under fine-tuning? Broadly, this provides an automated method for understanding the algorithms that models implement in order to learn the new objective.
-
Which feature interactions change the most during fine-tuning, even or especially when the features themselves do not change? While simple behavioral changes are likely to show up in the features themselves, more subtle or complicated changes may show up only or primarily in the ways features interact. Broadly, this allows us to make our detection and control methods more robust to adversarial pressure.
Describe the model mechanisms in more detail
We can use the ranked list of feature-location triplets as a guide for which parts of the model have the most unexplained dimensionality. We might be able to visualize the model’s circuits just by color-coding features and making a plot of feature interaction strength vs location in the model. Potentially, we will develop the mathematics necessary to find even cleaner abstractions of model computation.
There are probably a number of clusters of similar sorts of nonlinear feature interactions. For example, we might expect to see:
- some features that interact only by worst-case (proof) metrics, not by probabilistic metrics, or vice versa
- features whose encoder (and/or decoder) have high cosine similarity and are therefore likely to interact (we may see this in feature splitting, for example)
- some features that are “obviously supposed to interact” and do so in simple ways (for example, a “UK English vs US English” feature might interact strongly with features for expressions that differ between these dialects, but only very early and/or very late in the model)
Suppress undesired interactions
We can edit or fine-tune a language model to remove undesired feature interactions.
-
We can ablate the overlap of the two features entirely, assigning each dimension to one of the two interacting features and suppressing the other one.
-
Alternatively, we can use the feature interaction metric as a penalty in the loss function for the specific features of interest, and fine-tune the language model to suppress their interaction.
Adversarial examples
Some pairs of features might co-occur very rarely (or never), but have strong interaction — these features can perhaps be leveraged to generate adversarial examples, especially the lack of co-occurrence is an artifact of the dataset.
Training objective of the crosscoder
The methods for measuring feature interactions can be added as penalties to the crosscoder training loss function. While adding this penalty with a large coefficient would probably distort the features and increase reconstruction error significantly, a weak penalty should encourage the crosscoder to eliminate spurious feature interactions.
Discussion
Anthropic’s Sparse Crosscoders for Cross-Layer Features and Model Diffing post closes with some interesting questions. We respond to these questions from the compact proofs frame.
Crosscoders vs. SAEs
From the compact proofs frame, crosscoders seem clearly much better than SAEs in the sense that they should get at features the model actually uses much more cleanly. The above methodology for generating proofs does not work for SAEs without a correspondingly good story for SAE circuits, and even with circuits it’s not entirely clear how to either validate and leverage or do away with the linear representation hypothesis or anything similar to the linear interaction assumption.
The encoder and the decoder should not be treated symmetrically
Information content
In the proof, we must store the encoded dataset, and we must perform a computation using the decoder. This means that when measuring information, measuring entropy of the entries of the decoder matrix is fine, but for the encoder, we should instead consider the entropy of the encoded dataset.
Meaning: we don't need causal crosscoders
This asymmetry suggests that the encoder is relevant only to the quality of the features, and the interpretation / meaning of the features lives in the decoder. That is, using causal crosscoders or other fancier schemes may give us more accurate features, but they do not change the interpretation of the features we get, except inasmuch as they change how features interact.
Mechanistic faithfulness and measuring the richness of the crosscoder explanation
In the compact proofs approach, we measure the richness of an explanation by the length of the proof. A crosscoder-based proof has overall length approximately where is the (L0) sparsity and is the crosscoder width (number of features).
More precisely, we have[9] or, collecting terms and taking for , Note that the leading dataset-size-dependent asymptotic term here is the cost of embedding the dataset (in the regime where ). If we include the pre-embed and post-unembed vectors in the stacked activations, we can avoid this term and the leading dataset-size-dependent asymptotic term becomes (or if you want to be slightly more precise).
The leading asymptotic term that is independent of dataset-size is .
Recall that the brute-force “run inference on the dataset” baseline cost is .
We can imagine plotting a pareto frontier of mechanistic faithfulness (tightness of bound, whether probabilistic or proofs-based) vs cost of computation, and measure the richness of crosscoder explanations by looking at how much these explanations improve on the baseline of running model inference.
Crosscoder errors measure violation of the linear representation hypothesis
While crosscoder “error features” allow for a kind of exact isomorphism with the original model, they take significantly more compute to account for than actual features. This is the formal analogue of these “error features” being potentially extremely difficult to interpret. One might see the technical content of this post as being centrally about addressing the question of interpreting crosscoder error features:
-
If the error features are large at a given compute budget, this suggests that there is no simple linear decomposition of the model’s representations at that level of abstraction. Insofar as the errors are large across a wide range of compute budgets, this would be evidence against the linear representation hypothesis.
-
If the error features are small but our bounds on them are loose, this suggests that features are interacting nonlinearly, that crosscoders fail to fully capture feature interaction, and we need to discover crosscoder circuits.
Acknowledgements
Thanks to Paul Christiano for pointing out that we can assume independence of features instead of noninteraction of features as a starting point. Thanks to Adria Garriga-Alonso, Neel Nanda, Louis Jaburi, and Kola Ayonrinde for discussion and comments on a draft of this post.
Citation
Please cite as:
Gross, Jason and Agrawal, Rajashree, “Measuring Nonlinear Feature Interactions in Sparse Crosscoders [Project Proposal]", AI Alignment Forum, 2025.
BibTeX Citation:
@article{gross2025measuring,
title={Measuring Nonlinear Feature Interactions in Sparse Crosscoders [Project Proposal]},
author={Jason Gross and Rajashree Agrawal},
year={2025},
journal={AI Alignment Forum},
note={\url{https://www.alignmentforum.org/posts/RjrGAqJbk849Q7PHP/measuring-nonlinear-feature-interactions-in-sparse}}
}
This is beyond the scope of this post. ARC Theory is currently working on automating heuristic argument discovery following roughly this direction, though, and if you believe the SLT folks that symmetries define all interesting structure, then this sort of clustering may be essentially all you need to automate the remainder of the interpretability work. ↩︎
To make this post self-contained, we include in this footnote a review of the compact proofs approach: In the compact proofs approach, we write down a theorem saying that some particular model does what it does on the distribution of interest: where is the distribution of input-label pairs (such as The Pile or OpenWebText or the SCC training distribution), is the loss function, and is the upper bound on loss we wish to prove.
To prove this theorem, we first construct a function which computes a valid bound for any model weights. A proof of model performance then consists of a proof that gives valid upper bounds for all model weights together with a proof that on some particular model we have for some concrete (constructed by running ). Mathematically, we want to prove the two inequalities: The first proof we do in the standard way (in LaTeX or a proof assistant or pen-and-paper), and the second proof consists of a transcript of running the program (which may be written in PyTorch, for example). In general, the length of the proof can be well-approximated (asymptotically) by the running time of .
The baseline approach is to just run model inference on some fraction of the points in . For the other points, we merely need to establish cheaply that the model never produces nor . (I hope this can be done by bounds propagation and/or standard neural network verification tools, but we’ve not yet tested this approach for establishing the baseline for loss rather than accuracy.) The linear baseline is then just a line between the true loss of the model (which requires running inference on the full distribution) and the largest finite float (which should not cost more than a couple forward passes).
The compact proofs approach says that having a rich understanding consists of being able to get a better tradeoff between lossiness (tightness of bound ) and compression ratio (proof length or running time of ).
Note that in general there are two steps to any interpretation: finding the interpretation, and communicating (or presenting or validating) it. As is typical in estimating complexity, we don’t care how long it takes to find an interpretation, only how long it takes to communicate it. (In our case, we “communicate” the interpretation to a proof verifier.) ↩︎As far as proof length — measuring the richness of understanding — is concerned, the computational cost of encoding the dataset is irrelevant; the proof strategy is valid for any feature vector, and incorrectly encoded feature vectors will simply result in loose bounds. ↩︎
We may divide the computation up into clusters of datapoints that behave similarly, and compute the error separately on each cluster. The brute force approach has one datapoint per cluster, and the error is tight. The most compact proof we can get without additional insight over the crosscoder has a single cluster for all datapoints. ↩︎
Note that if there’s a location where the error estimate is bad, we can insert more intermediate points into the cross-coder training — I imagine that every post-nonlinearity would be useful, as might the , , and vectors. If we insert enough intermediate points, I expect the recursive estimate to be relatively tight. ↩︎
Note that if we relax from worst-case to average case and go for probabilistic bounds, we can assume independence of features and independence of errors and compute a binned distribution of features which we push through each of the layers in sequence, giving something like where is the number of bins. ↩︎
This is why I am so excited about crosscoders: for the first time, the explanation captured by an automated interpretability procedure is close enough to a proof that we can write out an almost reasonable proof with a clear picture of what holes remain to be filled, rather than having most of the proof be unspecified! ↩︎
ARC Theory has shown how to do this in the context of VAEs. ↩︎
We take to be the asymptotic cost of multiplying -bit floating-point numbers. ↩︎
0 comments
Comments sorted by top scores.