-
Notifications
You must be signed in to change notification settings - Fork 30.6k
🚨All attention refactor🚨 #35235
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
🚨All attention refactor🚨 #35235
Conversation
0dc9253
to
d1aa9ce
Compare
src/transformers/modeling_utils.py
Outdated
) | ||
|
||
|
||
class GradientCheckpointLayer(torch.nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should help with kwargs as well
8b56823
to
ecd814b
Compare
…ngface#36024) * update * update * update * dev-ci * more changes * fix * fix * fix --------- Co-authored-by: ydshieh <[email protected]>
Breaking change in transformers is huggingface/transformers#35235. Need to make changes to unpin nv-a6000 workflow. Signed-off-by: gyou2021 <[email protected]>
Breaking change in transformers is huggingface/transformers#35235. Need to make changes to unpin nv-a6000 workflow. Signed-off-by: gyou2021 <[email protected]>
Breaking change in transformers is huggingface/transformers#35235. Need to make changes to unpin nv-a6000 workflow. Signed-off-by: yisheng <[email protected]>
* refactor LlamaAttention * minimal changes * fix llama * update * modular gemmas * modular nits * modular updates * nits * simplify * gpt2 * more modualr and fixes * granite * modular modular modular * nits * update * qwen2 + starcoder2 * mostly gemma2 * Update image_processing_auto.py * fix * Update modular_starcoder2.py * fix * remove all copied from attentions * remove gcv * make fix-copies * oups * oups2.0 * fix some modulars + all copied from * should be good now * revert unwanted changes * Update modeling_decision_transformer.py * finish cleanup * Update modeling_olmo.py * consistency * re-add gradient checkpointing attribute * fix * style * make config necessary * bis * bis * Update modeling_my_new_model2.py * is_causal attr * fix * remove past kv return from decoder layer * fix * default rope config * correctly fix rope config * fix bias * fix gpt2 attention output * fix test * fix inits * fix default sdpa * fix default sdpa implementation * harmonize classes * fix mistral * fix sliding window models * mixtral * be more explicit * style * fix * several fixes * Update modeling_dbrx.py * fix test * olmo + phi * rotary * syle * phi * phi again * again * kwargs * Update test_modeling_common.py * skip fx tracing tests * Update modeling_utils.py * gemma 2 * again * Update modeling_recurrent_gemma.py * gemma2 * granite * style * starcoder * Update sdpa_attention.py * switch args * Update modeling_mllama.py * fix * cache type tests * gpt2 * Update test_modeling_common.py * fix * consistency * fix shape with encoder * should be the last one * tests non model * most comments * small oupsi * be more explicit in modulars * more explicit modulars * CIs! it works locally * add kwargs to _flash_attention_forward --------- Co-authored-by: Cyril Vallez <[email protected]>
Breaking change in transformers is huggingface/transformers#35235. Need to make changes to unpin nv-a6000 workflow.
# Adds support for `transformers` as a backend Following huggingface/transformers#35235, a bunch of models should already be supported, we are ramping up support for more models. Thanks @Isotr0py for the TP support, and @hmellor for his help as well! This includes: - `trust_remote_code=True` support: any model on the hub, if it implements attention the correct way can be natively supported!! - tensor parallel support --------- Signed-off-by: Harry Mellor <[email protected]> Signed-off-by: Isotr0py <[email protected]> Co-authored-by: Isotr0py <[email protected]> Co-authored-by: Harry Mellor <[email protected]> Co-authored-by: Isotr0py <[email protected]> Co-authored-by: Cyrus Leung <[email protected]> Co-authored-by: Michael Goin <[email protected]> Co-authored-by: Isotr0py <[email protected]>
See huggingface/transformers#35235 (comment) for context. There has been a refactor in transformers that resulted in the rotary embedding of Mistral (and probably others) moving to the model level. This led to a device map used in one of the tests to being incorrect. This PR fixes the device map. Note that this fix doesn't really have anything to do with prefix tuning, the error occurred even before prefix tuning is used.
The changes in huggingface/transformers#35235 resulted in a couple of adaption prompt tests to fail. This PR fixes these failures while maintaining compatibility with older transformers versions. Required changes: - hidden_size attribute removed from model, now config.hidden_size - num_heads attribute removed from model, now config.num_attention_heads - forward now returns 2 outputs instead of 3, rewritten to be agnostic towards the number of outputs
# Adds support for `transformers` as a backend Following huggingface/transformers#35235, a bunch of models should already be supported, we are ramping up support for more models. Thanks @Isotr0py for the TP support, and @hmellor for his help as well! This includes: - `trust_remote_code=True` support: any model on the hub, if it implements attention the correct way can be natively supported!! - tensor parallel support --------- Signed-off-by: Harry Mellor <[email protected]> Signed-off-by: Isotr0py <[email protected]> Co-authored-by: Isotr0py <[email protected]> Co-authored-by: Harry Mellor <[email protected]> Co-authored-by: Isotr0py <[email protected]> Co-authored-by: Cyrus Leung <[email protected]> Co-authored-by: Michael Goin <[email protected]> Co-authored-by: Isotr0py <[email protected]>
**kwargs, | ||
) -> Tuple[torch.Tensor, None]: | ||
if hasattr(module, "num_key_value_groups"): | ||
key = repeat_kv(key, module.num_key_value_groups) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ArthurZucker Wondering if you know whether this repeat_kv is needed here? torch.nn.functional.scaled_dot_product_attention
allows num heads to be different for q and k/v, and having this expansion just defeats the purpose of doing GQA.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yuanyao-nv Sadly, yes it is due to multiple reasons:
- GQA support was only later added to sdpa (we support up to torch 2.1)
- The GQA support in sdpa is very limited
- Quoting https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html "It currently works only for Flash_attention and math kernel on CUDA tensor"
- However, since we more often than not use a mask, we require the memory-efficient backend (fa cannot work with masks in sdpa) to avoid the least efficient math backend
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On point!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the reply!
With regard to the flash attention support gap, how does introducing repeat_kv solve the issue? Does flash attention have better support for masked cases in non-GQA than GQA?
The reason I'm asking is that I'm trying to export models into FX graph and ONNX and I'm trying to preserve the SDPA op for Attention in the FX graph. The repeat_kv introduces additional ops that can lead to inefficiencies when a non-torch backend tries to parse such graphs.
I'm wondering if the use of repeat_kv should be configurable based on which backend is intended?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@vasqu @ArthurZucker any thoughts on the above questions? Thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another followup question (might be a naive one):
The use case you're referring to above seems to be attn_implementation="sdpa"
and also explicitly requesting the FA backend (maybe via a context manager with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
?)
However, if the user really intends to use FA, shouldn't they directly set attn_imlementation="flash_attention_2"
or "flash_attention_3"
? ie, why would they set attn_implementation="sdpa"
and then request the FA kernel?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, in general using flash_attention_2
is just the best / recommended way!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for confirming. In that case is there anything stopping us from removing repeat_kv
from the sdpa_attention path? It sounds like if the user really wants to use the FA backend they should not need to use this path in the first place.
And this will make the exported fx graph and ONNX graph more efficient and easier for other backends to work with.
I'd be happy to submit a PR for this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think #35235 (comment) still explains most reasons:
enable_gqa
is not available for earlier versions of torch (starts with 2.5.x)- If we have a mask, we need to avoid the math kernel (glad to see benchmarks to prove me wrong there) and instead use the xformers one (memory-efficient backend)
enable_gqa
works only with fa or math- If we have a mask, it already cuts out the fa kernel (torch internal restriction)
- Fallback to math kernel
- It's not about users requesting specific backends but moreso to avoid entering inefficient branches (math kernel) - if a user uses SDPA they should expect the more efficient backends...
The original flash attention is better in that way but SDPA is the standard attention as it's native torch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI, there is #39412 which partially enables this kwarg
See huggingface/transformers#35235 (comment) for context. There has been a refactor in transformers that resulted in the rotary embedding of Mistral (and probably others) moving to the model level. This led to a device map used in one of the tests to being incorrect. This PR fixes the device map. Note that this fix doesn't really have anything to do with prefix tuning, the error occurred even before prefix tuning is used.
The changes in huggingface/transformers#35235 resulted in a couple of adaption prompt tests to fail. This PR fixes these failures while maintaining compatibility with older transformers versions. Required changes: - hidden_size attribute removed from model, now config.hidden_size - num_heads attribute removed from model, now config.num_attention_heads - forward now returns 2 outputs instead of 3, rewritten to be agnostic towards the number of outputs
What does this PR do?
Todo in this PR: