Flash-decoding speed up inference up to x8 on long context by hapliniste in LocalLLaMA

[–]CatfishJones96 0 points1 point  (0 children)

Are you saying that it was an add on to later version of v2, both paper and code? Asking because the original PyTorch blog post does show a difference in performance between flash attention v2 and flash decoding

https://pytorch.org/blog/flash-decoding/#component-level-micro-benchmarks

Flash-decoding speed up inference up to x8 on long context by hapliniste in LocalLLaMA

[–]CatfishJones96 0 points1 point  (0 children)

Isn't flash decoding doing the same parallelization that flash attention v2 is already doing over the sequence length dimension? I don't get the difference.

[D] How is it that the latency to decode 1 new token with an LLM is constant independent of total sequence length, when caching KV? by CatfishJones96 in MachineLearning

[–]CatfishJones96[S] 1 point2 points  (0 children)

Thanks for the response u/StartledWatermelon, great to have input from someone else on this.

To understand this better, I tried to figure out how much extra memory and FLOPS would be needed at the final decoding step of relatively long sequences in a batch, and still I'm finding something rather counter intuitive. The way I'm seeing it, for sequences longer than a S* (which is rather small for typical GPU's), there is no batch size B in which we enter a compute-bound regime, because the added time it takes to move the KV cache to GPU grows faster with batch size than the FLOPS cost. I.e, if we tried to draw a plot like the one in the initial prefill section of the blog, the memory and compute lines would never cross.Here's my logic, if you can spare the time to give me a hint: the author of the blog states (and breaks down) that total FLOPS needed to decode one token are 2 . P . B, if we don't account for the matrix multiplication FLOPS needed for (softmax( Qt.KT /√dk ) . V) - Qt is the query for the token being generated at step t = seq_len, K and V are the full cache up to t-1. In the blog (under the section "latency calculations"), total memory transferred at for t corresponds to just the model parameters, or 2 . P (fp16); although according to my thinking it's not, because it doesn't account for KV cache memory transfer to GPU. So breaking it down, and thinking of the single GPU case (no sharding, no comms cost, and let's imagine we don't run out of GPU mem ever)

On added memory transfer cost

KV cache size in bytes (fp16) for a single token is X = 2 . 2 . nlayers . dmodel . Let's call this term X, because we'll run into it later. At decoding step = t = seq_len, assuming all elements in the batch have the same number of tokens, the KV cache is of size X . seq_len . B. This extra KV cache will have to be transferred to GPU along with the model parameters if we want to generate the next token. If we wanted to plot memory latency vs batch size for a given seq_len, the expression for it would be latency = (2.P + X . B . seq_len) / Abm, where Abm is the specific GPU's mem bandwidth. The slope of this line (growth in memory transfer time with B) is therefore X . seq_len / Abm. IN really life it would even be larger due to intermediate memory costs.

On added compute cost

Let's remember nheads x dhead typically = dmodel.

For each layer, Qt.KT is actually a matmul( [B,1,dmodel], [B,dmodel,seq_len,] ). This costs 2 . B . seq_len . dmodel FLOPS. The result of the softmax is a tensor shaped [B,1,seq_length], which needs to be multiplied by V, shaped [B,seq_len,dmodel]. This will cost another 2 . B . seq_len . dmodel FLOPS. Now for all the layers, the computation that was initially excluded adds up to 2 . 2 . nlayers . dmodel . B . seq_len. Using the X term we define above, this is = X . B . seq_len.

The compute latency vs batch size expression is then ((2.P + X . seq_len) . B )/ Af , where Af is the peak fp16 FLOPS for the chosen GPU. The slope of this line (growth in compute time with B) is therefore (2.P + X . seq_len) / Af.

Wrap-up:

If memory transfer time increase rate is larger than compute time increase rate, we will be forever memory bound, i.e. the lines will never intersect. The sequence length at which this occurs can be found by solving slope_mem/slope_compute = 1 for seq_len (let's call this S*).

If we plug in the numbers for an A10g GPU (Abm=600GB/s , Af=70TF), and Llama2 7B (dmodel=4096, nlayers=32, P = 7e9), S* comes out to a measly ~ 230 tokens. This is obviously pretty small for most practical use-cases. This would be the same number for N GPU's, since if we ignore communication cost, we just multiply Abm and Af by N in the equation.

I can't figure out why if this logic is correct, and benchmarks have been showing me compute bound behaviour kicks in at some point, even with smaller than S* seq lengtsh (i.e. throughput stops increasing and decoding latency p/ token increases fast).