[PAPER] Jacobian Sparse Autoencoders: Sparsify Computations, Not Just Activations

post by Lucy Farnik (lucy.fa) · 2025-02-26T12:50:04.204Z · LW · GW · 0 comments

Contents

  Summary of the paper
    TLDR
    Why we care about computational sparsity
    Jacobians ≈ computational sparsity
    Core results
  How this fits into a broader mech interp landscape
    JSAEs as hypothesis testing
    JSAEs as a (necessary?) step towards solving interp
  Call for collaborators
  In-the-weeds tips for training JSAEs
None
No comments

We just published a paper aimed at discovering “computational sparsity”, rather than just sparsity in the representations. In it, we propose a new architecture, Jacobian sparse autoencoders (JSAEs), which induces sparsity in both computations and representations. CLICK HERE TO READ THE FULL PAPER.

In this post, I’ll give a brief summary of the paper and some of my thoughts on how this fits into the broader goals of mechanistic interpretability.

Summary of the paper

TLDR

We want the computational graph corresponding to LLMs to be sparse (i.e. have a relatively small number of edges). We developed a method for doing this at scale. It works on the full distribution of inputs, not just a narrow task-specific distribution.

Why we care about computational sparsity

It’s pretty common to think of LLMs in terms of computational graphs. In order to make this computational graph interpretable, we broadly want two things: we want each node to be interpretable and monosemantic, and we want there to be relatively few edges connecting the nodes.

To illustrate why this matters, think of each node as corresponding to a little, simple Python function, which outputs a single monosemantic variable. If you have a Python function which takes tens of thousands of arguments (and actually uses all of them to compute its output), you’re probably gonna have a bad time trying to figure out what on Earth it’s doing. Meanwhile, if your Python function takes 5 arguments, understanding it will be much easier. This is what we mean by computational sparsity.

More specifically in our paper, we “sparsify” the computation performed by the MLP layers in LLMs. We do this by training a pair of SAEs — the first SAE is trained on the MLP’s inputs, the second SAE is trained on the MLP’s outputs. We then treat the SAE features as the nodes in our computational graph, and we want there to be as few edges connecting them as possible without damaging reconstruction quality.

A diagram illustrating our setup. The function  which we sparsify is described by the function composition of the TopK activation function of the first (input) SAE , the decoder of the first SAE , the MLP , and the encoder of the second (output) SAE . We note that the activation function  is included for computational efficiency only; see Section 4.2 of the paper for details.

Jacobians  computational sparsity

We approximate computational sparsity as the sparsity of the Jacobian. This approximation, like all approximations, is imperfect, but it is highly accurate in our case due to the fact that the function we are sparsifying is approximately linear (see the paper for details).

Importantly, we can actually compute the Jacobian at scale. This required some thinking, since naively the Jacobian would be of size , which would have trillions of elements even for very small LLMs and would result in OOM errors. But we came up with a few tricks to do this really efficiently (again, see the paper for details). With our set up, training a pair of JSAEs takes only about twice as long as training a single standard SAE.

We compute the Jacobians at each training step and add an L1 sparsity penalty on the flattened Jacobian to our loss function.

Core results

How this fits into a broader mech interp landscape

(Epistemic status of this section: speculations backed by intuitions, empirical evidence for some of them is limited/nonexistent)

JSAEs as hypothesis testing

If we ever want to properly understand what LLMs are doing under the hood, I claim that we need to find a “lens” of looking at them which induces computational sparsity.

Here is my reasoning behind this claim. It relates to a fuzzy intuition about “meaning” and “understanding”, but it’s hopefully not too controversial: in order to be able to understand something, it needs to be decomposable. You need to be able to break it down into small chunks. You might need to understand how each chunk relates to a few other parts of the system. But if, in order to understand each chunk, you need to understand how it relates to 100,000 other chunks and you cannot understand it in any other way, you’re probably doomed and you won’t be able to understand it.

“Just automate mech interp research with GPT-6” doesn’t solve this. If there doesn’t exist any lens in which the connections between the different parts are sparse, what would GPT-6 give you? It would only be able to give you an explanation (whether in natural language, computer code, or equations) that has these very dense semantic connections between different concepts, and my intuition is that this would be effectively impossible to understand (at least for humans anyway). Again, imagine a python function with 10,000 arguments, all of which are used in the computation, which cannot be rewritten in a simpler form.

Either we can find sparsity in the semantic graph, or interpretability is almost certainly doomed. This means that one interpretation of our paper is testing a hypothesis. If computational sparsity is necessary (though not sufficient) for mech interp to succeed, and without computational sparsity mech interp would be doomed, is there computational sparsity? We find that there is, at least in MLPs.

JSAEs as a (necessary?) step towards solving interp

In order to solve interpretability, we need to solve a couple of sub-problems. One of them is disentangling activations into monosemantic representations, which are hopefully human-interpretable. Another one is finding a lens using those monosemantic representations which has sparsity in the graph relating these representations to one another. SAEs aim to solve the former, JSAEs aim to solve the latter.

On a conceptual level, JSAEs take a black box (the MLP layer) and break it down into tons of tiny black boxes. Each tiny black box takes as input a small handful of input features (which are mostly monosemantic) and returns a single output feature (which is likely also monosemantic).

Of course, JSAEs are almost certainly not perfect for this. Just like SAEs, they aim to approximate a fundamentally messy thing by using sparsity as a proxy for meaningfulness. JSAEs likely share many of the open problems of SAEs (e.g. finding the complete set of learned representations, including the ones not represented in the dataset used to train the SAE, which is especially important for addressing deceptive alignment). In addition to that, they probably add other axes along which things can get messy — maybe the computational graphs found by JSAEs miss some important causal connections, because the sparsity penalty probably doesn’t perfectly incentivize what we actually care about. That is one space for exciting follow-up work: what do good evals for JSAEs look like? How do you determine if the sparse causal graph JSAEs find is “the true computational graph”, what proxy metrics could we use here?

Call for collaborators

If you’re excited about JSAEs and want to do some follow-up work on this, I’d love to hear from you! Right now, as far as I’m aware, I’m the only person doing hands-on work in this space full-time. JSAEs (and ‘computational sparsity discovery’ in general) are severely talent-constrained. There are tons of things I would’ve loved to do in this paper if there had been more engineering-hours.

If you’re new to AI safety, I’d be excited to mentor you if we have personal fit! If you’re a bit more experienced, we could collaborate, or you could just bounce follow-up paper ideas off of me. Either way, feel free to drop me an email or DM me on Twitter.

Relatedly, I have a bunch of ideas for follow-up projects on computational sparsity that I’d be excited about. If you’d like to increase the probability of me writing up a “Concrete open problems in computational sparsity” LessWrong post, let me know in the comments that you’re interested in seeing this.

In-the-weeds tips for training JSAEs

If you’re just reading this post to get a feel for what JSAEs are and what they mean for mech interp, I’d recommend skipping this section.

0 comments

Comments sorted by top scores.