Skip to content

Conversation

JohannesGaessler
Copy link
Collaborator

@JohannesGaessler JohannesGaessler commented Jul 28, 2025

This PR adds support for skipping fully masked out KV slices for all CUDA FlashAttention kernels. The main benefit is higher batched server throughput with LLAMA_SET_ROWS but it also makes prompt processing with batch sizes >= 1024 a bit faster. The implementation works by checking how many values at the end of the KV cache in dimension 1 are fully masked out with -inf with an auxiliary kernel (overhead <1% end-to-end). This then provides an upper limit for the iteration of the FA kernels over the KV cache, allowing them to skip the fully masked out section with minimal changes. This of course assumes that there are no significant regions that should be skipped other than at the end - I think with the SWA rework there aren't any but please correct me if I'm wrong. It would in principle be possible to pre-compute the max. extents of the KV cache in CPU code but I don't think it would be worth the additional complexity.

I tried variants which also skipped loading the mask if it's all zero or which reweighted the workloads per SM to match the sequence lengths but I was not able to get meaningful performance gains.

To measure the performance difference vs. master I used server-bench.py with commands like this:

export np=2000
LLAMA_ARG_N_PARALLEL=32 LLAMA_ARG_CTX_SIZE=700000 LLAMA_ARG_MODEL=/opt/models/llama_3.2_instruct-1b-q4_k_m.gguf python3 server-bench.py --path_server /home/johannesg/Projects/llama.cpp/build/bin/llama-server --n_prompt 1000 --prompt_source rng-0-19000 --n_predict_min $(($np/2)) --n_predict $(($np*3/2))

(not the exact commands, there are some changes that I will still need to upstream). For the hardware I used 6x RTX 4090s with their frequencies limited to 1350 MHz. The GPUs ran 6 parallel benchmarks for each combination of setup and mean prediction length, I then did a linear regression with a quadratic model to isolate the part of the runtime that scales with the mean prediction length. With 32 parallel slots the end-to-end speedup for the range I looked at is up to 1.15, asymptotically it's something like 1.25. The speedup comes from the mma kernels scaling better with large workloads. Skipping masked out KV slices and using the mma kernel instead of the vector kernel seem to be of about equal importance.

Server performance plot Figure_1
pp8192 t/s numbers
GPU Model Microbatch size Test t/s master t/s ddede2c Speedup
P40 llama 8B Q4_0 1 pp8192 46.42 46.33 1.00
P40 llama 8B Q4_0 2 pp8192 81.40 81.32 1.00
P40 llama 8B Q4_0 4 pp8192 103.14 103.07 1.00
P40 llama 8B Q4_0 8 pp8192 114.72 113.11 0.99
P40 llama 8B Q4_0 16 pp8192 326.02 325.81 1.00
P40 llama 8B Q4_0 32 pp8192 458.74 458.53 1.00
P40 llama 8B Q4_0 64 pp8192 519.93 517.83 1.00
P40 llama 8B Q4_0 128 pp8192 563.78 561.29 1.00
P40 llama 8B Q4_0 256 pp8192 595.52 594.80 1.00
P40 llama 8B Q4_0 512 pp8192 606.82 604.50 1.00
P40 llama 8B Q4_0 1024 pp8192 601.37 617.57 1.03
P40 llama 8B Q4_0 2048 pp8192 579.87 621.40 1.07
P40 llama 8B Q4_0 4096 pp8192 566.37 622.50 1.10
P40 llama 8B Q4_0 8192 pp8192 579.64 617.56 1.07
RTX 4090 llama 8B Q4_0 1 pp8192 167.71 168.13 1.00
RTX 4090 llama 8B Q4_0 2 pp8192 303.63 302.71 1.00
RTX 4090 llama 8B Q4_0 4 pp8192 593.33 590.97 1.00
RTX 4090 llama 8B Q4_0 8 pp8192 1006.47 1004.24 1.00
RTX 4090 llama 8B Q4_0 16 pp8192 1692.53 1687.89 1.00
RTX 4090 llama 8B Q4_0 32 pp8192 3093.43 3107.36 1.00
RTX 4090 llama 8B Q4_0 64 pp8192 5417.20 5416.87 1.00
RTX 4090 llama 8B Q4_0 128 pp8192 7988.06 7988.70 1.00
RTX 4090 llama 8B Q4_0 256 pp8192 10329.99 10342.08 1.00
RTX 4090 llama 8B Q4_0 512 pp8192 11552.18 11570.88 1.00
RTX 4090 llama 8B Q4_0 1024 pp8192 11772.45 11830.59 1.00
RTX 4090 llama 8B Q4_0 2048 pp8192 11416.92 11534.25 1.01
RTX 4090 llama 8B Q4_0 4096 pp8192 11400.16 11559.12 1.01
RTX 4090 llama 8B Q4_0 8192 pp8192 11402.69 11592.19 1.02
RX 6800 llama 8B Q4_0 1 pp8192 46.40 45.48 0.98
RX 6800 llama 8B Q4_0 2 pp8192 67.43 67.47 1.00
RX 6800 llama 8B Q4_0 4 pp8192 73.61 73.48 1.00
RX 6800 llama 8B Q4_0 8 pp8192 75.18 75.20 1.00
RX 6800 llama 8B Q4_0 16 pp8192 93.02 93.03 1.00
RX 6800 llama 8B Q4_0 32 pp8192 115.83 116.52 1.01
RX 6800 llama 8B Q4_0 64 pp8192 126.99 127.33 1.00
RX 6800 llama 8B Q4_0 128 pp8192 153.67 153.98 1.00
RX 6800 llama 8B Q4_0 256 pp8192 160.67 160.94 1.00
RX 6800 llama 8B Q4_0 512 pp8192 157.56 157.98 1.00
RX 6800 llama 8B Q4_0 1024 pp8192 146.14 157.55 1.08
RX 6800 llama 8B Q4_0 2048 pp8192 132.32 155.34 1.17
RX 6800 llama 8B Q4_0 4096 pp8192 132.32 155.52 1.18
RX 6800 llama 8B Q4_0 8192 pp8192 132.41 155.55 1.17

@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Jul 28, 2025
@ggerganov
Copy link
Member

This of course assumes that there are no significant regions that should be skipped other than at the end - I think with the SWA rework there aren't any but please correct me if I'm wrong.

Yes, this should be correct. It's not specifically the SWA rework, but rather the split KV cache, combined with the new "set rows" approach for filling the batch into the KV cache. This combination should technically result in KV buffers that always have -inf only at their tails.

For SWA layers, this is not always the case, because the KV buffers is more like a ring buffer. But it does not matter much in this case anyways, because SWA KV buffers are typically very small.

@am17an
Copy link
Collaborator

am17an commented Jul 29, 2025

I have a dumb question, does this make supporting masks which are not at the tails impossible or just slower?

@JohannesGaessler
Copy link
Collaborator Author

Support for masks in general is not affected. The skipping logic is being narrowed to only skip the tails (this could in principle be extended).

@JohannesGaessler
Copy link
Collaborator Author

vllm lets you benchmark the throughput of OAI-compatible servers via vllm bench serve. Using Qwen 2.5 Instruct 4b I get with vllm:

============ Serving Benchmark Result ============
Successful requests:                     989       
Benchmark duration (s):                  130.76    
Total input tokens:                      1010389   
Total generated tokens:                  122680    
Request throughput (req/s):              7.56      
Output token throughput (tok/s):         938.17    
Total Token throughput (tok/s):          8664.94   
---------------Time to First Token----------------
Mean TTFT (ms):                          61240.27  
Median TTFT (ms):                        60653.48  
P99 TTFT (ms):                           126005.17 
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          214.44    
Median TPOT (ms):                        241.81    
P99 TPOT (ms):                           243.99    
---------------Inter-token Latency----------------
Mean ITL (ms):                           214.41    
Median ITL (ms):                         242.34    
P99 ITL (ms):                            246.97    
==================================================

With a llama.cpp server running the FP16 model I get:

============ Serving Benchmark Result ============
Successful requests:                     989       
Benchmark duration (s):                  209.46    
Total input tokens:                      1010382   
Total generated tokens:                  120650    
Request throughput (req/s):              4.72      
Output token throughput (tok/s):         576.00    
Total Token throughput (tok/s):          5399.73   
---------------Time to First Token----------------
Mean TTFT (ms):                          102603.02 
Median TTFT (ms):                        104404.92 
P99 TTFT (ms):                           203475.61 
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          54.59     
Median TPOT (ms):                        52.79     
P99 TPOT (ms):                           105.28    
---------------Inter-token Latency----------------
Mean ITL (ms):                           53.06     
Median ITL (ms):                         32.49     
P99 ITL (ms):                            206.20    
==================================================

So after this PR llama.cpp is at ~62% of vllm's throughput. One anecdote from my work in physics is I think relevant: a colleague of mine needed to host a language model themself, the first choice was a solution using ggml since that was the easiest. And after that a switch to e.g. vllm was never done because the additional throughput performance was not necessary (this was prior to recent efforts to improve the llama.cpp throughput performance).

@JohannesGaessler
Copy link
Collaborator Author

JohannesGaessler commented Jul 29, 2025

I noticed that the TTFT token for llama.cpp is higher but that the time between output tokens is lower. So presumably since the requests sent by vllm are following a Poisson distribution some of the requests are stalling since all slots are in use. With 128 slots I get:

============ Serving Benchmark Result ============
Successful requests:                     989       
Benchmark duration (s):                  174.37    
Total input tokens:                      1010382   
Total generated tokens:                  120907    
Request throughput (req/s):              5.67      
Output token throughput (tok/s):         693.41    
Total Token throughput (tok/s):          6488.02   
---------------Time to First Token----------------
Mean TTFT (ms):                          81330.91  
Median TTFT (ms):                        79789.26  
P99 TTFT (ms):                           166280.86 
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          174.02    
Median TPOT (ms):                        176.90    
P99 TPOT (ms):                           351.10    
---------------Inter-token Latency----------------
Mean ITL (ms):                           170.33    
Median ITL (ms):                         170.78    
P99 ITL (ms):                            267.44    
==================================================

Which puts llama.cpp request throughput, TTFT, and generation throughput at ~75% of vllm. vllm is to my knowledge doing the sampling with CUDA while llama.cpp uses a single CPU thread. And crucially the time needed for sampling is proportional to the number of parallel requests so I'm wondering if it's a significant performance bottleneck. Does the server have a built-in way to report the total time spent on sampling?

Edit: By default vllm is frontloading all of the requests.

@am17an
Copy link
Collaborator

am17an commented Jul 29, 2025

Does the server have a built-in way to report the total time spent on sampling?

In case it may be helpful, @ggerganov showed me in #14644 how to measure it precisely, as there are two components to the sampling, the logits to CPU data transfer time and the actual sampling time. In that case, the sampling time did dominate and was around ~25% of the total time (the rest of it being the decode step). But do note that it's sampling on the entire context

@JohannesGaessler
Copy link
Collaborator Author

JohannesGaessler commented Jul 29, 2025

The /metrics endpoint seems to not report time spent sampling but based on how long sampling takes with llama-cli I think ~10 s in my benchmark were spent sampling.

@ggerganov
Copy link
Member

We don't report the sampling time at the moment.

The /metrics endpoint seems to not report time spend sampling but based on how long sampling takes with llama-cli I think ~10 s in my benchmark were spent sampling.

It could be more because for 128 slots the logits data transfer could start to become significant and this is not measured by the llama-cli. llama-cli just measures the CPU time for sampling, i.e. after the logits have been transferred to RAM.

@JohannesGaessler
Copy link
Collaborator Author

It's I think unavoidable that we transfer some amount of data between RAM and VRAM for each graph evaluation, so we'll just have to eat the latency (but only once for all parallel sequences). In terms of the bandwidth, the GPUs are connected via PCIe 4 x16, for the given model with a vocab size of 151936 a total of 607744 bytes need to be copied. This then requires ~20 µs per token (vs. ~90 µs for the sampling itself).

@JohannesGaessler
Copy link
Collaborator Author

JohannesGaessler commented Jul 29, 2025

Question: I ran the following command to get the sampler overhead:

./cli --model models/opt/${model_name}-${quantization}.gguf --n-gpu-layers 99 -f navy_seals_copypasta.txt -n 1000 --ignore-eos -fa

And I got this output:

llama_perf_sampler_print:    sampling time =      87,49 ms /  1354 runs   (    0,06 ms per token, 15476,59 tokens per second)
llama_perf_context_print:        load time =    1499,07 ms
llama_perf_context_print: prompt eval time =      96,63 ms /   354 tokens (    0,27 ms per token,  3663,34 tokens per second)
llama_perf_context_print:        eval time =    9650,69 ms /   999 runs   (    9,66 ms per token,   103,52 tokens per second)
llama_perf_context_print:       total time =   13842,70 ms /  1353 tokens
llama_perf_context_print:    graphs reused =        995

I simply took the total sampling time as the time needed to sample 1000 tokens and scaled that up but in the print the sampling time "per token" is normalized to the number of generated tokens + the number of prompt tokens. Is this a bug?

@ggerganov
Copy link
Member

Is this a bug?

Yes. We feed the prompt tokens to the sampler chain through the llama_sampler_accept and these tokens are counted as sampled. We should have a separate counter that keeps track just for the sampled tokens.

@ggerganov
Copy link
Member

Btw, I don't think we utilize CUDA graphs for bs>1 yet, correct?

@JohannesGaessler
Copy link
Collaborator Author

I don't think we do, but for large batch sizes the kernel launch overhead takes up a comparatively smaller percentage of the total runtime so it's not as impactful.

@ggerganov
Copy link
Member

I'm thinking that increasing the slots to 128 is not optimal. What would happen is that when a new request comes, the slot selection logic will find an unused slot and put the request there. This way, all 128 slots would eventually become "occupied" (remember, the server does not have a logic to determine that a slot is no longer used).

Therefore, if we think how things look like somewhere in the middle of the benchmark - all slots have been used at least once by now and have some KV cache data in them. Let's say request 300 arrives and it gets assigned to the LRU (least-recently used slot) due to lack of "free" slots.

This means that even if at this point there are let's say just 20 active requests, they could be distributed in any pattern across the 128 slots (i.e. KV cache streams). This is not optimal for stream slicing - we are going to be splitting the batch multiple times to find continuous sets of streams.

Ideally, I would look for a benchmark that has a maximum of 8 requests at a given time. During runtime, it's allowed to have less than 8 active requests, but it is guaranteed that the script would not send a 9th parallel request when there are already 8 active. With 8 slots, the stream slicing issue is much less significant. Indeed, 8 slots is much fewer than 32, but realistically, having 32 concurrent users all the time is a use case that not that many users have (excluding Anthropic, OpenAI, etc.).

@ggerganov
Copy link
Member

Alternatively, we can actually fix the stream slicing issue completely, by adding a simple "sequence defrag" logic to the server: We would keep a map of slot id -> seq id, and whenever we have active slots that result in sequence id gaps (for example, current active sequences are: 1, 10, 12, 40, 43) we will remap to fill the gaps and use llama_memory_seq_cp() to move the KV cache data into the new sequence ids that are continuous (e.g. 0, 1, 2, 3, 4).

(sorry for the spam in the PR, we can move the discussion to somewhere else if needed)

@JohannesGaessler
Copy link
Collaborator Author

Ideally, I would look for a benchmark that has a maximum of 8 requests at a given time. During runtime, it's allowed to have less than 8 active requests, but it is guaranteed that the script would not send a 9th parallel request when there are already 8 active. With 8 slots, the stream slicing issue is much less significant. Indeed, 8 slots is much fewer than 32, but realistically, having 32 concurrent users all the time is a use case that not that many users have (excluding Anthropic, OpenAI, etc.).

I think this depends on what you are trying to benchmark. In a realistic scenario where you assume that the requests are made independently from one another the way vllm is doing it with a Poisson distribution is correct. The way I did it in server-bench.py with a constant number of concurrent requests can also be correct for a scenario where you're just benchmarking the throughput for a single user with a large number of prompts or if you assume that you have another server in-between that makes sure a constant number of requests is being assigned to the llama.cpp server. The main reason I chose the benchmarking parameters the way I did was to establish a baseline for how close llama.cpp is to vllm in terms of performance (and also because the vllm tool worked out-of-the-box). 64 slots is actually almost the same performance as with 128 slots but in terms of how similar the performance metrics are 128 was simply the better point of comparison. A request waiting for a free slot is not necessarily bad depending on whether the server should prioritize pre-existing or new requests.

Generally speaking, I think that 32 is already relatively high in terms of the number of slots (though I have yet to systematically investigate this).

(sorry for the spam in the PR, we can move the discussion to somewhere else if needed)

Personally I don't really mind either way.

@JohannesGaessler
Copy link
Collaborator Author

JohannesGaessler commented Jul 29, 2025

Alternatively, we can actually fix the stream slicing issue completely, by adding a simple "sequence defrag" logic to the server

I should mention, this will not necessarily result in better performance. With the CUDA FA code I'm for example padding the Q tensor to 4, 8, 16, 32, 64 tokens anyways (this number is reduced by the GQA ratio, 4 for LLaMA, 8 for Qwen 3 4b). MMQ has a granularity of 8 for <= 48 tokens and a granularity of 16 for > 48 tokens.

How much impact sequence defragmentation would have would also depend on whether we are talking about dense or MoE models since I would expect MoE to have a higher ceiling in terms of batch size until optimal performance is reached.

@JohannesGaessler
Copy link
Collaborator Author

With the CUDA FA code I'm for example padding the Q tensor to 4, 8, 16, 32, 64 tokens anyways (this number is reduced by the GQA ratio, 4 for LLaMA, 8 for Qwen 3 4b).

Actually, this doesn't matter. The FA code gets multiple sequences with a single token each, if you scale up the number of sequences that doesn't improve the arithmetic intensity like it does when you increase the number of tokens. What increases is the amount of work that each streaming multiprocessor can do in a single kernel launch. If the mask for an inactive sequence is all -inf it can be skipped, though the workload across streaming multiprocessors could become unbalanced. But I think that can be fixed by simply permuting the indices with which the FA kernel distributes the work.

@ggerganov
Copy link
Member

If the mask for an inactive sequence is all -inf it can be skipped

To do that I think I have to addd padding tokens for the inactive sequences and by doing so the graph-level implementation in llama.cpp will become very complicated.

@JohannesGaessler
Copy link
Collaborator Author

JohannesGaessler commented Jul 29, 2025

Ultimately I think as long as the context is deep enough it doesn't matter anyways (the default of the vllm benchmark is 128 tokens at a depth of 1024). In the edge case of a very long context the runtime will be dominated by the attention so it doesn't matter whether or not you can use the weight matrices more efficiently. And for the attention kernel itself you will have enough work per sequence to compensate for a low number of sequences.

@JohannesGaessler JohannesGaessler merged commit 92b8810 into ggml-org:master Jul 30, 2025
47 checks passed
@JohannesGaessler
Copy link
Collaborator Author

JohannesGaessler commented Jul 30, 2025

The vllm bench utility lets you specify a maximum number of concurrent requests, in combination with the requests all being frontloaded this lets you run each server at a constant load:

Prompt length Generation length Parallel requests/slots Runtime llama.cpp PR [s] Runtime vllm [s]
1024 128 1 1340.83 1394.91
1024 128 2 815.55 885.74
1024 128 3 616.54 484.95
1024 128 4 526.53 485.92
1024 128 8 349.66 288.40
1024 128 16 254.05 186.55
1024 128 32 208.52 131.92
1024 128 64 184.03 114.80
1024 128 128 176.16 115.19

At the default settings llama.cpp seems to be faster for <= 2 parallel requests. I'll try to isolate the impacts of prompt processing, the weights, and the attention by varying the prompt length and generation length. Probably would make sense to start a new discussion at that point.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants