The slingshot helps with learning

post by Wilson Wu (wilson-wu) · 2024-10-31T23:18:16.762Z · LW · GW · 0 comments

Contents

  1 Introduction
    1.1 The k-sparse parity task
    1.2 The slingshot effect
  2 Why do slingshots happen in the k-sparse parity setting?
    2.1 Hinge loss
    2.2 Cross-entropy loss
    2.3 Resetting AdamW mitigates slingshots
  3 Slingshots and generalization
    3.1 Slingshots tend to decrease test loss
    3.2 Slingshots vs random jumps
    3.3 Slingshots and LLC
  4 Discussion
None
No comments

The slingshot effect is a late-stage training anomaly found in various adaptive gradient optimization methods. In particular, slingshots are present with AdamW, the optimizer most widely used for modern transformer training. The original slingshot paper observes that slingshots tend to occur alongside grokking, a phenomenon in which neural networks trained on algorithmic tasks generalize to the test set long after perfectly fitting the training set. In this post we take a closer look at slingshots and their effect on generalization in the setting of 1-hidden-layer MLPs trained on -sparse parity, a specific algorithmic task. The main results are 1) an explanation of why slingshots occur in models trained with hinge loss that partially transfers to models trained with cross entropy loss and 2) empirical evidence that slingshots are biased towards decreasing test loss.

1 Introduction

1.1 The -sparse parity task

-sparse parity function is a function that takes in a bitstring of length  and returns the parity (XOR) of a specific size  subset of the input bits. Thus, for fixed , there are  different -sparse parity functions. We start by sampling one of these functions at random; the model's task is then to learn this function from a training set of input/output sample pairs. Intuitively, the heart of this problem is distinguishing from the training samples alone which of the  input bits is relevant. Learning -sparse parities is a well-studied task with known computational[1] lower bounds. See Barak et al. 2023 for more details.

In our experiments, we train a 1-hidden-layer ReLU MLP on this task. Input bitstrings of length  are encoded as vectors in  for input to the model. We set  and  and MLP dimensionality . All experiments (unless noted otherwise) use the AdamW optimizer with learning rate set to  and all other hyperparameters left as default.

1.2 The slingshot effect

slingshot is a spike in training loss co-occurring with a sudden increase in weight norm. In our experiments, we observe that it is often also associated with a sudden drop in test loss.

Figure 1: Example of a training run exhibiting a slingshot (red highlight). Blue curve is training loss, red curve is test loss, and green curve is input layer weight norm. In this example only one slingshot occurs, but in general there may be multiple. Model is trained on cross-entropy loss.

2 Why do slingshots happen in the -sparse parity setting?

We first consider models trained with hinge loss, for which we have an explanation of the slingshot phenomenon. Then we consider models trained with cross-entropy loss; there the picture is somewhat messier, but some aspects of the same explanation still transfer over.

When discussing AdamW's internal state we will use the same notation as in the pseudocode here.

2.1 Hinge loss

For hinge loss, we remap the expected model output from 0/1 to . The loss given model output  and ground truth label  is then . As opposed to the case with cross-entropy loss, with hinge loss there are wide (topologically open) regions of output space, and thus parameter space, for which the loss (and hence the gradient) is precisely zero.

The story goes:

  1. Training proceeds until the model perfectly fits the training data and model parameters enter a zero training loss basin.
  2. Within this basin, because there is no gradient, momentum quickly decays to zero. Movement in parameter space is then driven entirely by weight decay. Meanwhile, the variance estimate more slowly decays to zero. (The default AdamW decay factors, and the ones we use in our experiments, are  for momentum and  for variance.)
  3. Eventually, weight decay causes the model parameters to hit a wall of the basin; i.e., they have shrunk to the point where they no longer maintain perfect accuracy on the training set. The gradient  is no longer zero; momentum is updated as approximately , variance is updated as approximately ,[2] and parameters are updated as 
  4. This large parameter update (a factor of  larger than is typical[3]) slingshots the model parameters out of the zero loss basin. Geometrically, the slingshot must be in the direction of increasing weight norm, because the parameters hit the basin's wall while moving in the direction of decreasing weight norm.
  5. If the gradient wherever the model parameters land happen to be aligned with , then the effect is compounded.[4] E.g., if it so happens that , then in the next step the numerator of the parameter update is roughly doubled, while the denominator is multiplied by a factor of . Immediately after the slingshot, the numerator scales linearly with the number of following steps aligned with the initial gradient, while the denominator scales with the square root.

Note that, with AdamW, the weight decay is decoupled from the momentum and variance estimates. Thus, if the gradient at each step is near zero, both these estimates decay towards zero. Simultaneously, the model still moves in parameter space under the effect of weight decay.

If we were using Adam instead, the corresponding condition would be for the loss gradient and weight decay terms to cancel each other out.  In this case, the model parameters would necessarily be near-stationary and unable to "drift into a wall" as in Step 3 above. Indeed, we do not find any slingshots when using the Adam optimizer. Note that this seems to contradict the results of Thilak et al.; see Sec. 4 for further discussion of this and other discrepancies.

Figure 2: Model trained with hinge loss. In this example, there are three slingshots. Blue line is train loss, red line is test loss, and green line is weight norm.
Figure 3: Same example as in Fig. 2. Transparent green line is gradient norm and orange line is AdamW variance estimate. When the model enters a zero loss plateau, the gradient is zero and the variance estimate decays exponentially to zero. (Note that gradient and variance estimate norm are on a log scale.)
Figure 4: Same as Fig. 3 but with inclusion of the momentum estimate norm (purple line). There are two points in training where the model parameters hit a wall of the basin (gradient and momentum spike) but there is no associated slingshot. Possibly this is because the variance estimate had not yet decayed to a low enough value.
Figure 5: Same example as in Fig. 2. In purple is parameter update size . Update size during the slingshot is larger than that during the start of training.
Figure 6: Same example as in Fig. 2. In orange is the cosine similarity between  and . Negative values indicate shrinkage of parameters driven by weight decay, while positive values indicate movement in the direction of increasing weight norm. Note that the largest positive values occur during the slingshot. 

2.2 Cross-entropy loss

When using cross-entropy loss, the training loss and its gradient are never zero. Rather, once the model has attained perfect train accuracy, the loss gradient and weight decay are competing forces: the model can always decrease cross-entropy loss by scaling up weights. Yet, slingshots still do occur, and still appear alongside 1) spikes in gradient, momentum, and variance; and 2) large optimizer steps taken in the direction of increasing weight norm.

Figure 7: Same example as in Fig. 1, now plotting gradient norm (green), momentum (purple), and variance (orange). Although the gradient never vanishes, late in training it is small enough that the variance still decays exponentially towards zero.
Figure 8: Same example as in Fig. 1. In purple is parameter update size . Update size during the slingshot is larger than that during the start of training.
Figure 9: Same example as in Fig. 1. In orange is the cosine similarity between  and . Note that the largest positive values occur during the slingshot.

2.3 Resetting AdamW mitigates slingshots

The explanation given in Steps 4 and 5 of Sec. 2.1 rely on the number of iterations  being large and thus the normalizations of the momentum and variance by  and  being negligible. Thus, if we were to reset AdamW's state right before a slingshot, we should expect to find the instability substantially mitigated. This is indeed what we find, although the spikes in train and test loss do not entirely disappear (Fig 10). A reset in AdamW's internal state is itself a discontinuity in the training process, which may explain why the loss curves are not entirely smoothed out. Note also that this iteration count resetting trick only temporarily prevents the slingshot, so it needs to applied several times over the course of a training run.

Figure 10: Model trained with cross-entropy loss from the same initial seed as in Fig. 1, but with AdamW state reset before slingshots occur. Blue line is training loss, red line is test loss, and green line is weight norm. Note that final test loss is significantly higher than that in Fig. 1.

3 Slingshots and generalization

3.1 Slingshots tend to decrease test loss

We run 100 training runs of models trained on the -sparse parity task with cross-entropy loss, and automatically detect slingshots. We look at the difference in test loss before and after each slingshot, and compare these to the difference in test loss before and after intervals of the same length without slingshots. We find that slingshots tend to reduce test loss more than training for the same amount of time without slingshots (Fig. 11).

3.2 Slingshots vs random jumps

Why do slingshots help with generalization? One possible story is that they are simply a mechanism for jumping out of poorly-generalizing local minima in the training loss landscape. One might then further suppose that a jump in any arbitrary direction results in the same benefits. We test this directly by starting at points in training right before a slingshot and sampling jumps in random directions with the same magnitude as the slingshot. We find preliminary[5] evidence that, contrary to this supposition, the original slingshot tends to reduce test loss more than a random jump does (Fig. 11, 12). 

Currently, we do not have a good explanation for why slingshots induce better generalization than random jumps. One possibility is that an artificially introduced random jump does not update AdamW's internal state in the same way as a slingshot does, thus attenuating the affect described in Sec. 2.1.5. It should also be noted that our method for detecting slingshots, and thus the size and starting point of the random jump, is somewhat unprincipled[6]; possibly, a more "correct" jump size and starting point would weaken or eliminate the difference between slingshots and random jumps we observe here.

Figure 11: Log ratio in test loss after vs before an interval where 1) a slingshot occurred (blue, median=-0.438) 2) a slingshot did not occur (red, median=-0.039) and 3) we introduce a jump in a random direction (green, med=-0.369). Models all trained with cross-entropy loss.
Figure 12: Same example as in Fig. 1 with 10 randomly sampled jumps introduced for comparison with the original slingshot trajectory (black). All lines plot test cross-entropy loss.

3.3 Slingshots and LLC

The local learning coefficient (LLC) of a point in a model's parameter space is a measure of model complexity in a neighborhood around that point. That the LLC varies with the model parameters is a crucial difference from the classical statistical picture, in which the effective dimensionality is constant across all parameters. (For example, consider linear regression—the dimensionality of the model is equal to the number of weight parameters, regardless of the setting of the weights.) At least in the Bayesian setting, the LLC is closely related to generalization properties of the model. See the Distilling SLT [? · GW] sequence for more details.

We estimate the LLC for models trained on cross-entropy loss[7] and find that it tends to decrease after a slingshot (Fig. 13). Intuitively, we can interpret this observation as the slingshot moving the model from a more complex memorizing solution to a simpler generalizing solution.

Figure 13: Same example as in Fig. 1. Green line is LLC. Before the slingshot, the LLC and the test loss are both high (memorizing solution); immediately afterwards, they are both lower (generalizing solution). Note that the LLC is a function of only the model parameters and the training data, not the test data.
Figure 14: Training run from same initial seed as Fig. 13, but with AdamW internal state reset before slingshots occur (Sec. 2.3). Green line is LLC. Note that LLC at end of training is higher than that in Fig. 13.

4 Discussion

The scope of this post is fairly limited: we only discuss a very simple model (one-hidden-layer MLP) trained on a synthetic algorithmic task. The explanation of the slingshot effect given in Sec. 2.1 holds in its entirety only for models trained with hinge loss (or, possibly, for other setups with open regions in parameter space with vanishing gradient). It is unclear how well this explanation carries over to the wide range of settings investigated by Thilak et al. 2024. In particular, there are several slingshot regimes in their paper that 1) do not fit our explanation and 2) we are unable to empirically replicate:

Perhaps there are multiple possible sources of instability in late stages of training with adaptive optimizers, all of which manifest in the training loss curve as a "slingshot-like" effect. If so, the topic of this post is then only a single member in this family of instabilities.

Despite these caveats, we do find evidence in this toy setting that, during slingshots, the model tends to jump to better-generalizing solutions. The current lack of understanding of this phenomenon indicates that there is still much work to do connecting generalization and the inductive biases of adaptive optimizer methods.

Code for running the experiments in this post can be found in this notebook.

Acknowledgements: This work was completed during MATS 6.0. Many thanks to @Jesse Hoogland [LW · GW]  for mentorship and @DanielFilan [LW · GW] for research management, as well as to @LawrenceC [LW · GW] and @jacob_drori [LW · GW] for feedback on drafts of this post. I am supported by a grant from the Long-Term Future Fund.

  1. ^

    Even though  samples are information-theoretically sufficient to distinguish the -sparse parity function,  constant-noise queries are necessary to learn the function in the statistical query model, which encapsulates gradient-based methods.

  2. ^

    We can ignore the rescaling by  and , respectively, because late in training both  and  are approximately zero. We also ignore the influence of the  term in the denominator of the parameter update.

  3. ^

    If  and  were set to more comparable values, this factor would disappear. However, the explanation in Step 5 still goes through. We still observe slingshots when setting , in which case . Anecdotally, slingshots in this regime appear to be smaller but more frequent (as variance decays to zero more rapidly).

  4. ^

    This effect depends on adjacent gradients being correlated—if they were independent, then both the numerator and the denominator of the parameter updates would scale as the square root of the number of steps immediately following a slingshot. (The standard deviation of the sum of  independent Gaussians scales as ). Empirically, we find that cosine similarities between gradients of adjacent steps are on average positive throughout training, not just during slingshots. In the case of cross-entropy loss, correlated gradients can be explained by the negative loss gradient pointing outwards once the training data has been fit perfectly. (Scaling up the model decreases loss.)

  5. ^

    The difference in log test loss ratio between slingshot intervals and non-slingshot intervals is statistically significant (p<0.05, Mood's median test), but the difference between slingshots and random jumps is not (p>0.05).

  6. ^

    We made the arbitrary choice of defining a slingshot interval as 1) starting when the training loss, after previously decreasing below , exceeds  and 2) ending when the training loss returns to below . We did not check whether our results are robust to other methods of detecting spikes in training loss.

  7. ^

    It would not make sense to estimate the LLC for models trained on hinge loss, because the model parameters quickly reach flat basins of zero training loss. If the training loss landscape is flat in a neighborhood of a point then the LLC at that point is zero.

0 comments

Comments sorted by top scores.