How do scaling laws work for fine-tuning?
post by Daniel Kokotajlo (daniel-kokotajlo) · 2021-04-04T12:18:34.559Z · LW · GW · No commentsThis is a question post.
Contents
Answers 15 rohinmshah 4 Charlie Steiner None No comments
The scaling laws, at least according to the interpretation used in Ajeya's framework [LW(p) · GW(p)] (and this seems to be basically endorsed by tons of people I respect on this matter) say basically that if you increase parameter count by an order of magnitude, you also need to increase training steps/data points by about an order of magnitude, or else you are wasting your compute and could get the same performance with a smaller parameter count. For example, for a 10^14 parameter model (the size of the human brain, basically) we'd need 10^13 training steps/data points.
Now we have papers like this one claiming that pre-trained transformers can be fine-tuned to do well at completely different tasks (incl. different modalities!) by only modifying 0.1% of the parameters.
Does this mean that this fine-tuning process can be thought of as training a NN that is 3 OOMs smaller, and thus needs 3 OOMs fewer training steps according to the scaling laws? I'm guessing the answer is no, but I don't know why, so I'm asking.
(If the answer is yes, how does that not contradict the scaling laws for transfer described here and used in this calculation by Rohin [LW(p) · GW(p)]?)
Answers
Does this mean that this fine-tuning process can be thought of as training a NN that is 3 OOMs smaller, and thus needs 3 OOMs fewer training steps according to the scaling laws?
My guess is that the answer is mostly yes (maybe not the exact numbers predicted by existing scaling laws, but similar ballpark).
how does that not contradict the scaling laws for transfer described here and used in this calculation by Rohin [LW(p) · GW(p)]?
I think this is mostly irrelevant to timelines / previous scaling laws for transfer:
- You still have to pretrain the Transformer, which will take the usual amount of compute (my calculation that you linked takes this into account).
- The models trained in the new paper are not particularly strong. They are probably equivalent in performance to models that are multiple orders of magnitude smaller trained from scratch. (I think when comparing against training from scratch, the authors did use smaller models because that was more stable, though with a quick search I couldn't find anything confirming that right now.) So if you think of the "default" as "train an X-parameter model from scratch", then to get equivalent performance you'd probably want to do something like "pretrain a 100X-parameter model, then finetune 0.1% of its weights". (Numbers completely made up.)
- I expect there are a bunch of differences in how exactly models are trained. For example, the scaling law papers work almost exclusively with compute-optimal training, whereas this paper probably works with models trained to convergence.
You probably could come to a unified view that incorporates both this new paper and previous scaling law papers, but I expect you'd need to spend a bunch of time getting into the minutiae of the details across the two methods. (Probably high tens to low hundreds of hours.)
↑ comment by Daniel Kokotajlo (daniel-kokotajlo) · 2021-04-04T19:48:46.922Z · LW(p) · GW(p)
Thanks! Your answer no. 2 is especially convincing to me; I didn't realize the authors used smaller models as the comparison--that seems like an unfair comparison! I would like to see how well these 0.1%-tuned transformers do compared to similarly-sized transformers trained from scratch.
Replies from: rohinmshah↑ comment by Rohin Shah (rohinmshah) · 2021-04-04T19:57:16.049Z · LW(p) · GW(p)
I don't think similarly-sized transformers would do much better and might do worse. Section 3.4 shows that large models trained from scratch massively overfit to the data. I vaguely recall the authors saying that similarly-sized transformers tended to be harder to train as well.
I think it's plausible that the data dependence will act like it's 3 OOM smaller. Compute dependence will be different, though, right? Even if you're just finetuning part of the model you have to run the whole thing to do evaluation. In a sense this actually seems like the worst of both worlds (but you get the benefit from pretraining).
Edit: Actually, I'm confused why you say a smaller model needs that factor fewer steps. I thought the slope on that one was actually quite gentle. It's just that smaller models are cheap - or am I getting it wrong?
↑ comment by Daniel Kokotajlo (daniel-kokotajlo) · 2021-04-05T10:33:40.045Z · LW(p) · GW(p)
I think compute cost equals data x parameters, so even if parameters are the same, if data is 3 OOM smaller, then compute cost will be 3 OOM smaller.
I'm not sure I understand your edit question. I'm referring to the scaling laws as discussed and interpreted by Ajeya. Perhaps part of what's going on is that in the sizes of model we've explored so far, bigger models only need a little bit more data, because bigger models are more data-efficient. But very soon it is prophecied that this will stop and we will transition to a slower scaling law according to which we need to increase data by almost as much as we increase parameter count. So that's the relevant one I'm thinking about when thinking about TAI/AGI/etc.
Replies from: Charlie Steiner↑ comment by Charlie Steiner · 2021-04-05T18:58:31.283Z · LW(p) · GW(p)
I'm not sure how your reply relates to my guess, so I'm a little worried.
If you're intending the compute comment to be in opposition to my first paragraph, then no - when finetuning a subset of the parameters, compute is not simply proportional to the size of the subset you're finetuning, because you still have to do all the matrix multiplications of the original model, both for inference and gradient propagation. I think the point for the paper only finetuning a subset was to make a scientific point, not save compute.
My edit question was just because you said something about expecting the # of steps to be 3 OOM for a 3 OOM smaller model. But iirc really it's more like the compute will be smaller, but the # of steps won't change much (they're just cheaper).
Do you have a reference for this picture of "need lots more data to get performance improvements?" I've also heard some things about a transition, but as a transition from compute-limited to data-limited, which means "need lots more compute to get performance improvements."
Replies from: daniel-kokotajlo↑ comment by Daniel Kokotajlo (daniel-kokotajlo) · 2021-04-05T20:21:28.655Z · LW(p) · GW(p)
I totally agree that you still have to do all the matrix multiplications of the original model etc. etc. I'm saying that you'll need to do them fewer times, because you'll be training on less data.
Each step costs, say, 6*N flop where N is parameter count. And then you do D steps, where D is how many data points you train on. So total flop cost is 6*N*D. When you fine-tune, you still spend 6*N for each data point, but you only need to train on 0.001D data points, at least according to the scaling laws, at least according to the orthodox interpretation around here.
I'd recommend reading Ajeya's report (found here) [AF · GW] for more on the scaling laws. There's also this comment thread. [AF · GW]
Replies from: Charlie Steiner↑ comment by Charlie Steiner · 2021-04-06T02:06:06.841Z · LW(p) · GW(p)
Sure, but if you're training on less data it's because fewer parameters is worse :P
Replies from: daniel-kokotajlo↑ comment by Daniel Kokotajlo (daniel-kokotajlo) · 2021-04-06T06:44:39.300Z · LW(p) · GW(p)
Not according to this paper! They were able to get performance comparable to full-size networks, it seems. IDK.
Replies from: Charlie Steiner↑ comment by Charlie Steiner · 2021-04-06T14:19:11.893Z · LW(p) · GW(p)
I am frankly skeptical that this (section 3.9 in the pretrained frozen transformer paper) will hold up to Grad Student Descent on training parameters. But hey, maybe I'm wrong and there's some nice property of the pretrained weights that can only be pushed into overfitting by finetuning.
No comments
Comments sorted by top scores.