-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Refactoring of multi-head attention and support for KV caching #2061
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Started work to make sure all tests pass. |
Tests fail for me that should fail in ?? |
I'll submit a PR with a fix. |
e652799
to
0d6360b
Compare
Could you also link the PR here? |
I need to spend some more work on this one. Sorry was busy with other things. |
474797d
to
b8efa8c
Compare
In my local installation, there is only one test failing which I think I still need to attend to: There are also fails in these tests:
But for them, I don't really understand what is going on. |
OK, same thing in your CI/CD. I'd need help with these two tests, which use mocking in a way I do not understand. The one where "tensors are not close", I can deal with. |
06255ac
to
c8b8895
Compare
sure, which two? |
@Borda , I had to extend |
OK, we are down to 7 errors. 6 of which are about
The one mocking error I'd need help resolving. For the remaining 6, you need to decide. Maybe I see this wrong, but this |
2ed4013
to
98dfe11
Compare
Working on I know now why this passes in |
OK, now we are down to a single test failing, something to do with mocking. My solution here is to expand keys and values in the special case when query and key have the same length. This is what happens during training, or for inference with However, in subsequent cases, where query is much smaller than key, expanding is avoided, to save GPU memory. This turns out to be important for long context inference and fine-tuning. |
@t-vi , this is ready for review. The last test needs to be fixed on your side, I don't understand this mocking. Let me know whether I should split the PR into two. The first would generalize attention, the second introduce the KV cache abstraction. Also let me know if something more fundamental is not OK. |
04b74fa
to
40db977
Compare
@Borda , @t-vi . I am about to get open source approval from my employer to release a substantial amount of code for KV caching and also fine-tuning. This will be a separate repository, but sit on top of I am on holiday now from July 25 until August 18, but would take next steps after that. Also, I am helping with an university course, where we'd like to use |
Supercool. Enjoy your vacation! So from my side
I know the second part is not 100% intuitive if you haven't worked in the space but so
To my mind, having reasonable compiler support as part of the design these days rather than an afterthought greatly increases the impact that we can make. |
@t-vi , could you be more concrete about what changes you'd suggest? I have no particular preference. The Apart from that AFAIU understand, graph-based inference is spanning a compute graph, where the different calls to |
To avoid unnecessary cycles, could you sketch a brief list of what changes you recommend? |
As for compiler: Note that none of the To me, a static graph is assembled and compiled, where if But these inner states just depend on the subsequent chunk sizes, not on the content of the input tensors. If you do inference for a number of batches which can be chunked in the same way, the inference graph will be the same. Keeping |
It would be good to move forward. @t-vi , maybe you can put forward some concrete asks, in terms of missing tests, etc? In terms of API change. I can keep As I stated above, we could even get rid of this, because the KV cache internally keeps track of how many tokens have been processed. It is just to avoid silly errors to have We also need to understand when |
I'd like to understand the concerns about compilation and graphs. Each input batch has a different sequence length and batch size (potentially), so shapes will change. Depending on how inference works, the chunk sizes are going to be the same. Asking the graph to be exactly the same all the time would mean:
Is this really what happens in practice? |
cc736d6
to
ce91bbb
Compare
This is a bit more subtle. You can have a small number of shape configs, so e.g. batch size is not as much of a problem. Note that e.g. HF transformers has a "static" KV Cache class for compilation. I'd probably prefer to get the good stuff by default... |
@t-vi You may be informed about what happens with an exact (growing) KV cache. In that case, yes, a static compilation would not work. But modern KV caches are fixed size, with entries being overwritten, using |
However, I have not fully checked whether my code would work with graph compilation. I'll do that. |
Maybe this change for |
I'll first open source the complete library I have. I'll find a way to do this, by copying and modifying your I am a bit shocked how little fast inference with long contexts is supported. But I guess, for the moment, you are right, and this is maybe more of a niche. |
This continues from #1934 . I created a new branch, because the history of the previous one was messed up with a merge operation.
Adds abstraction for key-value caches, implements batched inference.
I am also adding two baseline KV caches, the default one from before (all KV are stored) and a last-recent one.
OK, this PR contains the following parts:
In the library I am writing, there are a number of additional more powerful KV caches, such as H2O and quantization-aware H2O. I am also working on fine-tuning in the presence of KV caches. The abstraction I propose here, enables all of that.
If these changes are not done, I'd have to copy and change quite a bit of your code. This would be hard to maintain, and would run the risk that KV caches are implemented differently at a later point, and then things really diverge.
As I said in the comments above, I found KV caching to be super-important to make large context inference work on a moderate GPU budget, which should be of interest to your customers as well.
Edit: Since I opened this, I am working a lot on gradient computation in the presence of long context models. This is stress-testing the abstraction here quite a bit.