Skip to content

Conversation

hlky
Copy link
Contributor

@hlky hlky commented Feb 14, 2025

What does this PR do?

Replaces #33522 to avoid conflicts and allow those using it to continue while we get it updated for #35235

Initial commit of this PR adds auxiliary code so we can discuss the core FAv3 integration.

cc @ArthurZucker

  • Integrate FAv3 into _flash_attention_forward/flash_attention_forward as before or create new functions?
  • Some models still have FlashAttention2 classes, is refactoring all models to use the new style planned? Integrate FAv3 as before or do the refactor in this PR?

Also to check:

  • Status of dropout, softcap etc
  • Status of FP8
  • Packaging

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a preheader to warn/inform you on some stuff regarding the current status of fa3:

if torch.version.cuda:
compute_capability = torch.cuda.get_device_capability()
major, _ = compute_capability
if major < 9:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A100 support has been recently added Dao-AILab/flash-attention#1481 (comment)

@vasqu
Copy link
Contributor

vasqu commented Feb 14, 2025

cc @bn999 if you're interested about the progress

@bn999
Copy link

bn999 commented Feb 14, 2025

@vasqu Yup, I'm following. Good stuff.

@hlky
Copy link
Contributor Author

hlky commented Feb 18, 2025

Thanks for the info @vasqu

@hlky
Copy link
Contributor Author

hlky commented Feb 24, 2025

Gentle ping @ArthurZucker

  • Integrate FAv3 into _flash_attention_forward/flash_attention_forward as before or create new functions?
  • Some models still have FlashAttention2 classes, is refactoring all models to use the new style planned? Integrate FAv3 as before or do the refactor in this PR?

@jianguoz
Copy link

jianguoz commented Mar 14, 2025

Hi @ArthurZucker @hlky @vasqu @muellerzr , thanks for the great efforts to integrate Flash Attention 3 😁. Do we have any plans to merge this PR?

@sam-h-bean
Copy link
Contributor

Hey quick thing here @hlky, if you have FA3 installed but not FA2 (which I believe is a valid way it is used in other repos like TE) you end up failing the is_flash_attn_2_available check and get _flash_attention_forward is not even thought the check and enable FA3 function passes. Not sure if intentional or a bug, but if intentional a better guard could help tell people both FA2 and FA3 are required?

@hlky
Copy link
Contributor Author

hlky commented Mar 21, 2025

Hi @sam-h-bean. At the time this PR was started (more specifically, the original PR #33522) pad functions were not available in FAv3, therefore FAv2 was required. As per #36190 (review) this is likely no longer required and will be updated when this PR is finished. At the moment we are waiting for comments from a core-maintainer, @ArthurZucker, regarding #36190 (comment).

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Answering!

Integrate FAv3 into _flash_attention_forward/flash_attention_forward as before or create new functions?

I think if API changes are not too big we can use the same

Some models still have FlashAttention2 classes, is refactoring all models to use the new style planned? Integrate FAv3 as before or do the refactor in this PR?

would be nice to have in a separate PR!

Happy to merge as is!

Comment on lines 353 to +354
_supports_flash_attn_2 = True
_supports_flash_attn_3 = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

supporting 2 or 3 is equivalent to the model here so we can just keep 2 <=> 3?

@hlky hlky closed this Apr 15, 2025
@hlky hlky deleted the fav3 branch April 15, 2025 12:28
@ArthurZucker
Copy link
Collaborator

@hlky sorry I probably forgot to merge 😓 don't worry we'll push trough and add support!

@hlky
Copy link
Contributor Author

hlky commented May 16, 2025

@ArthurZucker lol it's cool, the PR wasn't finished because I had been waiting for your response, didn't have time in Paris then I was fired so I closed it 🤷‍♂️

EduardDurech added a commit to swiss-ai/transformers that referenced this pull request Jun 18, 2025
Implements fwd and tests for Flash Attention 3 https://github.com/Dao-AILab/flash-attention/commits/main/hopper

- Includes checks for dropout>0 and ALiBi in `modeling_utils.PreTrainedModel._check_and_enable_flash_attn_3` (Dropout will likely be supported soon, so this will need to be updated and `modeling_flash_attention_utils._flash_attention_forward` at the `if _IS_FLASH_ATTN_3_AVAILABLE: ...`

An example Llama implementation is included in `modeling_llama.py` but other models would still need to be updated

Based on huggingface#36190 which has model implementations and examples which could be merged
EduardDurech added a commit to swiss-ai/transformers that referenced this pull request Jun 22, 2025
Implements fwd and tests for Flash Attention 3 https://github.com/Dao-AILab/flash-attention/commits/main/hopper

- Includes checks for dropout>0 and ALiBi in `modeling_utils.PreTrainedModel._check_and_enable_flash_attn_3` (Dropout will likely be supported soon, so this will need to be updated and `modeling_flash_attention_utils._flash_attention_forward` at the `if _IS_FLASH_ATTN_3_AVAILABLE: ...`

An example Llama implementation is included in `modeling_llama.py` but other models would still need to be updated

Based on huggingface#36190 which has model implementations and examples which could be merged
EduardDurech added a commit to swiss-ai/transformers that referenced this pull request Jun 22, 2025
Implements fwd and tests for Flash Attention 3 https://github.com/Dao-AILab/flash-attention/commits/main/hopper

- Includes checks for dropout>0 and ALiBi in `modeling_utils.PreTrainedModel._check_and_enable_flash_attn_3` (Dropout will likely be supported soon, so this will need to be updated and `modeling_flash_attention_utils._flash_attention_forward` at the `if _IS_FLASH_ATTN_3_AVAILABLE: ...`

An example Llama implementation is included in `modeling_llama.py` but other models would still need to be updated

Based on huggingface#36190 which has model implementations and examples which could be merged
ArthurZucker pushed a commit that referenced this pull request Jun 25, 2025
* Support `flash_attn_3`
Implements fwd and tests for Flash Attention 3 https://github.com/Dao-AILab/flash-attention/commits/main/hopper

- Includes checks for dropout>0 and ALiBi in `modeling_utils.PreTrainedModel._check_and_enable_flash_attn_3` (Dropout will likely be supported soon, so this will need to be updated and `modeling_flash_attention_utils._flash_attention_forward` at the `if _IS_FLASH_ATTN_3_AVAILABLE: ...`

An example Llama implementation is included in `modeling_llama.py` but other models would still need to be updated

Based on #36190 which has model implementations and examples which could be merged

* Add tests for Flash Attention 2 and 3 parity

* ci fix

* FA2 compatibiity
- `_prepare_flash_attention_from_position_ids` ->`prepare_fa2_from_position_ids`
- Remove bettertransformer check in Flash Attention 3
- Merge tests
- Add licensing

* ci fix

* Test naming consistency

* ci fix

* Deprecation warning for `prepare_fa2_from_position_ids`

* ci fix
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants