Deep learning models might be secretly (almost) linear

post by beren · 2023-04-24T18:43:28.188Z · LW · GW · 29 comments

Contents

  Evidence for this:
  Why might networks actually be linear-ish?
None
29 comments

Crossposted from my personal blog.

Epistemic status: Pretty speculative, but there is a surprising amount of circumstantial evidence.

I have been increasingly thinking about NN representations and slowly coming to the conclusion that they are (almost) completely secretly linear inside[1]. This means that, theoretically, if we can understand their directions, we can very easily exert very powerful control on the internal representations, as well as compose and reason about them in a straightforward way. Finding linear directions for a given representation would allow us to arbitrarily amplify or remove it and interpolate along it as desired. We could also then directly 'mix' it with other representations as desired. Measuring these directions during inference would let us detect the degree of each feature that the network assigns to a given input. For instance, this might let us create internal 'lie detectors' (which there is some progress towards) which can tell if the model is telling the truth, or being deceptive.  While nothing is super definitive (and clearly networks are not 100% linear), I think there is a large amount of fairly compelling circumstantial evidence for this position. Namely:

Evidence for this:

  1. All the work from way back when about interpolating through VAE/GAN latent space. I.e. in the latent space of a VAE on CelebA there are natural 'directions' for recognizable semantic features like 'wearing sunglasses' and 'long hair' and linear interpolations along these directions produced highly recognizable images
  2. Rank 1 or low rank editing techniques such as ROME work so well (not perfectly but pretty well). These are effectively just emphasizing a linear direction in the weights.
  3. You can apparently add and mix LoRas and it works about how you would expect.
  4. You can merge totally different models. People working with Stable Diffusion community literally additively merge model weights with weighted sum and it works!
  5. Logit lens [LW · GW] works.
  6. SVD directions [LW · GW].
  7. Linear probing definitely gives you a fair amount of signal.
  8. Linear mode connectivity and git rebasin.
  9. Colin Burns' unsupervised linear probing method works even for semantic features like 'truth'.
  10. You can merge together different models finetuned from the same initialization.
  11. You can do a moving average over model checkpoints and this improves performance and is better than any individual checkpoint!
  12. The fact that linear methods work pretty tolerably well in neuroscience.
  13. Various naive diff based editing and control techniques work at all.
  14. General linear transformations between models and modalities.
  15. You can often remove specific concepts from the model by erasing a linear subspace.
  16. Task vectors (and in general things like finetune diffs being composable linearly)
  17. People keep [LW · GW] finding [LW · GW] linear representations inside of neural networks when doing interpretability or just randomly.

If this is true, then we should be able to achieve quite a high level of control and understanding of NNs solely by straightforward linear methods and interventions. This would mean that deep networks might end up being pretty understandable and controllable artefacts in the near future. Just at this moment, we just have not yet found the right levers yet (or rather lots of existing work does show this but hasn't really been normalized or applied at scale for alignment). Linear-ish network representations are a best case scenario for both interpretability and control.

For a mechanistic, circuits-level understanding, there is still the problem of superposition of the linear representations. However, if the representations are indeed mostly linear than once superposition is solved there seem to be little other obstacles in front of a complete mechanistic understanding of the network. Moreover, superposition is not even a problem for black-box linear methods for controlling and manipulating features where the optimiser handles the superposition for you.

This hypothesis also gets at a set of intuitions I've slowly been developing. Basically, almost all of alignment thinking assumes that NNs are bad -- 'giant inscrutable matrices' -- and success looks like fighting against the NN. This can either be through minimizing the amount of the system that is NN-based, surrounding the NN with monitoring and various other schemes, or by interpreting their internals and trying to find human-understandable circuits inside. I feel like this approach is misguided and makes the problem way more difficult than it needs to be. Instead we should be working with the NNs. Actual NNs appear to be very far from the maximally bad case and appear to possess a number of very convenient properties - including this seeming linearity -- that we should be exploiting rather than ignoring. Especially if this hypothesis is true, then there is just so much control we can get if we just apply black-boxish methods to the right levers. If there is a prevailing linearity, then this should make a number of interpretability methods much more straightforward as well. Solving superposition might just resolve a large degree of the entire problem of interpretability. We may actually be surprisingly close to success at automated interpretability.

Why might networks actually be linear-ish?

1.  Natural abstractions hypothesis. Most abstractions are naturally linear and compositional in some sense (why?).

2.  NNs or SGD has strong Occam's razor priors towards simplicity and linear = simple.

3. Linear and compositional representations are very good for generalisation and compression which becomes increasingly important for underfit networks on large and highly varied natural datasets. This is similar in spirit to the way that biology evolves to be modular [LW · GW].

4.  Architectural evolution. Strongly nonlinear functions are extremely hard to learn with SGD due to poor conditioning. Linear functions are naturally easier to learn and find with SGD. Our networks use almost-linear nonlinearities such as ReLU/GeLU which strongly encourages formation of nearly-linear representations

5.  Some NTK-like theory. Specifically, as NNs get larger, they move less from their initial condition to the solution, so we can increasingly approximate them with linear taylor expansions. If the default 'representations' are linear and Gaussian due to the initialisation of the network, then perhaps SGD just finds solutions very close to the initialisation which preserve most of the properties.

6.  Our brains can only really perceive linear features and so everything we successfully observe in NNs is linear too, we just miss all the massively nonlinear stuff. This is the anthropic argument and would be the failure case. We just miss all the nonlinear stuff and there lies the danger. Also, if we are applying any implicit selection pressure to the model -- for instance optimising against interpretability tools -- then this might push dangerous behaviour into nonlinear representations to evade our sensors.

 

  1. ^

    Of course the actual function the network implements cannot be completely linear otherwise we would just be doing a glorified (and expensive) linear regression.

29 comments

Comments sorted by top scores.

comment by Quintin Pope (quintin-pope) · 2023-04-24T22:39:01.856Z · LW(p) · GW(p)

Some counter evidence:

  • Kernelized Concept Erasure: concept encodings do have nonlinear components. Nonlinear kernels can erase certain parts of those encodings, but they cannot prevent other types of nonlinear kernels from extracting concept info from other parts of the embedding space.
  • Limitations of the NTK for Understanding Generalization in Deep Learning: the neural tangent kernels of realistic neural networks continuously change throughout their training. Further, neither the initial kernels nor any of the empirical kernels from mid-training can reproduce the asymptotic scaling laws of the actual neural network, which are better than predicted by said kernels.
  • Mechanistic Mode Connectivity: LMs often have non-connected solution basins, which correspond to different underlying mechanisms by which they make their classification decisions.
Replies from: beren
comment by beren · 2023-04-25T12:25:27.625Z · LW(p) · GW(p)

Thanks for these links! This is exactly what I was looking for as per Cunningham's law. For the mechanistic mode connectivity, I still need to read the paper, but there is definitely a more complex story relating to the symmetries rendering things non-connected by default but once you account for symmetries and project things into an isometric space where all the symmetries are collapsed things become connected and linear again. Is this different to that?

 

I agree about the NTK. I think this explanation is bad in its specifics although I think the NTK does give useful explanations at a very coarse level of granularity. In general, to put a completely uncalibrated number on it, I feel like NNs are probably '90% linear' in their feature representations. Of course they have to have somewhat nonlinear representations as well. But otoh if we could get 90% of the way to features that would be massive progress and might be relatively easy. 

Replies from: sharmake-farah
comment by Noosphere89 (sharmake-farah) · 2023-04-26T15:45:06.073Z · LW(p) · GW(p)

One other problem of NTK/GP theory is that it isn't able to capture feature learning/transfer learning, and in general starts to break down as models get more complicated. In essence, NTK/GP fails to capture some empirical realities.

From the post "NTK/GP Models of Neural Nets Can't Learn Features":

Since people are talking about the NTK/GP hypothesis of neural nets again, I thought it might be worth bringing up some recent research in the area that casts doubt on their explanatory power. The upshot is: NTK/GP models of neural networks can't learn features. By 'feature learning' I mean the process where intermediate neurons come to represent task-relevant features such as curves, elements of grammar, or cats. Closely related to feature learning is transfer learning, the typical practice whereby a neural net is trained on one task, then 'fine-tuned' with a lower learning to rate to fit another task, usually with less data than the first. This is often a powerful way to approach learning in the low-data regime, but NTK/GP models can't do it at all.

The reason for this is pretty simple. During training on the 'old task', NTK stays in the 'tangent space' of the network's initialization. This means that, to first order, none of the functions/derivatives computed by the individual neurons change at all; only the output function does.[1] Feature learning requires the intermediate neurons to adapt to structures in the data that are relevant to the task being learned, but in the NTK limit the intermediate neurons' functions don't change at all. Any meaningful function like a 'car detector' would need to be there at initialization -- extremely unlikely for functions of any complexity. This lack of feature learning implies a lack of meaningful transfer learning as well: since the NTK is just doing linear regression using an (infinite) fixed set of functions, the only 'transfer' that can occur is shifting where the regression starts in this space. This could potentially speed up convergence, but it wouldn't provide any benefits in terms of representation efficiency for tasks with few data points[2]. This property holds for the GP limit as well -- the distribution of functions computed by intermediate neurons doesn't change after conditioning on the outputs, so networks sampled from the GP posterior wouldn't be useful for transfer learning either.

This also makes me skeptical of the Mingard et al. result about SGD being equivalent to picking a random neural net with given performance, given that picking a random net is equivalent to running a GP regression in the wide-width limit. In particular, it makes me skeptical that this result will generalize to the complex models and tasks we care about. 'GP/NTK performs similarly to SGD on simple tasks' has been found before, but it tends to break down as the tasks become more complex.[3]

In essence, NTK/GP can't transfer learn because it stays where it's originally at in the transfer space, and this doesn't change even in the limit of NTK.

A link to the post is below:

https://www.lesswrong.com/posts/76cReK4Mix3zKCWNT/ntk-gp-models-of-neural-nets-can-t-learn-features [LW · GW]

comment by Steven Byrnes (steve2152) · 2023-04-24T19:16:38.130Z · LW(p) · GW(p)

Sorry for stupid question but when you say  is “(almost) linear”, I presume that means  and . But what are X & Y here? Activations? Weights? Weight-changes-versus-initialization? Some-kind-of-semantic-meaning-encoded-as-a-high-dimensional-vector-of-pattern-matches-against-every-possible-abstract-concept? More than one of the above? Something else?

Replies from: Jon Garcia
comment by Jon Garcia · 2023-04-24T23:25:15.035Z · LW(p) · GW(p)

For an image-classification network, if we remove the softmax nonlinearity from the very end, then  would represent the input image in pixel space, and  would represent the class logits. Then  would represent an image with two objects leading to an ambiguous classification (high log-probability for both classes), and  would represent higher class certainty (softmax temperature = ) when the image has higher contrast. I guess that kind of makes sense, but yeah, I think for real neural networks, this will only be linear-ish at best.

Replies from: steve2152
comment by Steven Byrnes (steve2152) · 2023-04-25T01:48:53.732Z · LW(p) · GW(p)

Well, averaging / adding two images in pixel space usually gives a thing that looks like two semi-transparent images overlaid, as opposed to “an image with two objects”.

Replies from: Jon Garcia
comment by Jon Garcia · 2023-04-25T03:37:00.194Z · LW(p) · GW(p)

If both images have the main object near the middle of the image or taking up most of the space (which is usually the case for single-class photos taken by humans), then yes. Otherwise, summing two images with small, off-center items will just look like a low-contrast, noisy image of two items.

Either way, though, I would expect this to result in class-label ambiguity. However, in some cases of semi-transparent-object-overlay, the overlay may end up mixing features in such a jumbled way that neither of the "true" classes is discernible. This would be a case where the almost-linearity of the network breaks down.

Maybe this linearity story would work better for generative models, where adding latent vector representations of two different objects would lead the network to generate an image with both objects included (an image that would have an ambiguous class label to a second network). It would need to be tested whether this sort of thing happens by default (e.g., with Stable Diffusion) or whether I'm just making stuff up here.

Replies from: beren
comment by beren · 2023-04-25T12:54:01.625Z · LW(p) · GW(p)

Maybe this linearity story would work better for generative models, where adding latent vector representations of two different objects would lead the network to generate an image with both objects included (an image that would have an ambiguous class label to a second network). It would need to be tested whether this sort of thing happens by default (e.g., with Stable Diffusion) or whether I'm just making stuff up here.
 

Yes this is exactly right. This is precisely the kind of linearity that I am talking about not the input->output mapping which is clearly nonlinear. The idea being that hidden inside the network is a linear latent space where we can perform linear operations and they (mostly) work. In the points of evidence in the post there is discussion of exactly this kind of latent space editing for stable diffusion. A nice example is this paper. Interestingly this also works for fine-tuning weight diffs for e.g. style transfer.

comment by johnswentworth · 2023-04-24T22:13:14.535Z · LW(p) · GW(p)

Natural abstractions hypothesis. Most abstractions are naturally linear and compositional in some sense (why?).

One of my main current hypotheses about natural abstractions is that natural summary statistics are approximately additive across subsystems. It's the same idea as "extensivity" in statistical physics, i.e. how energy and entropy are both approximately-additive across mesoscale subsystems. And it would occur for similar reasons: if not-too-close-together parts of the system are independent given some natural abstract latent variables, then we can break the system into a bunch of mesosize chunks with some space between each of them, ignore the relatively-small handful of variables in between the mesoscale chunks, and find that log probability of state is approximately additive across the chunks. That log probability is, in turn, "approximately a sufficient statistic" in some sense, because log likelihood is a universal sufficient statistic. So, we get an approximate sufficient statistic which is additive across the chunks.

... unfortunately the approximation is very loose, and more generally this whole argument dovetails with open questions about how to handle approximation for natural abstractions. So the math is not yet ready for prime time. But there is at least a qualitative argument for why we'd expect additivity across subsystems from natural abstractions.

My guess is that that rough argument is the main step in understanding why linearity seems to capture natural abstractions so well empirically.

Replies from: beren
comment by beren · 2023-04-25T13:09:06.172Z · LW(p) · GW(p)

I think this is a good intuition. I think this comes down to the natural structure of the graph and the fact that information disappears at larger distances. This means that for dense graphs such as lattices etc regions only implicitly interact through much lower dimensional max-ent variables which are then additive while for other causal graph structures such as the power-law small-world graphs that are probably sensible for many real-world datasets, you also get a similar thing where each cluster can be modelled mostly independently apart from a few long-range interactions which can be modelled as interacting with some general 'cluster sum'. Interestingly, this is how many approximate bayesian inference algorithms for factor graphs look like -- such as the region graph algorithm. ( http://pachecoj.com/courses/csc665-1/papers/Yedidia_GBP_InfoTheory05.pdf). 

I definitely agree it would be really nice to have the math of this all properly worked out as I think this, as well as the region why we see power-law spectra of features so often in natural datasets (which must have a max-ent explanation) is a super common and deep feature of the world. 

comment by Ariel Kwiatkowski (ariel-kwiatkowski) · 2023-04-24T20:58:33.434Z · LW(p) · GW(p)

Isn't this extremely easy to directly verify empirically? 

Take a neural network $f$ trained on some standard task, like ImageNet or something. Evaluate $|f(kx) - kf(x)|$ on a bunch of samples $x$ from the dataset, and $f(x+y) - f(x) - f(y)$ on samples $x, y$. If it's "almost linear", then the difference should be very small on average. I'm not sure right now how to define "very small", but you could compare it e.g. to the distance distribution $|f(x) - f(y)|$ of independent samples, also depending on what the head is.

FWIW my opinion is that all this "circumstantial evidence" is a big non sequitur, and the base statement is fundamentally wrong. But it seems like such an easily testable hypothesis that it's more effort to discuss it than actually verify it.

Replies from: TurnTrout
comment by TurnTrout · 2023-04-25T01:04:18.886Z · LW(p) · GW(p)

At least how I would put this -- I don't think the important part is that NNs are literally almost linear, when viewed as input-output functions. More like, they have linearly represented features (i.e. directions in activation space, either in the network as a whole or at a fixed layer), or there are other important linear statistics of their weights (linear mode connectivity) or activations (linear probing). 

Maybe beren can clarify what they had in mind, though.

Replies from: beren
comment by beren · 2023-04-25T12:26:44.537Z · LW(p) · GW(p)

Yes. The idea is that the latent space of the neural network's 'features' are 'almost linear' which is reflected in both the linear-ish properties of the weights and activations. Not that the literal I/O mapping of the NN is linear, which is clearly false.

 

More concretely, as an oversimplified version of what I am saying, it might be possible to think of neural networks as a combined encoder and decoder to a linear vector space. I.e. we have nonlinear function f and  g which encode the input x to a latent space z and g which decodes it to the output y -i.e. f(x) = z and g(z) = y. We the hypothesise that the latent space z is approximately linear such that we can perform addition and weighted sums of zs as well as scaling individual directions in z and these get decoded to the appropriate outputs which correspond to sums or scalings of 'natural' semantic features we should expect in the input or output.

comment by Hoagy · 2023-10-04T11:15:22.077Z · LW(p) · GW(p)

Reposting from a shortform post but I've been thinking about a possible additional argument that networks end up linear that I'd like some feedback on:

the tldr is that overcomplete bases necessitate linear representations

  • Neural networks use overcomplete bases to represent concepts. Especially in vector spaces without non-linearity, such as the transformer's residual stream, there are just many more things that are stored in there than there are dimensions, and as Johnson Lindenstrauss shows, there are exponentially many almost-orthogonal directions to store them in (of course, we can't assume that they're stored linearly as directions, but if they were then there's lots of space). (see also Toy models of transformers, sparse [LW · GW] coding work)
  • Many different concepts may be active at once, and the model's ability to read a representation needs to be robust to this kind of interference.
  • Highly non-linear information storage is going to be very fragile to interference because, by the definition of non-linearity, the model will respond differently to the input depending on the existing level of that feature. For example, if the response is quadratic or higher in the feature direction, then the impact of turning that feature on will be much different depending on whether certain not-quite orthogonal features are also on. If feature spaces are somehow curved then they will be similarly sensitive.

Of course linear representations will still be sensitive to this kind of interferences but I suspect there's a mathematical proof for why linear features are the most robust to represent information in this kind of situation but I'm not sure where to look for existing work or how to start trying to prove it.

Replies from: beren
comment by beren · 2023-10-04T20:13:02.509Z · LW(p) · GW(p)

This is an interesting idea. I feel this also has to be related to increasing linearity with scale and generalization ability -- i.e. if you have a memorised solution, then nonlinear representations are fine because you can easily tune the 'boundaries' of the nonlinear representation to precisely delineate the datapoints (in fact the nonlinearity of the representation can be used to strongly reduce interference when memorising as is done in the recent research on modern hopfield networks) . On the other hand, if you require a kind of reasonably large-scale smoothness of the solution space, as you would expect from a generalising solution in a flat basin, then this cannot work and you need to accept interference between nearly orthogonal features as the cost of preserving generalisation of the behaviour across many different inputs which activate the same vector.

Replies from: Hoagy
comment by Hoagy · 2023-10-06T11:06:57.335Z · LW(p) · GW(p)

Yes that makes a lot of sense that linearity would come hand in hand with generalization. I'd recently been reading Krotov on non-linear Hopfield networks but hadn't made the connection. They say that they're planning on using them to create more theoretically grounded transformer architectures. and your comment makes me think that these wouldn't succeed but then the article also says:

This idea has been further extended in 2017 by showing that a careful choice of the activation function can even lead to an exponential memory storage capacity. Importantly, the study also demonstrated that dense associative memory, like the traditional Hopfield network, has large basins of attraction of size O(Nf). This means that the new model continues to benefit from strong associative properties despite the dense packing of memories inside the feature space.

which perhaps corresponds to them also being able to find good linear representation and to mix generalization and memorization like a transformer?

comment by simeon_c (WayZ) · 2023-04-26T08:23:24.497Z · LW(p) · GW(p)

@beren [LW · GW] in this post [AF(p) · GW(p)], we find that our method (Causal Direction Extraction) allows to capture a lot of the gender difference with 2 dimensions in a linearly separable way. Skimming that post might of interest to you and your hypothesis. 

In the same post though, we suggest that it's unclear how much logit lens "works", to the extent that basically the direction encoding the best a same concept likely changes by a small angle at each layer, which causes two directions that best encode a concept at 15 layers of interval to have a cosine similarity <0.5.

But what seems plausible to me is that basically almost ~all of the information relevant to a feature are encoded in a very small amounts of directions, which are slightly different for each layer. 

comment by Bogdan Ionut Cirstea (bogdan-ionut-cirstea) · 2023-04-30T22:01:30.607Z · LW(p) · GW(p)

Another reason to expect approximate linearity in deep learning models: point 12 + arguments about approximate (linear) isomorphism between human and artificial representations (e.g. search for 'isomorph' in Understanding models understanding language and in Grounding the Vector Space of an Octopus: Word Meaning from Raw Text).

comment by Zach Furman (zfurman) · 2023-04-26T06:43:13.281Z · LW(p) · GW(p)

Great discussion here!

Leaving a meta-comment about priors: on one hand, almost-linear features seem very plausible (a priori) for almost-linear neural networks; on the other, linear algebra is probably the single mathematical tool I'd expect ML researchers to be incredibly well-versed in, and the fact that we haven't found a "smoking gun" at this point with so much potential scrutiny makes me suspect.

And while this is a very natural hypothesis to test, and I'm excited for people to do so, it seems possible that the field's familiarity with linear methods is a hammer that makes everything look like a nail. It's easy to focus on linear interpretability because the alternative seems too hard (a response I often get) - I think this is wrong, and there are tractable directions in the nonlinear case too, as long as you're willing to go slightly further afield.

I also have some skepticism on the object-level here too, but it was taking me too long to write it up, so that will have to wait. I think this is definitely a topic worth spending more time on - appreciate the post!

comment by rotatingpaguro · 2023-04-24T19:15:16.008Z · LW(p) · GW(p)

Note: not an NN expert, I'm speculating out of my depth.

This is related to the Gaussian process limit of infinitely wide neural networks. Bayesian Gaussian process regression can be expressed as Bayesian linear regression, they are really the same thing mathematically. They mostly perform worse than finitely wide neural networks, but they work decently. The fact that they work at all suggests that indeed linear stuff is fine, and that even actual NNs may be working in a regime close to linearity in some sense.

A counterargument to linearity could be: since the activations are ReLUs, i.e., flat then linear, the NN is locally mostly linear, which allows weight averaging and linear interpretations, but then the global behavior is not linear. Then a counter-counter-argument is: is the nonlinear region actually useful?

Combining the hypotheses/interpretations 1) the GP limit works although it is worse 2) various local stuff based on linearity works, I make the guess that maybe the in-distribution behavior of the NN is mostly linear, and going out-of-distribution brings the NN to strongly non-linear territory?

Replies from: dsj
comment by dsj · 2023-04-24T21:32:04.566Z · LW(p) · GW(p)

A key distinction is between linearity in the weights vs. linearity in the input data.

For example, the function  is linear in the arguments  and  but nonlinear in the arguments  and , since  and  are nonlinear.

Similarly, we have evidence that wide neural networks  are (almost) linear in the parameters , despite being nonlinear in the input data  (due e.g. to nonlinear activation functions such as ReLU). So nonlinear activation functions are not a counterargument to the idea of linearity with respect to the parameters.

If this is so, then neural networks are almost a type of kernel machine, doing linear learning in a space of features which are themselves a fixed nonlinear function of the input data.

comment by Bogdan Ionut Cirstea (bogdan-ionut-cirstea) · 2024-05-10T19:53:50.740Z · LW(p) · GW(p)

If this is true, then we should be able to achieve quite a high level of control and understanding of NNs solely by straightforward linear methods and interventions. This would mean that deep networks might end up being pretty understandable and controllable artefacts in the near future. Just at this moment, we just have not yet found the right levers yet (or rather lots of existing work does show this but hasn't really been normalized or applied at scale for alignment). Linear-ish network representations are a best case scenario for both interpretability and control.

For a mechanistic, circuits-level understanding, there is still the problem of superposition of the linear representations. However, if the representations are indeed mostly linear than once superposition is solved there seem to be little other obstacles in front of a complete mechanistic understanding of the network. Moreover, superposition is not even a problem for black-box linear methods for controlling and manipulating features where the optimiser handles the superposition for you.

Here's a potential operationalization / formalization of why assuming the linear representation hypothesis seems to imply that finding and using the directions might be easy-ish (and significantly easier than full reverse-engineering / enumerative interp). From Learning Interpretable Concepts: Unifying Causal Representation Learning and Foundation Models (with apologies for the poor formatting):

'We focus on the goal of learning identifiable human-interpretable concepts from complex high-dimensional data. Specifically, we build a theory of what concepts mean for complex high-dimensional data and then study under what conditions such concepts are identifiable, i.e., when can they be unambiguously recovered from data. To formally define concepts, we leverage extensive empirical evidence in the foundation model literature that surprisingly shows that, across multiple domains, human-interpretable concepts are often linearly encoded in the latent space of such models (see Section 2), e.g., the sentiment of a sentence is linearly represented in the activation space of large language models [96]. Motivated by this rich empirical literature, we formally define concepts as affine subspaces of some underlying representation space. Then we connect it to causal representation learning by proving strong identifiability theorems for only desired concepts rather than all possible concepts present in the true generative model. Therefore, in this work we tread the fine line between the rigorous principles of causal representation learning and the empirical capabilities of foundation models, effectively showing how causal representation learning ideas can be applied to foundation models.


Let us be more concrete. For observed data X that has an underlying representation Zu with X = fu(Zu) for an arbitrary distribution on Zu and a (potentially complicated) nonlinear underlying mixing map fu, we define concepts as affine subspaces AZu = b of the latent space of Zus, i.e., all observations falling under a concept satisfy an equation of this form. Since concepts are not precise and can be fuzzy or continuous, we will allow for some noise in this formulation by working with the notion of concept conditional distributions (Definition 3). Of course, in general, fu and Zu are very high-dimensional and complex, as they can be used to represent arbitrary concepts. Instead of ambitiously attempting to reconstruct fu and Zu as CRL [causal representation learning] would do, we go for a more relaxed notion where we attempt to learn a minimal representation that represents only the subset of concepts we care about; i.e., a simpler decoder f and representation Z—different from fu and Zu—such that Z linearly captures a subset of relevant concepts as well as a valid representation X = f(Z). With this novel formulation, we formally prove that concept learning is identifiable up to simple linear transformations (the linear transformation ambiguity is unavoidable and ubiquitous in CRL). This relaxes the goals of CRL to only learn relevant representations and not necessarily learn the full underlying model. It further suggests that foundation models do in essence learn such relaxed representations, partially explaining their superior performance for various downstream tasks.
Apart from the above conceptual contribution, we also show that to learn n (atomic) concepts, we only require n + 2 environments under mild assumptions. Contrast this with the adage in CRL [41, 11] where we require dim(Zu) environments for most identifiability guarantees, where as described above we typically have dim(Zu) ≫ n + 2.'


'The punchline is that when we have rich datasets, i.e., sufficiently rich concept conditional datasets, then we can recover the concepts. Importantly, we only require a number of datasets that depends only on the number of atoms n we wish to learn (in fact, O(n) datasets), and not on the underlying latent dimension dz of the true generative process. This is a significant departure from most works on causal representation learning, since the true underlying generative process could have dz = 1000, say, whereas we may be interested to learn only n = 5 concepts, say. In this case, causal representation learning necessitates at least ∼ 1000 datasets, whereas we show that ∼ n + 2 = 7 datasets are enough if we only want to learn the n atomic concepts.'

comment by Review Bot · 2024-03-04T01:50:44.662Z · LW(p) · GW(p)

The LessWrong Review [? · GW] runs every year to select the posts that have most stood the test of time. This post is not yet eligible for review, but will be at the end of 2024. The top fifty or so posts are featured prominently on the site throughout the year.

Hopefully, the review is better than karma at judging enduring value. If we have accurate prediction markets on the review results, maybe we can have better incentives on LessWrong today. Will this post make the top fifty?

comment by Joseph Van Name (joseph-van-name) · 2023-10-13T19:45:27.404Z · LW(p) · GW(p)

I trained a (plain) neural network on a couple of occasions to predict the output of the function  where  are bits and  denotes the XOR operation. The neural network was hopelessly confused despite the fact that neural networks usually do not have any trouble memorizing large quantities of random information. This time the neural network could not even memorize the truth table for XOR. While the operation  is linear over the field , it is quite non-linear over . The inability for a simple neural network to learn this function indicates that neural networks are better at learning when they are not required to stray too far away from linearity.

comment by Joseph Van Name (joseph-van-name) · 2023-10-13T19:31:48.262Z · LW(p) · GW(p)

Neural networks with ReLU activation are the things you obtain when you combine two kinds of linearity, namely the standard linearity that we all should be familiar with and tropical linearity.

Give  two operations  defined by setting .  Then the operations  are associative, commutative, and they satisfy the distributivity property . We shall call the operations  tropical operations on .

We can even perform matrix and vector operations by replacing the operations  with their tropical counterparts . More explicitly, we can define the tropical matrix addition and multiplication operations by setting

 and

.

Here, the ReLU operation is just , and if  is the zero vector and  is a real vector, then , so ReLU does not even rely on tropical matrix multiplication.

Of course, one can certainly construct and train neural networks using tropical matrix multiplication in the layers of the form  where  are weight matrices and  are bias vectors, but I do not know of any experiments done with these sorts of neural networks, so I am uncertain of what advantages they offer.

Since ReLU neural networks are a combination of two kinds of linearity, one might expect for ReLU neural networks to behave nearly linearly. And it is not surprising that ReLU networks look more like the standard linear transformations than the tropical linear transformations since the standard linear transformations in a neural network are far more complex than the ReLU. ReLU just provides the bare minimum non-linearity for a neural network without doing anything fancy.

comment by dr_s · 2023-04-25T09:59:37.316Z · LW(p) · GW(p)

I wonder if perhaps a weaker and more defensible thesis was that deep learning models are mostly linear (and maybe the few non-linearities could be separated and identified? Has anyone tried applying ReLUs only to some outputs, leaving most of the rest untouched?). It would seem really weird to me if they really were linear. If that was the case it would mean that:

  1. activation functions are essentially unnecessary
  2. forget SGD, you can just do one shot linear regression to train them (well, ok, no, they're still so big that you probably need gradient descent, but it's a much more deterministic process if it's a linear function that you're fitting)

You wouldn't even need multiple layers, just one big tensor. It feels weird that an entire field might have just overlooked such a trivial solution.