Appendix: Interpretable by Design - Constraint Sets with Disjoint Limit Points
post by Ronak_Mehta · 2025-05-08T21:09:23.689Z · LW · GW · 0 commentsContents
Maps for Simplex-Valued Vectors Linear Maps Other Layers and Network Pieces Other Notes, Connections, References None No comments
A bunch of other ideas that I couldn't format well for the main post here [LW · GW], are relevant, but were blocking me from just sharing the main ideas. This is significantly more messy and rough, with random pieces all over the place.
Maps for Simplex-Valued Vectors
Once we have "stuff on the simplex", we still need to do computation. We need some differentiable maps that do something.
Linear Maps
What is the equivalent "Linear Map"? What is for ? We can get this by requiring be in the space of stochastic matrices.
We can easily project a weight matrix in to be a stochastic matrix by normalizing the output dimension using a softmax. Then, if we guarantee the input is on the simplex, the output will be as well! We also have guarantees on reachability, i.e., we don't lose any expressive power, and can always find a parameter setting that can take any input to any output on the simplex.
In practice,
def forward(self, x):
stochastic_weight = self.softmax(self.weight)
output = F.linear(x, stochastic_weight)
return output
Other Layers and Network Pieces
Residual Connections. We also need binary maps, . As a first pass, two operations on the simplex make sense here. First, Elementwise multiplication plus renormalization makes sense here as a default, i.e.,
If is a uniform distribution then , so a "no-op" can be the default or initial operation. This could also be seen as a bias term in some sense.
Alternatively, we can map back to "real space" and do a typical addition there, and then map back to the simplex:
We can also add a fixed or learnable bias term to this, which can act as a trade-off parameter or as a way to understand the "influence" of certain parts of the network.
Attention. This should compose naturally because all of the operations are linear, but there is probably something more to do here, e.g., the outer product of the query-key vectors may correspond more closely to a joint distribution? Perhaps there is something more interesting here where token positions are more appropriately weighted by some optimal transport-based cost function.
Other Notes, Connections, References
- There is no scaling, things can't blow up, they can just become more or less discrete.
- We can add temperatures in many different places to adjust "interpretability" and trade-offs.
- There is a strong notion of convexity: all activations are on the simplex and so linear combinations of activations are explicitly "valid".
- If you bias the linear layer initializations to some decaying exponential over the dimensions, will this prefer a basis that has a certain type of distribution (low entropy if possible?)
- The first half of this blog post has a good introduction to some of these ideas: https://iclr-blog-track.github.io/2022/03/25/non-monotonic-autoregressive-ordering/
- We can set our hidden dimension is equal to our vocabulary size and preference the basis by initializing our stochastic matrices close to identity. Does this lead to more interpretable models because everything is some combination or mix of tokens?
- Probability! everything is probability, so all of those great keywords can get slapped on (Bayesian updating, likelihood estimation, etc.)
- Variational inference: https://proceedings.mlr.press/v84/linderman18a.html
- Gradient updates on the simplex and Birkhoff polytope probably have some connection with information updates, maybe likelihood updates?
- If we use extreme temperatures to push all elements toward corners, we effectively have "hard" gradient updates that move us explicitly between discrete points. Maybe this is connected to some existing work on discrete backpropogation? https://writings.stephenwolfram.com/2024/08/whats-really-going-on-in-machine-learning-some-minimal-models/
- There's probably a lot of new optimization schemes that need to be explored here. Adam, SGD, etc. are often all focused on real-valued parameter spaces with prior distributions fixed to isotropic Normals. We probably can re-derive new optimization schemes based on Dirichlet priors alongside tightly coupled information regularizers based on entropy priors.
- Some stuff on convex duality of attention https://proceedings.mlr.press/v162/sahiner22a.html
- The sofmax network stuff by Anthropic is related, but not exactly the same. Their results might suggest that practically something like this is unlikely to work, but IMO I think they didn't push hard enough in this direction. https://transformer-circuits.pub/2022/solu/index.html
- Learning Latent Permutations with Gumbel-Sinkhorn Networks. Gonzalo Mena, David Belanger, Scott Linderman, Jasper Snoek. https://arxiv.org/abs/1802.08665
- DAG Learning on the Permutahedron. Valentina Zantedeschi, Luca Franceschi, Jean Kaddour, Matt Kusner, Vlad Niculae. https://openreview.net/forum?id=OxNQXyZK-K8
0 comments
Comments sorted by top scores.