-
Notifications
You must be signed in to change notification settings - Fork 30.7k
Flash Attention v3 #36190
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Flash Attention v3 #36190
Changes from all commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
efa7189
_supports_flash_attn_3
hlky b1fc52e
modeling_utils/import_utils
hlky 9a80143
config._attn_implementation/_use_flash_attention_3
hlky a526189
testing_utils
hlky a9717e7
make
hlky 1b5f20c
sliding_window
hlky aeb1d55
Merge branch 'main' into fav3
hlky ea85044
Update modeling_granitemoe.py
hlky af0d015
Update modeling_granitemoe.py
hlky 823386a
Merge branch 'main' into fav3
hlky c3aea43
Merge remote-tracking branch 'upstream/main' into fav3
hlky File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -95,6 +95,7 @@ | |
is_accelerate_available, | ||
is_bitsandbytes_available, | ||
is_flash_attn_2_available, | ||
is_flash_attn_3_available, | ||
is_offline_mode, | ||
is_optimum_available, | ||
is_peft_available, | ||
|
@@ -1812,6 +1813,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix | |
# Flash Attention 2 support | ||
_supports_flash_attn_2 = False | ||
|
||
# Flash Attention 3 support | ||
_supports_flash_attn_3 = False | ||
|
||
# SDPA support | ||
_supports_sdpa = False | ||
|
||
|
@@ -2087,6 +2091,8 @@ def _autoset_attn_implementation( | |
message = f'Specified `attn_implementation="{config._attn_implementation}"` is not supported. The only possible arguments are `attn_implementation="eager"` (manual attention implementation)' | ||
if cls._supports_flash_attn_2: | ||
message += ', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)' | ||
if cls._supports_flash_attn_3: | ||
message += ', `"attn_implementation=flash_attention_3"` (implementation using flash attention 3)' | ||
if cls._supports_sdpa: | ||
message += ', `"attn_implementation=sdpa"` (implementation using torch.nn.functional.scaled_dot_product_attention)' | ||
if cls._supports_flex_attn: | ||
|
@@ -2129,6 +2135,14 @@ def _autoset_attn_implementation( | |
hard_check_only=False, | ||
check_device_map=check_device_map, | ||
) | ||
elif config._attn_implementation == "flash_attention_3": | ||
cls._check_and_enable_flash_attn_3( | ||
config, | ||
torch_dtype=torch_dtype, | ||
device_map=device_map, | ||
hard_check_only=False, | ||
check_device_map=check_device_map, | ||
) | ||
elif requested_attn_implementation == "flex_attention": | ||
config = cls._check_and_enable_flex_attn(config, hard_check_only=True) | ||
elif requested_attn_implementation in [None, "sdpa"] and not is_torch_xla_available(): | ||
|
@@ -2331,6 +2345,90 @@ def _check_and_enable_flash_attn_2( | |
config._attn_implementation = "flash_attention_2" | ||
return config | ||
|
||
@classmethod | ||
def _check_and_enable_flash_attn_3( | ||
cls, | ||
config, | ||
torch_dtype: Optional[torch.dtype] = None, | ||
device_map: Optional[Union[str, Dict[str, int]]] = None, | ||
check_device_map: bool = True, | ||
hard_check_only: bool = False, | ||
) -> PretrainedConfig: | ||
""" | ||
Checks the availability of Flash Attention 3 and compatibility with the current model. | ||
If all checks pass and `hard_check_only` is False, the method will set the config attribute `attn_implementation` to "flash_attention_3" so that the model can initialize the correct attention module. | ||
""" | ||
if not cls._supports_flash_attn_3: | ||
raise ValueError( | ||
f"{cls.__name__} does not support Flash Attention 3.0 yet. Please request to add support where" | ||
f" the model is hosted, on its model hub page: https://huggingface.co/{config._name_or_path}/discussions/new" | ||
" or in the Transformers GitHub repo: https://github.com/huggingface/transformers/issues/new" | ||
) | ||
|
||
if not is_flash_attn_3_available(): | ||
preface = "FlashAttention3 has been toggled on, but it cannot be used due to the following error:" | ||
# TODO: docs | ||
install_message = "Please refer to the documentation of https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-3 to install Flash Attention 3." | ||
|
||
if importlib.util.find_spec("flash_attn_interface") is None: | ||
raise ImportError( | ||
f"{preface} the package flash_attn_interface seems to be not installed. {install_message}" | ||
) | ||
|
||
if torch.version.cuda: | ||
compute_capability = torch.cuda.get_device_capability() | ||
major, _ = compute_capability | ||
if major < 9: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A100 support has been recently added Dao-AILab/flash-attention#1481 (comment) |
||
raise ValueError("Flash Attention 3 requires NVIDIA GPU with compute capability >= 9.0") | ||
else: | ||
raise ValueError("Flash Attention 3 requires NVIDIA GPU with compute capability >= 9.0") | ||
|
||
_is_bettertransformer = getattr(cls, "use_bettertransformer", False) | ||
|
||
if _is_bettertransformer: | ||
raise ValueError( | ||
"Flash Attention 3 and BetterTransformer API are not compatible. Please make sure to disable BetterTransformers by doing model.reverse_bettertransformer()" | ||
) | ||
|
||
if torch_dtype is None: | ||
logger.warning_once( | ||
"You are attempting to use Flash Attention 3.0 without specifying a torch dtype. This might lead to unexpected behaviour" | ||
) | ||
elif torch_dtype is not None and torch_dtype not in [torch.float16, torch.bfloat16]: | ||
logger.warning_once( | ||
"Flash Attention 3.0 only supports torch.float16 and torch.bfloat16 dtypes, but" | ||
f" the current dype in {cls.__name__} is {torch_dtype}. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator," | ||
' or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_3", torch_dtype=torch.float16)`' | ||
) | ||
|
||
# The check `torch.empty(0).device.type != "cuda"` is needed as the model may be initialized after `torch.set_default_device` has been called, | ||
# or the model may be initialized under the context manager `with torch.device("cuda"):`. | ||
if check_device_map and device_map is None and torch.empty(0).device.type != "cuda": | ||
if torch.cuda.is_available(): | ||
logger.warning_once( | ||
"You are attempting to use Flash Attention 3.0 with a model not initialized on GPU. Make sure to move the model to GPU" | ||
" after initializing it on CPU with `model.to('cuda')`." | ||
) | ||
else: | ||
raise ValueError( | ||
"You are attempting to use Flash Attention 3.0 with a model not initialized on GPU and with no GPU available. " | ||
"This is not supported yet. Please make sure to have access to a GPU and either initialise the model on a GPU by passing a device_map " | ||
"or initialising the model on CPU and then moving it to GPU." | ||
) | ||
elif ( | ||
check_device_map | ||
and device_map is not None | ||
and isinstance(device_map, dict) | ||
and ("cpu" in device_map.values() or "disk" in device_map.values()) | ||
): | ||
raise ValueError( | ||
"You are attempting to use Flash Attention 3.0 with a model dispatched on CPU or disk. This is not supported. Please make sure to " | ||
"initialise the model on a GPU by passing a device_map that contains only GPU devices as keys." | ||
) | ||
if not hard_check_only: | ||
config._attn_implementation = "flash_attention_3" | ||
return config | ||
|
||
@classmethod | ||
def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> PretrainedConfig: | ||
""" | ||
|
@@ -5889,6 +5987,7 @@ def get_disk_only_shard_files(device_map, weight_map): | |
ALL_ATTENTION_FUNCTIONS.update( | ||
{ | ||
"flash_attention_2": flash_attention_forward, | ||
"flash_attention_3": flash_attention_forward, | ||
"flex_attention": flex_attention_forward, | ||
"sdpa": sdpa_attention_forward, | ||
} | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
supporting 2 or 3 is equivalent to the model here so we can just keep 2 <=> 3?