Sparse trinary weighted RNNs as a path to better language model interpretability

post by Am8ryllis · 2022-09-17T19:48:24.330Z · LW · GW · 13 comments

Contents

  Introduction
  Objections
    We cannot train, quantize and sparsify RNNs of comparable accuracy to modern transformers.
    Even if we have such a SoTA LLM, it would not improve interpretability.
    LLM interpretability does not ultimately help with alignment.
    Trinarised LLMs will make inference very cheap and this will accelerate capabilities more than they improve interpretability and therefore be net negative.
  Conclusion
  Acknowledgments
None
13 comments

Epistemic status: Strongly arguing for what I feel is a neglected approach. May somewhat overstate the case and fail to adequately steelman counter arguments. I hope and expect that readers will point out flaws in my logic.

Introduction

Currently, large transformers with dense floating point weights are state of the art in language modeling. Despite recent progress by Anthropic and others, they remain difficult to understand.

Why are they hard to understand?

We can fix these issues.

Once we quantize our RNN to trinary weights we can replace addition and tanh activations with adder trees and digital comparators and apply boolean logic simplification tools to further simplify the logic. Now we have a fully combinatorial logic finite state machine[3]. Understanding the behavior of such a machine intuitively feels far more tractable than understanding current transformers. Once we can understand the internal workings of large language models, we can likely use this understanding to improve safety/alignment.

 

Objections

We cannot train, quantize and sparsify RNNs of comparable accuracy to modern transformers.

RNNs are less powerful than transformers, hard to train on GPUs, and quantization kills accuracy. What good is it if we can understand small toy language models? They are not what is being deployed. If it costs 10x more to train, few will actually use it and SoTA will stay uninterpretable.

A: We use transformers instead of large RNNs because they are easier to train on GPUs, not because they are a fundamentally better architecture. RWKV-LM appears to be making progress on large RNN training on GPUs. TernaryBERT is one example of successful trinary quantization of transformers, (although it does not trinarize activations). Some work has also been done on trinary weights in RNNs. I suspect that both could be significantly improve upon with further effort. Trinary weights are sparse binary, and once we convert activations to binary, more weights will die.

Even if we have such a SoTA LLM, it would not improve interpretability.

It is still a huge mess of gates. This is hard to understand.

A: If the whole model is logic gates end to end, we can potentially apply SAT solvers to it. This feels like a significant improvement over current transformer interpretability. Also, I suspect[4] that if we sparsify it sufficiently and apply logic simplification transformations, the logic graph will tend to fall apart into mostly independent modules which can be studied independently[5].

LLM interpretability does not ultimately help with alignment.

A: We have little idea what sort of model will end up being dangerous over the coming decade, but currently, LLMs are the frontier of general reasoning capabilities. Practicing on what we currently have seems better than not practicing.

Trinarised LLMs will make inference very cheap and this will accelerate capabilities more than they improve interpretability and therefore be net negative.

A: Assuming that it does not make training cheaper, cheaper inference is probably not a large accelerant to AI timelines? (I am not at all sure about this.)

Conclusion

I make four assertions:

I am looking for disagreements with any of these assertions, but particularly the last two assertions[6].

Acknowledgments

I thank Joel Burget and Nathan Showell for insightful discussion and comments on a draft of this post.

  1. ^

    The mask probably does provide some additional sense of ordering. I do not yet understand transformers well enough to understand the full implications of the mask. But still, this is not very helpful for interpretability.

  2. ^

    Or LSTM. When the state is binary, the distinction becomes less clear.

  3. ^

    A traditional RNN is a continuous finite state machine, and if we quantize the state to binary, it becomes a proper finite finite state machine.

  4. ^

     I do not have a solid argument for this suspicion, only vague intuition about generalization requiring information destruction. I intend to make it fall apart, by whatever means are needed.

  5. ^

    While SAT solvers may have difficulty scaling to billions or even millions of edges, force directed graph layout algorithms have been able to scale to hundreds of millions of edges on a single GPU for some years, allowing us to visually identify interesting subgraphs to which SAT solvers may be applied.

  6. ^

     I am aware that both generating quantized LLMs, and interpreting them once generated, have low probability of success, but I intend to attempt both as long as A) the attempt will not cause significant harm, and B) the result if successful is actually useful.

13 comments

Comments sorted by top scores.

comment by Charlie Steiner · 2022-09-17T20:30:20.773Z · LW(p) · GW(p)

Of all of these claims, I am most on board the notion that RNNs might have better interpretability properties than transformers, and have some degree of competitiveness. But I disagree with basically all the other things.

If you want to be competitive with SOTA, a more quantized net will need a lot more neurons (have you read the new article on superposition?). I am quite confident that this would not be more interpretable, and that you would still need specialized tools to get anywhere.

Replies from: Am8ryllis
comment by Am8ryllis · 2022-09-18T00:55:51.507Z · LW(p) · GW(p)

I'm glad we agree that RNNs are nice.

So if I understand correctly, you are saying:

  • A trinary weighted LLM with accuracy comparable to Chinchilla (70B weights) would need significantly more (dense) trits, let's say >140B?
  • An LLM with significantly more trit weights is less interpretable than an LLM with a less quantity of float weights?
  • Do you disagree regarding harm if successful?

Consider that most of the trits will be 0 and thus removable, and that we will be replacing the activations with boolean logic and applying logic simplification transformations to discard even more nodes. The number of trits in the weights is not the same as the number of gates in the resulting logic graph. I think it plausible that even if we are forced to start with a LLM of greater than chinchilla size to achieve comparable accuracy, after sparsification and logic simplification we will end up with significantly fewer gates. Would such a LLM still be less interpretable?

 

If you want to be competitive with SOTA, a more quantized net will need a lot more neurons (have you read the new article on superposition?).

I agree that lower precision weights will likely requires somewhat more weights, however I do not see the connection to superposition. It is possible to embed >n features in n bits (assuming some feature sparsity). The features will be on the unit corners, but most of the area is there anyway, I do not think it would be a very large decrease in available space.

 

and that you would still need specialized tools to get anywhere.

I agree with this. I am currently attempting to build the needed tooling. It's nontrivial work, but I think it is doable.

Replies from: Charlie Steiner
comment by Charlie Steiner · 2022-09-18T02:42:15.495Z · LW(p) · GW(p)

If you want to be competitive with SOTA, a more quantized net will need a lot more neurons (have you read the new article on superposition?).

I agree that lower precision weights will likely requires somewhat more weights, however I do not see the connection to superposition. It is possible to embed >n features in n bits (assuming some feature sparsity). The features will be on the unit corners, but most of the area is there anyway, I do not think it would be a very large decrease in available space.

The more quantized the weights and activations, the harder it is to embed >n features in n bits without them interfering with each other - interference that stops you from adding together features in semantically sensible ways, or decomposing a state into features. So those small bits aren't just being wasted - at least I think not, in most parts of modern NNs.

Replies from: nathan-helm-burger
comment by Nathan Helm-Burger (nathan-helm-burger) · 2022-09-19T03:24:13.927Z · LW(p) · GW(p)

I agree that I think you would need a LOT more weights. Kind of a ridiculous seeming amount perhaps, like maybe 10000x or more. But I actually think that's a potential strength. I think that reducing super-position and having a very sparse wide network with only a small portion of that network active at any one time could actually be made to be both compute efficient and interpretable. If each of those sparse weights does fewer things, then it becomes much easier to label those specific things, and to see what logic went into any given decision.

As for whether it's computationally tractable... There's good reason to think that that's possible. The brain is basically a very wide sparse net that's quite computationally efficient. Here's a recent interview from Yannic Kilcher on the subject: 

My view is slightly different, in that I don't think we should prune down the networks and leave them pruned. I think we want absurdly huge networks with clear labels. I'm currently imagining something that's like a mixture of experts implemented in this giant wide network, but the experts have significant overlap with each other. So maybe creating this with a series of learn-prune-learn-prune-learn to build up an increasing complex very sparse space.

If we can get the unwanted cognition/behaviors to sit entirely in their own section of weights, we can then ablate the unwanted behaviors without losing wanted capability. That's my hope anyway.

Replies from: Am8ryllis
comment by Am8ryllis · 2022-09-19T17:30:46.083Z · LW(p) · GW(p)

I agree that reducing superposition is probably valuable even if it requires a significantly larger network. I still don't understand why the transition from float to binary would cause a dramatic reduction in superposition capacity. But if it does prevent superposition, great! I'll just give it more parameters as needed. But if we still get superposition, I will need to apply other techniques to make it stop.

(I have not yet finished my closer re-read of Toy Models of Superposition after my initial skimming. Perhaps once I do I will understand better.)

Hopefully in a few months I will have empirical data regarding how much more neurons we need. Then I can stop hand waving about vague intuitions.

 

If we can get the unwanted cognition/behaviors to sit entirely in their own section of weights, we can then ablate the unwanted behaviors without losing wanted capability. That's my hope anyway.

My thoughts and hope as well.

comment by Nathan Helm-Burger (nathan-helm-burger) · 2022-09-22T19:16:38.033Z · LW(p) · GW(p)

Curious to hear your feelings on Linear Transformers: 

Replies from: None
comment by [deleted] · 2022-09-23T23:34:01.832Z · LW(p) · GW(p)

Do you happen to know how this compares with https://github.com/BlinkDL/RWKV-LM which is described as an RNN with performance comparable to a transformer / linear attention?

Replies from: nathan-helm-burger
comment by Nathan Helm-Burger (nathan-helm-burger) · 2022-09-24T23:47:30.494Z · LW(p) · GW(p)

I don't know, but I'd love to know! If you find out, please tell me!

comment by Jonathan_Graehl · 2022-09-17T20:33:46.837Z · LW(p) · GW(p)

how is a discretized weight/activation set amenable to the usual gradient descent optimizers?

Replies from: Am8ryllis
comment by Am8ryllis · 2022-09-17T23:22:37.692Z · LW(p) · GW(p)

Discretized weights/activation are very much not amenable to the usual gradient descent. :) Hence the usual practice is to train in floating point, and then quantize afterwords. Doing this naively tends to cause a big drop in accuracy, but there are tricks involving gradually quantizing during training, or quantizing layer by layer.

comment by Nathan Helm-Burger (nathan-helm-burger) · 2022-09-17T20:16:12.902Z · LW(p) · GW(p)

I agree that alternative, more interpretable, architectures are a plausible path to alignment. I think maybe there's some tradeoff between alignment tax (e.g. reduced ease of training, diversion from mainstream path) and increased interpretability. I, myself, am working on an experiment with unusually sparse nets with architecture much closer to (and hopefully interoperable with) a GPT-like transformer.

Replies from: Am8ryllis
comment by Am8ryllis · 2022-09-17T23:33:39.640Z · LW(p) · GW(p)

I am hopeful that we can get interpretability and easy training. But you may well be right.

After skimming some of your progress reports, I am very excited about your sparse nets work!

Replies from: nathan-helm-burger
comment by Nathan Helm-Burger (nathan-helm-burger) · 2022-09-19T03:18:40.348Z · LW(p) · GW(p)

Thanks! And I'm excited to hear more about your work. It sounds like if it did work, the results would be quite interesting.