Using the probabilistic method to bound the performance of toy transformers
post by Alex Gibson · 2025-01-21T23:01:38.067Z · LW · GW · 0 commentsContents
Introduction: max-of-k transformer: Equation for logits: Hyperplane form: Chernoff bounds: Full bounds: When to use probabilistic bound: None No comments
Introduction:
Transformers are statistical artefacts. If a model achieves 99% accuracy, it might have learnt an algorithm that relies upon some property of the input data that only holds 99% of the time.
Let's say that our train input to the transformer is 100 independent unbiased coin tosses. Then the transformer might implicitly rely upon there being between 35 and 65 heads, because there is a ~ 0.17% chance that there would be a more extreme number of heads.
This potentially makes transformers vulnerable to adversarial inputs. But it also tells us how we should be approaching formal proofs. We should not necessarily expect transformers to learn algorithms that work on every input, but instead for them to learn algorithms which work on the "typical set" of inputs, like "the set of inputs with between 35 and 65 heads".
Then we can bound transformer performance in two stages:
1.) Identify a "typical set" of inputs, where we can show the transformer gets high accuracy
2.) Bound the volume of this "typical set" using probabilistic techniques.
max-of-k transformer:
To demonstrate this approach, I derive bounds on the performance of a toy one-layer attention-only model, the max-of-k transformer. This transformer takes as input numbers between and , and is trained to output the maximum number in the sequence. It is a single layer transformer with a single attention head, and has no layer norm or MLP layer.
Intuitively, the way the transformer solves this task is by using its QK circuit to attend more to tokens the bigger they are, and then using its OV circuit to copy the tokens it attends to. So it will attend the most to the biggest token, and then copy this token the most, so that it outputs the maximum token as its highest logit.
But there are adversarial inputs where toy transformers trained on this task fail. For instance, take the sequence . A sequence this extreme is super unlikely to appear in the training data, so the model won't be prepared to handle it. Even though the QK circuit attends slightly more to an individual token than , the huge number of 's in a single input sequence can overwhelm the transformer and lead it to erroneously outputting rather than .
We'd like a condition on the input sequence, which measures how "extreme" it is, which is simple enough that we can use probabilistic techniques to estimate how many "extreme sequences" there are.
For max-of-k, we can get a pretty simple condition by working through the algebra.
Equation for logits:
Let's say we have an input sequence of tokens between and , say .
Then let be the exponentiated attention score at .
The output from the single attention head is .
This accounts for the EVO circuit and the PVO circuit. But we also have the direct circuit due to E and P, which we can bring inside the sum, so that our output logit at token k is:
We are only interested in the relative sizes of logits, so we can cancel out the softmax denominator, and just look at
.
Write for and for .
Hyperplane form:
If we restrict to sequences whose maximum is M, then we are interested in finding conditions on having for all between and .
Now
Now we fix as well as fixing M.
Then .
Now bound . Then implies:
Finally this is in a form we can work with, because the right hand side is just the set of sequences which lie above a certain hyperplane in sequence space.
Chernoff bounds:
It's in general quite hard to determine precisely the size of discrete sets which lie above hyperplanes, but we can bound their size from above using Chernoff bounds.
Let
Then for arbitrary , we have:
.
Where the last equality comes from independence of ,,..,. Although technically because we fixed we don't strictly have independence, but we can fix this with some bookkeeping by partitioning sequences by the positions which have 's in.
There is no closed form for optimal , but we can use gradient descent to obtain a good numerical value for .
We can sum across different values of to get a probability just in terms of and .
Full bounds:
Now we have
And , where is our chernoff bound from before, and we have used a union bound.
We should expect the union bound to work well, because there will tend to only be one which has a greater logit than , if there exists such a . So we get a lower bound on .
Then we have:
,
and we can get a lower bound on this.
When to use probabilistic bound:
This bound works best for longer sequences, because for shorter sequences it's easier to manually enumerate the sequences which lie above the hyperplane. The probabilistic method easily adapts to longer sequences, where manually enumerating sequences quickly becomes unfeasible.
0 comments
Comments sorted by top scores.