How much AI inference can we do?
post by Benjamin_Todd · 2024-05-14T15:10:58.539Z · LW · GW · 7 commentsThis is a link post for https://benjamintodd.substack.com/p/how-much-ai-inference-can-we-do
Contents
What about more advanced GPUs? What about future models and algorithms? What about input tokens? Summing up None 7 comments
Suppose you have a bunch of GPUs. How many LLM forward passes can you do with them?[1]
This is relevant to figuring out how profitable AI will be in the short-term, how powerful AI systems might be able to come in the near future, how large the compute overhang will be and other strategic questions.
Here’s my attempt to understand this topic as a non-specialist. I’ve had it checked over by some technical advisors, but I don’t claim any special expertise. I wrote it because I haven’t been able to find an accessible explainer elsewhere. I appreciate corrections.
The most obvious approach – the one I often see people in the community taking – is to look up how many FLOP per second your GPU can process, then how many FLOP it takes to run a forward pass, and then divide the two.
For example, Nvidia’s A100 GPU is listed at 312 teraflop per second (3e14) on its spec sheet (FP16 tensor), a forward pass of GPT-4 requires 5.6e11 FLOP per forward pass.[2] So that would imply a single GPU can do about 560 forward passes per second.
But this turns out to be much too high.
Even if it were possible to achieve spec sheet FLOP in a real life application (it’s not), this wouldn’t be the relevant figure because, in practice, inference is limited more by memory than by FLOP:
Each forward pass requires all the parameters to also pass through the GPU’s memory. If 280 billion parameters are activated, and each parameter requires 16-bits = 2 bytes to encode it, then 560 gigabytes must pass through memory.[3]
But the A100’s memory bandwidth is 2000 gigabytes per second – only enough for 4 forward passes.
However, 4 forward passes per second is also not right.
In practice, GPUs are parallelised, so multiple forward passes are processed in batches and many other optimisations are applied, allowing real world efficiency to be much higher than the memory bandwidth of an individual GPU would suggest.
So, the first FLOP-based method is an upper bound, the second memory-based method a lower bound, and the real world lies somewhere in between.
Figuring out where real world efficiency lies is tricky. Not only is it an area of active research, but it also depends on many factors, such as acceptable latency, context length, batch size, size of the model, etc.
The best estimate I’ve seen so far is from Semianalysis. In their article on GPT-4 architecture, they estimate that a cluster of 128 A100s can output 1 million tokens for $4.9 of compute (assuming fairly high utilisation and a context seqlen of 8k).
If A100s cost $1/hour on the cloud in 2023, running the cluster costs $128 per hour. This means the cluster must produce $128/4.9 = 26 million forward passes per hour.
That’s about 60 forward passes per chip per second – about 10% of the theoretical max, but 15 times better than the lower bound.
(Interestingly, it’s significantly worse than the ~33% utilisation of FLOP that can be achieved in training, which means that even if a certain number of FLOP were used for training, the same GPUs couldn’t produce that many FLOP if applied to inference.)
What about more advanced GPUs?
In the same article, Semianalysis provides a similar figure for the H100. They also have a newer post analysing the inference performance of the newer Blackwell chips, which gives a rough sense of how the B200 compares to the H100.[4] For the H200, I looked at some comparisons of inference performance with the H100 and guessed 1.7x.
From this, I can make the following table:
One interesting point is that inference throughput has increased about 20x in 5 years, compared to only an 8x increase in FLOP.
This seems to be at least partly because Nvidia has crammed a lot more memory bandwidth into the newest chips, and memory is still usually a bigger constraint to inference than FLOP/s. It also allows for larger batch sizes. And more broadly it seems like the recent generation of chips have been more optimised for inference relative to training.
However, the underlying speed of memory bandwidth has been increasing more slowly than FLOP for years (the so-called ‘memory wall’), so while memory has been able to catch up recently, my best guess would be that it’s a temporary effect.
My figures are also based on FP16 and 16-bit encoding of model weights, but it seems like inference is switching to FP8 and 8-bit encoding, which could also roughly double how much inference can be done per GPU.[5]
What about future models and algorithms?
As a first pass, FLOP and memory requirements per forward pass scale linearly with the number of model parameters activated per forward pass. So if a model is 10 times bigger, all else equal we’ll only be able to perform about one tenth as many forward passes.
In reality, there are many complications.
For example, if users really value long contexts and low latency, that makes it harder to batch, which pushes throughput towards the lower bound. If, however, most inference is short context and higher latency, throughput could be much closer to the theoretical max (but probably not above about ~50%).
We should also probably expect chips, parallelisation and algorithms to get better optimised over inference over time, allowing throughput to get closer to the max, or to achieve more performance with less compute. These effects can be big.
One example is that new parallelisation techniques could be discovered, allowing inference to get closer to the upper bound of available FLOP.
A more interesting example is that by using a mixture of experts structure, GPT-4 only needs to activate about one tenth of its parameters on a typical forward pass, so it only requires about a tenth of the compute suggested by its total number of parameters. Future models might be able to activate an even smaller fraction of parameters to achieve similar performance.
Models can also be trained using more data and use fewer parameters, which makes the model cheaper to run.
As a concrete example, by using 10 times as many experts, a lot of data and some other architecture improvements, it seems like DeepSeek has been able to achieve performance approaching GPT-4 while only activating about a tenth as many parameters.
As a final aside, Dylan Patel of Semianalysis claims that the computing requirements to run a forward pass will increase more slowly than linearly with model size. I’m not sure exactly why this is, but it could be because larger models open up more possibilities for optimisation.
What about input tokens?
Everything above has been just about the compute needed to produce one output token. But in reality, the compute required also depends on the number of tokens that are input into the model before producing the output.
I’ve heard conflicting accounts of the relationship between output tokens and compute, but a technical advisor told me that for FLOP adding input and output tokens works as a rough rule of thumb.
For example, if you input 1,000 input tokens and get 50 output tokens, then you need about 1050 times the FLOP required for one forward pass.
That also lines up with this article by Az16 (and would be consistent with the fees for using LLMs being linear in input tokens, though there are other reasons for this).
So, we can roughly say the throughput numbers above hold for the number of input or output tokens in most cases.
Though my understanding is this could break down if the number of input tokens / context is very large, in which case the memory requirements can increase faster than linear, pushing performance closer to the lower bound per token.
Summing up
We can look at the FLOP per second of future chips to get a rough upper bound on future ability to do inference, and memory bandwidth to get a lower bound, and think about where real life performance might fall within that range. Then we can compare that to the size of future models.
Historically, our ability to use maximum available FLOP in inference has been worse than in training.
However, inference throughput has been getting closer to the upper bound recently, as chips have been more adapted to inference (especially through having more memory bandwidth), and parallelisation & batching techniques have improved. This trend could continue (up to a max of maybe around 50% of the upper bound) if we discover more parallelisation techniques. Or, it could start to reverse due to the memory wall.
Algorithmic improvements have also allowed models to achieve the same performance while using much less compute, and that trend seems likely to continue.
Switching from FP16 to FP8 could also roughly double how much inference we can do with a given cluster of GPUs.
This was originally posted on benjamintodd.substack.com. Subscribe to get all my posts.
- ^
(A forward pass is activation of all the parameters in the model, which produces one token of output, which is roughly equivalent to one word).
- ^
From Semianalysis, “GPT-4 architecture”, July 2023:
GPT-4 is more than 10x the size of GPT-3. We believe it has a total of ~1.8 trillion parameters across 120 layers versus the ~175 billion parameters of GPT-3…
Furthermore, OpenAI utilizes 16 experts within their model, each is about ~111B parameters for MLP. 2 of these experts are routed to per forward pass.
While the literature talks a lot about advanced routing algorithms for choosing which experts to route each token to, OpenAI’s is allegedly quite simple, for the current GPT-4 model.
Furthermore, there are roughly ~55B shared parameters for attention.
Each forward pass inference (generation of 1 token) only utilizes ~280B parameters and ~560 GFLOPs. This contrasts with the ~1.8 trillion parameters and ~3,700 GFLOP that would be required per forward pass of a purely dense model.
- ^
One complication is that GPT-4 uses a mixture of experts structure, and so isn’t a dense model. The total number of parameters are ~8x larger than those that get activated in a forward pass. However, these extra parameters also need to pass through the memory in some situations, which would further decrease the lower bound. I’m ignoring this complication and treating GPT-4 as a dense model with 280 billion parameters. - ^
Unfortunately there’s not a completely direct comparison. The new post doesn’t cover the A100, and the H100 analysis is for 32k input tokens rather than 8k. However, they do say “As such, in a large model like GPT-4 B200 brings ~4x to ~7x performance gains for GPT-4 inference, when quantization is set fairly, depending on the point in the interactivity curve chosen.”
- ^
FP16 means that 16 bits are used to encode each number in the computation. The more bits used, the more precisely the number can be encoded, reducing rounding errors. However, it turns out that ML algorithms can often perform about as well with less accurate encodings. If fewer bits are used for each number, you can do more calculations while using less compute.
7 comments
Comments sorted by top scores.
comment by ryan_greenblatt · 2024-05-15T00:00:33.809Z · LW(p) · GW(p)
I think this article fails to list the key consideration around generation: output tokens require using a KV cache which requires substantial memory bandwidth and takes up a considerable amount of memory.
From my understanding the basic situation is:
- For input (not output) tokens, you can get pretty close the the maximum flop utilization for realistic work loads. To make this efficient (and avoid memory bandwidth issues), you'll need to batch up a bunch of tokens at once. This can be done by batching multiple input sequences or even a single long sequence can be ok. So, memory bandwidth isn't currently a binding constraint for input tokens.
- (You might also note that input tokens have a pretty similar work profile to model training as the forward pass and backward pass are pretty structurally similar.)
- However, for generating output tokens a key bottleneck is that you have utilize the entire KV (key value) cache for each output token in order to implement attention. In practice, this means that on long sequences, the memory bandwidth for attention (due to needing to touch the whole KV cache) can be a limiting constraint. A further issue is that KV cache memory consumption forces us to use a smaller batch size. More details:
- It will still be key to batch up token, but now we're just doing computation on a single token which means we'll need to batch up many more sequences: the optimal number of sequences to batch for generating output tokens will be very different than the optimal number of sequences to batch for input tokens (where we can run the transformer on the whole sequence at once).
- A further difficulty is that because we need a higher batch size, we need a larger amount of KV cache data. I think it's common to use an otherwise suboptimally small batch size for generation due to constraints on VRAM (at least on consumer applications (e.g. llama-70b inference on 8xH100), I assume this also comes up for bigger models). We could store the KV cache on CPU, but then we might get bottlenecked on memory bandwidth to the CPU.
- Note that in some sense the operations for each output token is the same as for each input token. So, why are the memory bandwidth requirements worse? The key thing is that we potentially get much worse cache locality on output tokens due to only computing one token, but needing to read the KV for many tokens (while on input we do many to many).
- However, it is possible to substantially reduce the KV sizes using various optimizations like sparse attention and mamba. This can substantially improve inferences speeds due to reducing memory bandwidth in inference and also allowing for higher batch sizes. See e.g. the mamba paper where allowing for higher batch sizes results in substantially higher speeds.
One additional note: I recently set up an inference setup for llama-3-70b on 8xH100. I can get about 100,000 tok/s on inputs which is pretty close to full utilization (1e15 flop/s * 8 gpus / 7e10 flop per forward pass). However, I get dramatically worse performance on generation, perhaps 3,200 tok/s. I'm doing generation with long prompts and llama-3-70b has no sparse attention or other feature for reducing KV cache (beyond multi-query attention which is standard these days), so KV cache bits pretty hard. My setup probably isn't very close to optimal, especially on output tok/s, I'm just using basic out of the box stuff (vllm).
Replies from: Benjamin_Todd↑ comment by Benjamin_Todd · 2024-05-15T08:26:13.551Z · LW(p) · GW(p)
Thanks that's interesting!
Can I double check, do you think this affects the bottom lines?
The bottom line is supposed to be that FLOP/s vs. FLOP per forward pass can be used as an upper bound, and memory bandwidth vs. model size can be used as an lower bound, and real life efficiency falls somewhere in the middle depending on a many factors (inc. length of KV cache), which I don't try to get into, but is plausibly around 15% of the upper bound for GPT-4 on H100s.
Are you saying that the lower bound for output tokens should maybe be even lower, because the KV cache can be larger than the model weights?
Replies from: ryan_greenblatt↑ comment by ryan_greenblatt · 2024-05-15T18:40:23.481Z · LW(p) · GW(p)
The lower bound of "memory bandwidth vs. model size" is effectively equivalent to assuming that the batch size is a single token. I think this isn't at all close to realistic operating conditions and thus won't be a very tight lower bound. (Or reflect the most important bottlenecks.)
I think that the KV cache for a single sequence won't be larger than the model weights for realistic work loads, so the lower bound should still be a valid lower bound. (Though not a tight one.)
I think the bottom line number you provide for "rough estimate of actual throughput" ends up being pretty reasonable for output tokens and considerably too low for input tokens. (I think input tokens probably get more like 50% or 75% flop utilization rather than 15%. See also the difference in price for anthropic model.)
That said, it doesn't seem like a good mechanism for estimating throughput will be to aggregate the lower and upper bounds you have as the lower bound doesn't have much correspondence with actual bottlenecks. (For instance, this lower bound would miss that mamba would get much higher throughput.)
I also think that insofar as you care about factors of 3-5 on inference efficiency, you need to do different analysis for input tokens and output tokens.
(I also think that input tokens get pretty close to the pure FLOP estimate. So, another estimation approach you can use if you don't care about factors of 5 is to just take the pure flop estimate and then halve it to be account for other slow downs. I think this estimate gets input tokens basically right and is wrong by a factor of 3-5 for output tokens.)
It seems like your actual mechanism for making this estimate for the utilization on output tokens was to take the number from semi-analysis and extrapolate it to other GPUs. (At least the number matches this?) This does seem like a reasonable approach, but it isn't particularly tethered to your lower bound.
Replies from: Benjamin_Todd↑ comment by Benjamin_Todd · 2024-05-18T18:31:53.408Z · LW(p) · GW(p)
I agree the lower bound for output isn't very tight. I'd be very interested to hear other simple rules of thumb you could use to provide a tighter one.
I'll add a note to the section on input tokens that since they don't require KV cache, it's possible to get much closer to the upper bound.
comment by kotrfa · 2024-07-25T08:19:11.980Z · LW(p) · GW(p)
Even though some commenters mentioned some issues with the article, I really want to appreciate the attempt and being upfront with the estimates. It's very relevant for the thing I am now trying to figure out. As I have almost no intuitions about this except about some raw FLOPS, it pointed to important flaws my analysis would have. There are not many public sources that would explain that [are not a book or don't require me reading one-to-many to understand it]
comment by Seth Herd · 2024-05-14T16:17:18.598Z · LW(p) · GW(p)
Algorithmic improvements are, on average, roughly similar in soeed to hardware improvements. In the area I f deep nets I believe they're on average larger, although I haven't looked deeply enough to say this with confidence or have a ref handy. So how much you can do is a function of how far in the future you're talking about, on two fronts. The opportunities for algorithmic improvements go far beyond the parallelization and mixture of experts methods you mention.
Replies from: Benjamin_Todd↑ comment by Benjamin_Todd · 2024-05-15T08:34:21.711Z · LW(p) · GW(p)
The opportunities for algorithmic improvements go far beyond the parallelization and mixture of experts methods you mention.
I agree. I'd be very interested in anyone's forecasts for how they might evolve.
I've been working with (very roughly) another ~10x or so improvement in "inference efficiency" by 2030 (or how to measure this and make sure it's independent from other factors).
By this I mean that if we were able to train a model with 10^26 FLOP this year, achieving a fixed level of learning efficiency, it would require 10X FLOP to generate useful output, while by 2030 it would only require X FLOP to get the same output.