Ambiguous out-of-distribution generalization on an algorithmic task

post by Wilson Wu (wilson-wu), Louis Jaburi (Ansatz) · 2025-02-13T18:24:36.160Z · LW · GW · 0 comments

Contents

  Introduction
    Setup
  Experiments
    Ambiguous grokking
      Grokking either group
      Grokking the intersect
      Grokking only one group
      No grokking
    Measuring complexity
      Complexity of the grokked solution
      Complexity over time
    Determination and differentiation
      Perturbation sensitivity
      Total variation
      Determination across distribution shift
      Training Jacobian
  Discussion
None
No comments

Introduction

It's now well known that simple neural network models often "grok" algorithmic tasks. That is, when trained for many epochs on a subset of the full input space, the model quickly attains perfect train accuracy and then, much later, near-perfect test accuracy. In the former phase, the model memorizes the training set; in the latter, it generalizes out-of-distribution to the test set.

In the algorithmic grokking literature, there is typically exactly one natural generalization from the training set to the test set. What if, however, the training set were instead under-specified in such a way that there were multiple possible generalizations? Would the model grok at all? If so, which of the generalizing solutions would it choose? If the model followed Occam's razor, it would choose the simplest solution -- but what does "simplest" mean here? We explore these questions for the task of computing a finite group operation.

Setup

This section assumes some basic familiarity with group theory and can be skipped. The point is just that each model is trained on the intersection of two datasets  and ; the intersection of the two test sets always has size 

In existing work on grokking finite group operations (e.g. Chughtai et al.), a one-hidden-layer MLP model with two inputs  is trained on the operations of a finite group . The model takes as input a pair of elements , encoded as one-hot vectors in , and is expected to output logits maximized at the product  Thus, the input space is the set of all pairs , and the model is evaluated on its accuracy over the entire test space: 

Previous work finds that training on an iid subsample of 40% of the total input space (so the training set has  points) is enough to grok the full multiplication table for various choices of . (The most well-studied choice is the symmetric group .) In our setup, we leave everything the same except for the choice of training set. We now choose two groups  and  such that . Thinking of the two groups as two operations  and  on the same underlying set of elements, our ambiguous training set is the set of all pairs of elements such that the two group operations agree:

Note then that  can be completed to the full multiplication table for either  or 

To ensure that there are enough elements in  for grokking to be possible, we need to construct  and  with large overlap. One way is to set  and  for some group  and . Then, by construction, whenever ,

Thus, the overlap between  and  is at least 50%. (In fact, it is generally somewhat larger than 50%, owing to the fixed points of .) All examples we discuss will be of this form.

Experiments

Ambiguous grokking

Grokking either group

We run the training setup described above [LW(p) · GW(p)] with  and . The intersection between the two groups' multiplication tables has size . We use this entire intersection as the training set and train 100 models with varying initialization seeds, all other hyperparameters held fixed. In this setup, the vast majority of models do not fully grok either solution (~90%), and instead just memorize . However, we do find both models that grok  (~4%) and  (~6%).[1]

Accuracy of 100 models over training on both groups.
Accuracy of 100 models over training on 
Accuracy of 100 models over training on 

Grokking the intersect

Although models often fail to grok either of the two groups, they always successfully grok the intersection: when we train models on an iid random 60% subset of the intersection (so 33% of the full input space), we find that they always attain full accuracy on the entire intersection and in some cases full accuracy on either  or .

Accuracy of 100 models trained on 60% iid subset of intersection between  and , evaluated by accuracy on . Notice that all models first memorize the training set (33% of input space) then grok the full intersection (55% of input space). Some models then proceed to grok the entire input space.

Grokking only one group

When we set  and , where  maps  (in this case, the intersection size is 75%), the model only groks . Our intuition is that , being commutative, is much simpler than , which is not, and thus may be preferred by the model. This may just be a coincidence, however: we have not been able to find an intrinsic measure of group complexity such that models always prefer simpler groups. See more discussion below [LW · GW].

Accuracy of 100 models over training on 
Accuracy of 100 models over training on 

No grokking

When we set  and  (intersection size 51%) the model never groks either group. However, this example is a little unusual: the proportion of label classes in the intersect training set is non-uniform, and thus differs from the proportion over the entire input space (for either  and ). We speculate that this class imbalance may be the reason for lack of grokking.

Measuring complexity

Complexity of the grokked solution

Intuitively, in cases where the model may grok either of two generalizing solutions, it should prefer the simpler of the two. There are two distinct things that we could mean by a solution implemented by a model being simple:

To test whether LLC or some measure of circuit efficiency is able to predict which group the model groks, we train 10,000 models on the intersection of  and . In this setting, we again find that most models don't grok (93%). Among those that do grok there is a strong preference for  (6.5%) over  (0.5%). We then record the training loss, weight norm, and estimated LLC of all models at the end of training.

All three measures are somewhat predictive of what group is grokked, in a direction that aligns with our intuition that the model prefers simpler solutions. Models that grok  (6.5%) tend to have lower training loss, weight norm, and LLC, and models that grok  (0.5%) tend to have larger LLC and training loss, though still lower than what is typical for models that do not grok at all. However, we also find plenty of examples of models that do not grok either group yet still have low LLC and training loss. Possibly, these models are finding solutions that do not correspond to either of the two groups (or any group), yet still are "simple" in some appropriate sense.[3]

Surprisingly, however, we observe that final training loss and weight norm are moderately correlated, and that final training loss and LLC are highly correlated.

Training loss vs weight norm  for 10,000 models trained on the intersect of  and 
Training loss vs LLC for a subset of 10,000 models trained on the intersect of  and 

The correlation between training loss and LLC is especially unexpected to us, and we do not have a good explanation for its presence. Since (to our knowledge) this correlation has not been noted in other settings, we suspect that it is a quirk specific to 1-hidden layer MLPs trained on groups. In any case, while our results are suggestive that models prefer simpler solutions as measured by some combination of circuit efficiency and LLC, this correlation means that we cannot disentangle the two measures in the groups setting.

Complexity over time

Besides checking model complexity at the end of training, we also measure it across training time. For  and , the LLC tracks grokking in the sense that, whenever the model groks either of the two groups, the LLC decreases sharply at the same time that the test accuracy rises. For models that do not grok either group, we observe both examples where the LLC stays large throughout training and examples where it falls. As aforementioned, we speculate that these cases correspond to the model learning some simple solution distinct from either of the two groups.

Accuracy on both groups along with LLC for example model instances that grok  (1),  (2), and neither (3, 4).

Are there any complexity measures that track grokking over time better than the LLC? We check those listed in Jiang et al. For simplicity, we measure these across checkpoints for a model trained on an iid subset of  instead of an intersect set -- these models consistently grok. We notice that the LLC estimates in the iid case tend to be smoother over time than in the intersection experiments.

Complexity measures over training on iid subset of . See Jiang et al. for definition of the metrics measured.

From these plots, it appears that LLC (bottom left) best tracks the generalization error, measured as the difference between train accuracy and test accuracy. Inverse margin and training loss (last two plots) also do well (perhaps this is related to the high correlation between training loss and LLC at the end of training, discussed above [? · GW]), but they are both large at the start of training, when generalization error is low because both training and test accuracy are low. The LLC is correctly low both at the beginning of training, before memorization, and at the end, after grokking.

Determination and differentiation

Perturbation sensitivity

Somewhat separately from the previous investigations, one might wonder when the model "decides" which of the groups (if any) it will grok. In a literal sense, the answer is that the model's fate is determined at initialization, because in our experiments there is no stochasticity during training (we use full-batch gradient descent). However, this is not really a satisfying answer. Rather, when we say that the model has "decided" on a future solution, we should expect that, from the decision point onwards, its decision is robust: small perturbations in training should not be able to make the model "change its mind". 

Hence, to measure the point at which a model decides its fate, we make small Gaussian perturbations to the model weights at evenly spaced checkpoints throughout training. We then continue training anew from these perturbed weights. We find evidence that the model makes a decision not at initialization but still well before its choice is apparent in its test-set behavior.[4]

Sensitivity to perturbations at initialization. Model trained on intersect of  and .
Sensitivity to perturbations throughout training. Perturbation size is 0.01.
Test set accuracy on (blue) and  (orange)

The example above is particularly interesting in that the model briefly veers towards  around epoch 600, corresponding to a bump in perturbation sensitivity, before returning on its path towards .[5]

Total variation

For models trained on the intersection of two groups, we notice that those that grok either of the two groups tend to have more stable outputs late in training than those that grok neither. We quantify this by measuring total variation in logits over a fixed training interval:

where  are the model parameters at epoch . For models trained on the intersection of  and  for 3000 epochs, we measure .

 of models trained on he intersection of  and 

Models that grok either of  always "lock in" their solutions -- the functions they compute no longer change after grokking. Those that grok neither often continue to oscillate in logit space even late in training.[6] However, similarly to the case with the complexity measure experiments [LW · GW], there are many examples of models that grok neither and yet still have zero total variation, possibly corresponding to simple solutions distinct from both  and  that we are unaware of.

Note also that this is a purely test-set behavior. On the training set, all models have zero total variation by the end of training: once they attain perfect training accuracy, they no longer vary their training-set outputs.

Determination across distribution shift

As seen above, the models that are trained on an intersection dataset and eventually grok have their fate stably determined relatively early in the training process. Is this stability a property of all possible parameter settings that implement the same solution? We investigate this question by first training models on iid data subsampled from  so that the models consistently grok; we then "transplant" these models to the intersection of  and  and resume training. Perhaps surprisingly, many model instances (7%) partially ungrok , while retaining perfect accuracy on the intersection. A few instances then proceed to regrok , returning to perfect test accuracy later in training.[7] Repeating the same experiment with the roles of  and  swapped results in the same behavior (8% ungrokked).

Models first trained on iid data from  until grokked, then on the intersection of  and Plot is of accuracy on  during the intersection training. 

Training Jacobian

The training Jacobian (Belrose & Scherlis) is the Jacobian of the training map from initial model parameters to final model parameters.[8] Thus, if the model has  parameters, then the training map is of type  and the training Jacobian is a matrix  In experiments with 25 epochs of training and without weight decay (hence no grokking), Belrose & Scherlis find that the training Jacobian preserves most directions in parameter space (corresponding to singular values  and left singular vectors approximately equal to right singular vectors). The subspace approximately preserved by the training Jacobian is referred to as the bulk.

Singular values of an example training Jacobian. Most singular values are close to 1. Source Belrose & Scherlis.

We compute the training Jacobian in our groups setting and observe that

The results are similar across both models trained on group intersections and those trained on iid subsamples of one group.

Singular values of the training Jacobian over time. Model trained on iid sample from .

Assuming the latter observation isn't an artifact of numerical instability, (which is entirely possible) we think it suggests that the limiting infinite-training map, mapping initial parameters to the fully converged final parameters, is discontinuous as a function . That is, arbitrarily small perturbations to the initial weights might cause non-negligible changes to the model's final weights.[9] When the training set is iid, the different resulting model parameters all lead to the same model behaviors -- neural networks are non-identifiable. When the training set is under-determined, these changes in weight space may manifest as changes in model behavior.

As an aside, we note that this high sensitivity to model initialization seems to somewhat contradict (a sufficiently strong version of) the lottery ticket hypothesis. If there really existed a subnetwork at initialization that is significantly more amenable to training than the rest of the model, we'd expect its prominence to be at least somewhat robust to perturbations.

Discussion

We anticipate that ambiguous out-of-distribution generalization is a phenomenon that may arise not only in toy algorithmic tasks but also in more realistic settings. Indeed, recent work (Qin et al., Mészáros et al., Reizinger et al.) finds that the ability of language models to apply grammatical rules generalizes unstably out-of-distribution. Our concern is that ambiguous generalization may pose serious obstacles to the safety and alignment of LLM-based AI systems. Safety training for large language models is typically performed with datasets much smaller than those used for pre-training. Thus, safety training may be significantly more prone to under-specification in a manner leading to ambiguous generalization. Such under-specification may manifest in what is called deceptive alignment: the model appears to its designers to be well-aligned within the training environment, but in reality learns a solution that is harmful in deployment. Heuristic counting arguments (Hubinger [LW · GW], Carlsmith [LW · GW]) suggest that, in the presence of such ambiguity, true alignment may be vanishingly unlikely compared to deceptive alignment or scheming.

Our main motivation for studying models trained on the intersection of two groups was to 1) exhibit a crisp and concrete example of ambiguous generalization and 2) use the setting as a testbed for various hypotheses surrounding this phenomenon. Regarding 2), our hope was to relate some precise measure of a solution's complexity to the frequency with which it is learned, thus providing a quantitative version of the aforementioned counting arguments. While we were not able to fully attain this goal, we did find some evidence in favor of the local learning coefficient and circuit complexity. On the other hand, we were not able to disentangle these two measures in this toy setting, perhaps suggesting the need to move to other, more realistic experiments. Regardless, we continue to believe that ambiguous out-of-distribution generalization is an important, safety-relevant phenomenon that merits further study.

Acknowledgements: This work was mainly conducted during MATS 6.0 and 6.1. Many thanks to Jesse Hoogland, Daniel Filan, and Jacob Drori for helpful conversations. Wilson Wu was supported by an LTFF grant.

 

  1. ^

    Since in this experiment the sample size is small and grokking is somewhat rare, the proportion estimates should be treated as fairly low-confidence. In particular, we don't claim that the model prefers to grok  over . Our experiments [LW · GW] with  and  have a 100x larger sample size, and thus for those groups we are able to draw more confident conclusions about model preferences.

  2. ^

    One hypothesis we had was that the model would prefer the group with lower  for some . This quantity is larger for groups with larger irreps, and in particular is minimized for abelian groups. It also appears with  in the expression for maximum margin computed in Theorem 9 of Morwani et al.'s paper. However, in the limited experiments we ran, this hypothesis seemed not to pan out.

  3. ^

    The functions computed by non-grokking models on the full input space are diverse and, in general, not simply combinations of  and .

  4. ^

    One might draw an analogy to cellular determination. "When a cell 'chooses' a particular fate, it is said to be determined, although it still 'looks' just like its undetermined neighbors. Determination implies a stable change -- the fate of determined cells does not change." (source)

  5. ^

    A speculative, likely incorrect cartoon picture: around epoch 300, the model is funneled into a slide in the loss landscape whose final destination is at . Around epoch 600, there's a sharp bend in the slide (and/or it becomes narrower and/or shallower...). If the model takes the bend too quickly, it shoots over the edge and lands in 

  6. ^

    We use a constant learning rate schedule. Likely this effect would not have been as apparent if we had instead used, say, cosine annealing.

  7. ^

    Performing some cursory mechanistic interpretability on these examples, we find that the original grokked parameters and the regrokked parameters generally differ, though they do tend to share many neurons in common (in the sense that the same neuron indices are supported in the same irreps). Since regrokking is a rare phenomenon that we only quickly spot-checked, we don't claim that this observation is representative.

  8. ^

    It follows immediately from the chain rule that, assuming the final model parameters are at a local minimum, the training Jacobian must lie in the null space of the loss Hessian at the final parameters. That is, the training Jacobian is only nontrivial if the model is singular.

  9. ^

    On the other hand, it is probably not too hard to show that any finite number of iterations of SGD, Adam, etc., is continuous. Only in the infinite training limit can discontinuities appear.

0 comments

Comments sorted by top scores.