Appendix: Interpretable by Design - Constraint Sets with Disjoint Limit Points

post by Ronak_Mehta · 2025-05-08T21:09:23.689Z · LW · GW · 0 comments

Contents

  Maps for Simplex-Valued Vectors
    Linear Maps
    Other Layers and Network Pieces
  Other Notes, Connections, References
None
No comments

A bunch of other ideas that I couldn't format well for the main post here [LW · GW], are relevant, but were blocking me from just sharing the main ideas. This is significantly more messy and rough, with random pieces all over the place.

Maps for Simplex-Valued Vectors

Once we have "stuff on the simplex", we still need to do computation. We need some differentiable maps  that do something.

Linear Maps

What is the equivalent "Linear Map"? What is  for ? We can get this by requiring  be in the space of stochastic matrices.

We can easily project a weight matrix in  to be a stochastic matrix by normalizing the output dimension using a softmax. Then, if we guarantee the input is on the simplex, the output will be as well! We also have guarantees on reachability, i.e., we don't lose any expressive power, and can always find a parameter setting that can take any input to any output on the simplex.

In practice,

def forward(self, x):
	stochastic_weight = self.softmax(self.weight)
	output = F.linear(x, stochastic_weight)
	return output

Other Layers and Network Pieces

Residual Connections. We also need binary maps, . As a first pass, two operations on the simplex make sense here. First, Elementwise multiplication plus renormalization makes sense here as a default, i.e., 

 If  is a uniform distribution then , so a "no-op" can be the default or initial operation. This could also be seen as a bias term in some sense.

Alternatively, we can map back to "real space" and do a typical addition there, and then map back to the simplex: 

 We can also add a fixed or learnable bias term to this, which can act as a trade-off parameter or as a way to understand the "influence" of certain parts of the network.

Attention. This should compose naturally because all of the operations are linear, but there is probably something more to do here, e.g., the outer product of the query-key vectors may correspond more closely to a joint distribution? Perhaps there is something more interesting here where token positions are more appropriately weighted by some optimal transport-based cost function.

Other Notes, Connections, References

0 comments

Comments sorted by top scores.