Skip to content

Conversation

mseeger
Copy link
Contributor

@mseeger mseeger commented May 30, 2025

This continues from #1934 . I created a new branch, because the history of the previous one was messed up with a merge operation.

Adds abstraction for key-value caches, implements batched inference.

I am also adding two baseline KV caches, the default one from before (all KV are stored) and a last-recent one.

OK, this PR contains the following parts:

  • Small things: Start of layer hook in GPT.forward, skip_lm_head in GPT.forward. I need these for gradient computation, but also to put proper head models on top of the transformer. This is generally useful.
  • Refactoring of multi-head attention: This is needed in order to implement the KV cache abstraction in the way @t-vi suggested (in a phone call). But it also really simplifies things. It also removes a major issue: mask_cache requires lots of memory, it is now computed on demand, with particular attention to inference (where query is much smaller than key)
  • Proper KV cache abstraction, which modifies slightly how GPT.forward is called (namely, input_pos as int). This simplifies things, though. I also provide a few default implementations. DenseKVCache replicates what is currently in place.

In the library I am writing, there are a number of additional more powerful KV caches, such as H2O and quantization-aware H2O. I am also working on fine-tuning in the presence of KV caches. The abstraction I propose here, enables all of that.

If these changes are not done, I'd have to copy and change quite a bit of your code. This would be hard to maintain, and would run the risk that KV caches are implemented differently at a later point, and then things really diverge.

As I said in the comments above, I found KV caching to be super-important to make large context inference work on a moderate GPU budget, which should be of interest to your customers as well.

Edit: Since I opened this, I am working a lot on gradient computation in the presence of long context models. This is stress-testing the abstraction here quite a bit.

@mseeger
Copy link
Contributor Author

mseeger commented May 30, 2025

Started work to make sure all tests pass.

@mseeger
Copy link
Contributor Author

mseeger commented May 30, 2025

@t-vi , @Borda , just a heads-up, I continue work in this PR, from #1934

@mseeger
Copy link
Contributor Author

mseeger commented May 30, 2025

Tests fail for me that should fail in mainline as well. For example, test_against_multimodal_gemma_3 in test_models.py fails in copy_weights_gemma_3, because the skip logic there checks for prefix "vision_tower" or "language_model", but the keys really start with "model.vision_tower" or "model.language_model".

??

@mseeger
Copy link
Contributor Author

mseeger commented May 30, 2025

I'll submit a PR with a fix.

@mseeger mseeger force-pushed the kvcache4 branch 10 times, most recently from e652799 to 0d6360b Compare June 6, 2025 10:16
@Borda
Copy link
Collaborator

Borda commented Jun 10, 2025

I'll submit a PR with a fix.

Could you also link the PR here?

@Borda Borda added enhancement New feature or request waiting on author labels Jun 18, 2025
@mseeger
Copy link
Contributor Author

mseeger commented Jun 18, 2025

I need to spend some more work on this one. Sorry was busy with other things.

@mseeger mseeger force-pushed the kvcache4 branch 2 times, most recently from 474797d to b8efa8c Compare June 18, 2025 19:06
@mseeger
Copy link
Contributor Author

mseeger commented Jun 18, 2025

In my local installation, there is only one test failing which I think I still need to attend to:
FAILED tests/test_adapter.py::test_against_original_gemma_3[device0-dtype0-gemma-3-27b-it] - AssertionError: Tensor-likes are not close!

There are also fails in these tests:

tests/generate/test_adapter.py: Something with mock??
tests/generate/test_main.py: Something with mock??
tests/test_tokenizer.py: Probably only works on CI/CD??

But for them, I don't really understand what is going on.

@mseeger
Copy link
Contributor Author

mseeger commented Jun 18, 2025

OK, same thing in your CI/CD. I'd need help with these two tests, which use mocking in a way I do not understand. The one where "tensors are not close", I can deal with.

@mseeger mseeger force-pushed the kvcache4 branch 2 times, most recently from 06255ac to c8b8895 Compare June 20, 2025 10:12
@Borda
Copy link
Collaborator

Borda commented Jun 23, 2025

I'd need help with these two tests, which use mocking in a way I do not understand.

sure, which two?

@mseeger
Copy link
Contributor Author

mseeger commented Jul 2, 2025

@Borda , I had to extend generate/sequentially.py a bit in order to make some tests pass. I think this may be interesting on its own, it is making sure that layer_to_device always works.

@mseeger
Copy link
Contributor Author

mseeger commented Jul 2, 2025

OK, we are down to 7 errors. 6 of which are about EFFICIENT_ATTENTION, and one more using mocking:

FAILED tests/test_generate_speculatively.py::test_main - AssertionError: assert [call(_Fabric...culative_k=3)] == [call(, ...culative_k=3)]

The one mocking error I'd need help resolving.

For the remaining 6, you need to decide. Maybe I see this wrong, but this EFFICIENT_ATTENTION kernel is not available on the GPU instances I tried on. Maybe I am doing something wrong, but it seems to be the same in your CI/CD system. Just curious why this did not happen before?

@mseeger mseeger force-pushed the kvcache4 branch 2 times, most recently from 2ed4013 to 98dfe11 Compare July 4, 2025 08:42
@mseeger
Copy link
Contributor Author

mseeger commented Jul 4, 2025

Working on test_model.py. There is something else going on here, seems like enable_gqa=True does not work with flash attention. Need to investigate.

I know now why this passes in main, because you do not use enable_gqa=True there and instead expand the key and value matrices. That is not great, bit wasteful of memory.

@mseeger
Copy link
Contributor Author

mseeger commented Jul 5, 2025

OK, now we are down to a single test failing, something to do with mocking.

My solution here is to expand keys and values in the special case when query and key have the same length. This is what happens during training, or for inference with input_pos=0. In this case, it is more important to make use of the efficient fused kernels than to save some memory not expanding the inputs.

However, in subsequent cases, where query is much smaller than key, expanding is avoided, to save GPU memory. This turns out to be important for long context inference and fine-tuning.

@mseeger
Copy link
Contributor Author

mseeger commented Jul 7, 2025

@t-vi , this is ready for review. The last test needs to be fixed on your side, I don't understand this mocking.

Let me know whether I should split the PR into two. The first would generalize attention, the second introduce the KV cache abstraction.

Also let me know if something more fundamental is not OK.

@mseeger mseeger force-pushed the kvcache4 branch 4 times, most recently from 04b74fa to 40db977 Compare July 12, 2025 07:52
@mseeger
Copy link
Contributor Author

mseeger commented Jul 20, 2025

@Borda , @t-vi . I am about to get open source approval from my employer to release a substantial amount of code for KV caching and also fine-tuning. This will be a separate repository, but sit on top of LitGPT.

I am on holiday now from July 25 until August 18, but would take next steps after that.

Also, I am helping with an university course, where we'd like to use LitGPT, starts in fall 2025.

@t-vi
Copy link
Collaborator

t-vi commented Aug 1, 2025

Supercool. Enjoy your vacation!

So from my side

  • we're a bit weary about API-breaking changes. They can be done, but ideally we should minimize them, as an approximation we should go over the changes to the tests and have a good think whether we can reasonably avoid them,
  • we need to have the default kvcache work with compilers well. This means:
    • don't pass in ints that change during generation, (ie no index position). The index_pos array worked very well for us.
    • don't index with slices from Tensor and avoid other casts from tensor to int ("data dependent control flow"),
    • do a functional interface thing and put things that can be tricky for compilers into functions (that the compiler can then divert / support with a different implementation).

I know the second part is not 100% intuitive if you haven't worked in the space but so

  • I have seen attention kernels take a "max_input_pos" tensor to avoid unnecessary work,
  • paged attention also does this.

To my mind, having reasonable compiler support as part of the design these days rather than an afterthought greatly increases the impact that we can make.

@mseeger
Copy link
Contributor Author

mseeger commented Aug 7, 2025

@t-vi , could you be more concrete about what changes you'd suggest? I have no particular preference.

The input_pos argument for KVCache.forward could be removed, instead we could have a boolean flag called prefill or so, which signals this is the first chunk to be processed. This is needed, because the cache is used for a number of inference calls. An alternative would be to require that a reset method (or so) is called up front for every inference call.

Apart from that input_pos is just a safety check, because most caches will keep this position internally and compare to input_pos passed in.

AFAIU understand, graph-based inference is spanning a compute graph, where the different calls to KVCache.forward are also different nodes, so I don't see why input_pos would be problematic. But I can remove it.

@mseeger
Copy link
Contributor Author

mseeger commented Aug 7, 2025

To avoid unnecessary cycles, could you sketch a brief list of what changes you recommend?

@mseeger
Copy link
Contributor Author

mseeger commented Aug 7, 2025

As for compiler: Note that none of the KVCache implementations are stateless operators. They do maintain some members which change depending on how they are called. The position in the sequence is just one example.

To me, a static graph is assembled and compiled, where if KVCache.forward is called several times, this maps to different nodes in the graph. For example, the chunk sizes are naturally different for many of these calls.

But these inner states just depend on the subsequent chunk sizes, not on the content of the input tensors. If you do inference for a number of batches which can be chunked in the same way, the inference graph will be the same. Keeping
the chunk sizes and sequence lengths the same across several inference calls may be a challenge, but that happens outside.

@mseeger
Copy link
Contributor Author

mseeger commented Aug 18, 2025

It would be good to move forward. @t-vi , maybe you can put forward some concrete asks, in terms of missing tests, etc?

In terms of API change. I can keep index_pos, but it essentially needs to be contiguous ranges, so it is really redundant to have to provide this index.

As I stated above, we could even get rid of this, because the KV cache internally keeps track of how many tokens have been processed. It is just to avoid silly errors to have input_pos.

We also need to understand when input_pos == 0, because this starts a new inference pass (which works in chunks). Unless we ask the user to signal that in a different way, e.g. some reset method or such.

@mseeger
Copy link
Contributor Author

mseeger commented Aug 18, 2025

I'd like to understand the concerns about compilation and graphs. Each input batch has a different sequence length and batch size (potentially), so shapes will change. Depending on how inference works, the chunk sizes are going to be the same.

Asking the graph to be exactly the same all the time would mean:

  • You need to have the same batch size
  • You need to pad everything to maximum length
  • You need to keep generating tokens until this max length is reached

Is this really what happens in practice?

@mseeger mseeger force-pushed the kvcache4 branch 2 times, most recently from cc736d6 to ce91bbb Compare August 20, 2025 21:34
@Borda
Copy link
Collaborator

Borda commented Sep 8, 2025

cc: @lantiga @t-vi

@t-vi
Copy link
Collaborator

t-vi commented Sep 9, 2025

Asking the graph to be exactly the same all the time would mean:

  • You need to have the same batch size
  • You need to pad everything to maximum length
  • You need to keep generating tokens until this max length is reached
    Is this really what happens in practice?

This is a bit more subtle. You can have a small number of shape configs, so e.g. batch size is not as much of a problem.
For next token generation, you actually typically have 1-element (or with speculative decoding a few) inputs, so it's
Similarly, you can have "maxpos aware attention", the maxpos being computed and fed as a CUDA tensor to the attention kernel. This is a bit involved (and for slicing the attention inputs this does not work well because CUDA->CPU induces a sync point), but it is what people actually do.

Note that e.g. HF transformers has a "static" KV Cache class for compilation. I'd probably prefer to get the good stuff by default...

@mseeger
Copy link
Contributor Author

mseeger commented Sep 29, 2025

@t-vi You may be informed about what happens with an exact (growing) KV cache. In that case, yes, a static compilation would not work. But modern KV caches are fixed size, with entries being overwritten, using scatter. This is fully compliant with what static graphs need. Even with the exact cache, you need to set its maximum size beforehand.

@mseeger
Copy link
Contributor Author

mseeger commented Sep 29, 2025

However, I have not fully checked whether my code would work with graph compilation. I'll do that.

@mseeger
Copy link
Contributor Author

mseeger commented Sep 29, 2025

Maybe this change for LitGPT is too big for the moment. In order to make efficient inference with sparse attention, and a form of gradient computation, work, I need a sizable number of changes, in particular to the naive multi-head attention. PyTorch SDPA does not help here, because they only really cater for the "training case" (where query and key are for the same tokens). Any other case already fails with the mask matrix being too large.

@mseeger
Copy link
Contributor Author

mseeger commented Sep 29, 2025

I'll first open source the complete library I have. I'll find a way to do this, by copying and modifying your model.py. We can then see how things develop.

I am a bit shocked how little fast inference with long contexts is supported. vLLM does not even support basic things like H2O. The existing implementations of sparse attention are rather poor. Everybody seems to be reluctant to move away from the cozy "attention with a single forward" case you have in training. Rather people parallelize across sequence length, as in ring attention (or context parallelism). That is all nice if you have many GPUs.

But I guess, for the moment, you are right, and this is maybe more of a niche.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants