-
Notifications
You must be signed in to change notification settings - Fork 13.2k
CUDA: skip masked KV slices for all FA kernels #14924
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CUDA: skip masked KV slices for all FA kernels #14924
Conversation
Yes, this should be correct. It's not specifically the SWA rework, but rather the split KV cache, combined with the new "set rows" approach for filling the batch into the KV cache. This combination should technically result in KV buffers that always have For SWA layers, this is not always the case, because the KV buffers is more like a ring buffer. But it does not matter much in this case anyways, because SWA KV buffers are typically very small. |
I have a dumb question, does this make supporting masks which are not at the tails impossible or just slower? |
Support for masks in general is not affected. The skipping logic is being narrowed to only skip the tails (this could in principle be extended). |
vllm lets you benchmark the throughput of OAI-compatible servers via
With a llama.cpp server running the FP16 model I get:
So after this PR llama.cpp is at ~62% of vllm's throughput. One anecdote from my work in physics is I think relevant: a colleague of mine needed to host a language model themself, the first choice was a solution using ggml since that was the easiest. And after that a switch to e.g. vllm was never done because the additional throughput performance was not necessary (this was prior to recent efforts to improve the llama.cpp throughput performance). |
I noticed that the TTFT token for llama.cpp is higher but that the time between output tokens is lower. So presumably since
Which puts llama.cpp request throughput, TTFT, and generation throughput at ~75% of vllm. vllm is to my knowledge doing the sampling with CUDA while llama.cpp uses a single CPU thread. And crucially the time needed for sampling is proportional to the number of parallel requests so I'm wondering if it's a significant performance bottleneck. Does the server have a built-in way to report the total time spent on sampling? Edit: By default vllm is frontloading all of the requests. |
In case it may be helpful, @ggerganov showed me in #14644 how to measure it precisely, as there are two components to the sampling, the logits to CPU data transfer time and the actual sampling time. In that case, the sampling time did dominate and was around ~25% of the total time (the rest of it being the decode step). But do note that it's sampling on the entire context |
The |
We don't report the sampling time at the moment.
It could be more because for 128 slots the logits data transfer could start to become significant and this is not measured by the |
It's I think unavoidable that we transfer some amount of data between RAM and VRAM for each graph evaluation, so we'll just have to eat the latency (but only once for all parallel sequences). In terms of the bandwidth, the GPUs are connected via PCIe 4 x16, for the given model with a vocab size of 151936 a total of 607744 bytes need to be copied. This then requires ~20 µs per token (vs. ~90 µs for the sampling itself). |
Question: I ran the following command to get the sampler overhead: ./cli --model models/opt/${model_name}-${quantization}.gguf --n-gpu-layers 99 -f navy_seals_copypasta.txt -n 1000 --ignore-eos -fa And I got this output:
I simply took the total sampling time as the time needed to sample 1000 tokens and scaled that up but in the print the sampling time "per token" is normalized to the number of generated tokens + the number of prompt tokens. Is this a bug? |
Yes. We feed the prompt tokens to the sampler chain through the |
Btw, I don't think we utilize CUDA graphs for bs>1 yet, correct? |
I don't think we do, but for large batch sizes the kernel launch overhead takes up a comparatively smaller percentage of the total runtime so it's not as impactful. |
I'm thinking that increasing the slots to 128 is not optimal. What would happen is that when a new request comes, the slot selection logic will find an unused slot and put the request there. This way, all 128 slots would eventually become "occupied" (remember, the server does not have a logic to determine that a slot is no longer used). Therefore, if we think how things look like somewhere in the middle of the benchmark - all slots have been used at least once by now and have some KV cache data in them. Let's say request 300 arrives and it gets assigned to the LRU (least-recently used slot) due to lack of "free" slots. This means that even if at this point there are let's say just 20 active requests, they could be distributed in any pattern across the 128 slots (i.e. KV cache streams). This is not optimal for stream slicing - we are going to be splitting the batch multiple times to find continuous sets of streams. Ideally, I would look for a benchmark that has a maximum of 8 requests at a given time. During runtime, it's allowed to have less than 8 active requests, but it is guaranteed that the script would not send a 9th parallel request when there are already 8 active. With 8 slots, the stream slicing issue is much less significant. Indeed, 8 slots is much fewer than 32, but realistically, having 32 concurrent users all the time is a use case that not that many users have (excluding Anthropic, OpenAI, etc.). |
Alternatively, we can actually fix the stream slicing issue completely, by adding a simple "sequence defrag" logic to the server: We would keep a map of (sorry for the spam in the PR, we can move the discussion to somewhere else if needed) |
I think this depends on what you are trying to benchmark. In a realistic scenario where you assume that the requests are made independently from one another the way vllm is doing it with a Poisson distribution is correct. The way I did it in Generally speaking, I think that 32 is already relatively high in terms of the number of slots (though I have yet to systematically investigate this).
Personally I don't really mind either way. |
I should mention, this will not necessarily result in better performance. With the CUDA FA code I'm for example padding the Q tensor to 4, 8, 16, 32, 64 tokens anyways (this number is reduced by the GQA ratio, 4 for LLaMA, 8 for Qwen 3 4b). MMQ has a granularity of 8 for <= 48 tokens and a granularity of 16 for > 48 tokens. How much impact sequence defragmentation would have would also depend on whether we are talking about dense or MoE models since I would expect MoE to have a higher ceiling in terms of batch size until optimal performance is reached. |
Actually, this doesn't matter. The FA code gets multiple sequences with a single token each, if you scale up the number of sequences that doesn't improve the arithmetic intensity like it does when you increase the number of tokens. What increases is the amount of work that each streaming multiprocessor can do in a single kernel launch. If the mask for an inactive sequence is all |
To do that I think I have to addd padding tokens for the inactive sequences and by doing so the graph-level implementation in |
Ultimately I think as long as the context is deep enough it doesn't matter anyways (the default of the vllm benchmark is 128 tokens at a depth of 1024). In the edge case of a very long context the runtime will be dominated by the attention so it doesn't matter whether or not you can use the weight matrices more efficiently. And for the attention kernel itself you will have enough work per sequence to compensate for a low number of sequences. |
The
At the default settings llama.cpp seems to be faster for <= 2 parallel requests. I'll try to isolate the impacts of prompt processing, the weights, and the attention by varying the prompt length and generation length. Probably would make sense to start a new discussion at that point. |
This PR adds support for skipping fully masked out KV slices for all CUDA FlashAttention kernels. The main benefit is higher batched server throughput with
LLAMA_SET_ROWS
but it also makes prompt processing with batch sizes >= 1024 a bit faster. The implementation works by checking how many values at the end of the KV cache in dimension 1 are fully masked out with-inf
with an auxiliary kernel (overhead <1% end-to-end). This then provides an upper limit for the iteration of the FA kernels over the KV cache, allowing them to skip the fully masked out section with minimal changes. This of course assumes that there are no significant regions that should be skipped other than at the end - I think with the SWA rework there aren't any but please correct me if I'm wrong. It would in principle be possible to pre-compute the max. extents of the KV cache in CPU code but I don't think it would be worth the additional complexity.I tried variants which also skipped loading the mask if it's all zero or which reweighted the workloads per SM to match the sequence lengths but I was not able to get meaningful performance gains.
To measure the performance difference vs. master I used
server-bench.py
with commands like this:(not the exact commands, there are some changes that I will still need to upstream). For the hardware I used 6x RTX 4090s with their frequencies limited to 1350 MHz. The GPUs ran 6 parallel benchmarks for each combination of setup and mean prediction length, I then did a linear regression with a quadratic model to isolate the part of the runtime that scales with the mean prediction length. With 32 parallel slots the end-to-end speedup for the range I looked at is up to 1.15, asymptotically it's something like 1.25. The speedup comes from the mma kernels scaling better with large workloads. Skipping masked out KV slices and using the mma kernel instead of the vector kernel seem to be of about equal importance.
Server performance plot
pp8192 t/s numbers