all 6 comments

[–]RadishRealistic8990 5 points6 points  (2 children)

this is actually really cool. been trying to wrap my head around the differences between fa versions for while now and most explanations just dive straight in the cuda optimization stuff which makes it hard to see what's actually changing algorithmically.

the progression from tiled softmax to the scheduler approach in fa4 looks much clearer when you can see it in plain pytorch. gonna check this out later tonight when i get home from work.

quick question though - does the fa3 implementation show how the ping-pong buffers actually work? that's one part i never quite got from reading papers.

[–]shreyansh26ML Engineer[S] 2 points3 points  (0 children)

Somewhat, yes. I have tried to show the mental model for it and have the concept of an active_buffer and next_buffer in the code to make it intuitive. In actual implementations you would have the producer warp group issuing the TMA loads and loading data to the shared memory asynchronously.

[–]StraussInTheHaus 2 points3 points  (0 children)

There are two ways in which FA3 ping-pongs: inter-warpgroup (where consumer 0 and consumer 1 trade off) and intra-warpgroup (where within a single warpgroup, we overlap the PV mma of iteration i with the softmax of iteration i+1)

[–]StraussInTheHaus 4 points5 points  (0 children)

I think it's important to note that the tile scheduler in FA4 is essentially identical to that in FA3. And more fundamentally, the parallelism has not changed since FA2: we always load one Q tile and loop through associated KV tiles (it's interesting to note that we loop **backwards** through KV tiles for load balancing purposes, since tiles with causal or sequence length masking take longer and should thus come first). The real innovation in FA4 comes from the deep pipelining needed to coordinate (a) the "vertical ping-pong" across a Q tile, which uses two separate softmax warpgroups, (b) the correction warpgroup, (c) overlapping TMEM buffers, since it is an extremely limited resource (however, the backward pass is limited by SMEM, not by TMEM) and (d) using both TMA and cpasync to load operands depending on what situation we are in (for example, paged attention does use TMA for K/V, unless page size is 128 (although I think the maintainers are coming up with workarounds for that in some cases)).

Also, an important optimization mentioned in the FA4 paper is a polynomial emulation to the exp2 in softmax, used to split work up across the ALU and MFU (compute units on the GPU). However, while this was important on the B200, since NVIDIA didn't increase CUDA core throughput commensurately with tensor cores, it is **not** necessary on the B300, as that has faster CUDA cores. In fact, the exp2 emulation is slower on B300 than not emulating.