Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/modular-transformers/modeling_dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,7 @@ class DummyPreTrainedModel(PreTrainedModel):
_no_split_modules = ["DummyDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_2 = True
_supports_flash_attn_3 = True
Comment on lines 353 to +354
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

supporting 2 or 3 is equivalent to the model here so we can just keep 2 <=> 3?

_supports_sdpa = True
_supports_flex_attn = True
_supports_cache_class = True
Expand Down Expand Up @@ -594,7 +595,7 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
if self.config._attn_implementation == "flash_attention_2":
if "flash_attention" in self.config._attn_implementation:
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None
Expand Down
3 changes: 2 additions & 1 deletion examples/modular-transformers/modeling_multimodal1.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,7 @@ class Multimodal1TextPreTrainedModel(PreTrainedModel):
_no_split_modules = ["Multimodal1TextDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_2 = True
_supports_flash_attn_3 = True
_supports_sdpa = True
_supports_flex_attn = True
_supports_cache_class = True
Expand Down Expand Up @@ -594,7 +595,7 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
if self.config._attn_implementation == "flash_attention_2":
if "flash_attention" in self.config._attn_implementation:
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None
Expand Down
1 change: 1 addition & 0 deletions examples/modular-transformers/modeling_multimodal2.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,6 +637,7 @@ class Multimodal2VisionPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_supports_sdpa = True
_supports_flash_attn_2 = True
_supports_flash_attn_3 = True

def _init_weights(self, module):
"""Initialize the weights"""
Expand Down
3 changes: 2 additions & 1 deletion examples/modular-transformers/modeling_my_new_model2.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,7 @@ class MyNewModel2PreTrainedModel(PreTrainedModel):
_no_split_modules = ["MyNewModel2DecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_2 = True
_supports_flash_attn_3 = True
_supports_sdpa = True
_supports_flex_attn = True
_supports_cache_class = True
Expand Down Expand Up @@ -600,7 +601,7 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
if self.config._attn_implementation == "flash_attention_2":
if "flash_attention" in self.config._attn_implementation:
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None
Expand Down
3 changes: 2 additions & 1 deletion examples/modular-transformers/modeling_new_task_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ class NewTaskModelPreTrainedModel(PreTrainedModel):
_supports_quantized_cache = True
_supports_static_cache = True
_supports_flash_attn_2 = True
_supports_flash_attn_3 = True
_supports_sdpa = True

def _init_weights(self, module):
Expand Down Expand Up @@ -258,7 +259,7 @@ def _update_causal_mask(
input_tensor,
is_training: bool = False,
):
if self.config.text_config._attn_implementation == "flash_attention_2":
if "flash_attention" in self.config.text_config._attn_implementation:
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
Expand Down
3 changes: 2 additions & 1 deletion examples/modular-transformers/modeling_super.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,7 @@ class SuperPreTrainedModel(PreTrainedModel):
_no_split_modules = ["SuperDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_2 = True
_supports_flash_attn_3 = True
_supports_sdpa = True
_supports_flex_attn = True
_supports_cache_class = True
Expand Down Expand Up @@ -516,7 +517,7 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
if self.config._attn_implementation == "flash_attention_2":
if "flash_attention" in self.config._attn_implementation:
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None
Expand Down
99 changes: 99 additions & 0 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@
is_accelerate_available,
is_bitsandbytes_available,
is_flash_attn_2_available,
is_flash_attn_3_available,
is_offline_mode,
is_optimum_available,
is_peft_available,
Expand Down Expand Up @@ -1812,6 +1813,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Flash Attention 2 support
_supports_flash_attn_2 = False

# Flash Attention 3 support
_supports_flash_attn_3 = False

# SDPA support
_supports_sdpa = False

Expand Down Expand Up @@ -2087,6 +2091,8 @@ def _autoset_attn_implementation(
message = f'Specified `attn_implementation="{config._attn_implementation}"` is not supported. The only possible arguments are `attn_implementation="eager"` (manual attention implementation)'
if cls._supports_flash_attn_2:
message += ', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)'
if cls._supports_flash_attn_3:
message += ', `"attn_implementation=flash_attention_3"` (implementation using flash attention 3)'
if cls._supports_sdpa:
message += ', `"attn_implementation=sdpa"` (implementation using torch.nn.functional.scaled_dot_product_attention)'
if cls._supports_flex_attn:
Expand Down Expand Up @@ -2129,6 +2135,14 @@ def _autoset_attn_implementation(
hard_check_only=False,
check_device_map=check_device_map,
)
elif config._attn_implementation == "flash_attention_3":
cls._check_and_enable_flash_attn_3(
config,
torch_dtype=torch_dtype,
device_map=device_map,
hard_check_only=False,
check_device_map=check_device_map,
)
elif requested_attn_implementation == "flex_attention":
config = cls._check_and_enable_flex_attn(config, hard_check_only=True)
elif requested_attn_implementation in [None, "sdpa"] and not is_torch_xla_available():
Expand Down Expand Up @@ -2331,6 +2345,90 @@ def _check_and_enable_flash_attn_2(
config._attn_implementation = "flash_attention_2"
return config

@classmethod
def _check_and_enable_flash_attn_3(
cls,
config,
torch_dtype: Optional[torch.dtype] = None,
device_map: Optional[Union[str, Dict[str, int]]] = None,
check_device_map: bool = True,
hard_check_only: bool = False,
) -> PretrainedConfig:
"""
Checks the availability of Flash Attention 3 and compatibility with the current model.
If all checks pass and `hard_check_only` is False, the method will set the config attribute `attn_implementation` to "flash_attention_3" so that the model can initialize the correct attention module.
"""
if not cls._supports_flash_attn_3:
raise ValueError(
f"{cls.__name__} does not support Flash Attention 3.0 yet. Please request to add support where"
f" the model is hosted, on its model hub page: https://huggingface.co/{config._name_or_path}/discussions/new"
" or in the Transformers GitHub repo: https://github.com/huggingface/transformers/issues/new"
)

if not is_flash_attn_3_available():
preface = "FlashAttention3 has been toggled on, but it cannot be used due to the following error:"
# TODO: docs
install_message = "Please refer to the documentation of https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-3 to install Flash Attention 3."

if importlib.util.find_spec("flash_attn_interface") is None:
raise ImportError(
f"{preface} the package flash_attn_interface seems to be not installed. {install_message}"
)

if torch.version.cuda:
compute_capability = torch.cuda.get_device_capability()
major, _ = compute_capability
if major < 9:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A100 support has been recently added Dao-AILab/flash-attention#1481 (comment)

raise ValueError("Flash Attention 3 requires NVIDIA GPU with compute capability >= 9.0")
else:
raise ValueError("Flash Attention 3 requires NVIDIA GPU with compute capability >= 9.0")

_is_bettertransformer = getattr(cls, "use_bettertransformer", False)

if _is_bettertransformer:
raise ValueError(
"Flash Attention 3 and BetterTransformer API are not compatible. Please make sure to disable BetterTransformers by doing model.reverse_bettertransformer()"
)

if torch_dtype is None:
logger.warning_once(
"You are attempting to use Flash Attention 3.0 without specifying a torch dtype. This might lead to unexpected behaviour"
)
elif torch_dtype is not None and torch_dtype not in [torch.float16, torch.bfloat16]:
logger.warning_once(
"Flash Attention 3.0 only supports torch.float16 and torch.bfloat16 dtypes, but"
f" the current dype in {cls.__name__} is {torch_dtype}. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator,"
' or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_3", torch_dtype=torch.float16)`'
)

# The check `torch.empty(0).device.type != "cuda"` is needed as the model may be initialized after `torch.set_default_device` has been called,
# or the model may be initialized under the context manager `with torch.device("cuda"):`.
if check_device_map and device_map is None and torch.empty(0).device.type != "cuda":
if torch.cuda.is_available():
logger.warning_once(
"You are attempting to use Flash Attention 3.0 with a model not initialized on GPU. Make sure to move the model to GPU"
" after initializing it on CPU with `model.to('cuda')`."
)
else:
raise ValueError(
"You are attempting to use Flash Attention 3.0 with a model not initialized on GPU and with no GPU available. "
"This is not supported yet. Please make sure to have access to a GPU and either initialise the model on a GPU by passing a device_map "
"or initialising the model on CPU and then moving it to GPU."
)
elif (
check_device_map
and device_map is not None
and isinstance(device_map, dict)
and ("cpu" in device_map.values() or "disk" in device_map.values())
):
raise ValueError(
"You are attempting to use Flash Attention 3.0 with a model dispatched on CPU or disk. This is not supported. Please make sure to "
"initialise the model on a GPU by passing a device_map that contains only GPU devices as keys."
)
if not hard_check_only:
config._attn_implementation = "flash_attention_3"
return config

@classmethod
def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> PretrainedConfig:
"""
Expand Down Expand Up @@ -5889,6 +5987,7 @@ def get_disk_only_shard_files(device_map, weight_map):
ALL_ATTENTION_FUNCTIONS.update(
{
"flash_attention_2": flash_attention_forward,
"flash_attention_3": flash_attention_forward,
"flex_attention": flex_attention_forward,
"sdpa": sdpa_attention_forward,
}
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/aria/modeling_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -1017,7 +1017,7 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool = False,
):
if self.config._attn_implementation == "flash_attention_2":
if "flash_attention" in self.config._attn_implementation:
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None
Expand Down Expand Up @@ -1393,6 +1393,7 @@ def __init__(self, config: AriaConfig):
self.language_model = AutoModelForCausalLM.from_config(config.text_config)
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
self._use_flash_attention_2 = config.text_config._attn_implementation == "flash_attention_2"
self._use_flash_attention_3 = config.text_config._attn_implementation == "flash_attention_3"
self.post_init()

def _create_patch_attention_mask(self, pixel_mask):
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/aria/modular_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -1208,6 +1208,7 @@ class AriaTextPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = False
_supports_flash_attn_3 = False
_supports_sdpa = True
_supports_cache_class = True

Expand Down Expand Up @@ -1348,6 +1349,7 @@ class AriaCausalLMOutputWithPast(LlavaCausalLMOutputWithPast):
class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
config_class = AriaConfig
_supports_flash_attn_2 = False
_supports_flash_attn_3 = False
_supports_flex_attn = False
_supports_sdpa = False
_tied_weights_keys = ["language_model.lm_head.weight"]
Expand All @@ -1361,6 +1363,7 @@ def __init__(self, config: AriaConfig):
self.language_model = AutoModelForCausalLM.from_config(config.text_config)
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
self._use_flash_attention_2 = config.text_config._attn_implementation == "flash_attention_2"
self._use_flash_attention_3 = config.text_config._attn_implementation == "flash_attention_3"
self.post_init()

def _create_patch_attention_mask(self, pixel_mask):
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/bamba/modeling_bamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -1318,7 +1318,7 @@ def _update_causal_mask(
past_key_values: HybridMambaAttentionDynamicCache,
output_attentions: bool,
):
if self.config._attn_implementation == "flash_attention_2":
if "flash_attention" in self.config._attn_implementation:
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/bamba/modular_bamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -809,6 +809,7 @@ class BambaPreTrainedModel(PreTrainedModel):
_no_split_modules = ["BambaDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_flash_attn_3 = True
_supports_sdpa = True
_supports_cache_class = True # Note: only supports HybridMambaAttentionDynamicCache
_is_stateful = True
Expand Down Expand Up @@ -1059,7 +1060,7 @@ def _update_causal_mask(
past_key_values: HybridMambaAttentionDynamicCache,
output_attentions: bool,
):
if self.config._attn_implementation == "flash_attention_2":
if "flash_attention" in self.config._attn_implementation:
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
Expand Down
43 changes: 41 additions & 2 deletions src/transformers/models/bark/modeling_bark.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,7 @@ class BarkPreTrainedModel(PreTrainedModel):
config_class = BarkConfig
supports_gradient_checkpointing = False
_supports_flash_attn_2 = True
_supports_flash_attn_3 = True

def _init_weights(self, module):
"""Initialize the weights."""
Expand Down Expand Up @@ -561,6 +562,7 @@ def __init__(self, config):

self.layers = nn.ModuleList([BarkBlock(config, is_causal=True) for _ in range(config.num_layers)])
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3"

self.layernorm_final = BarkLayerNorm(config.hidden_size, bias=config.bias)

Expand Down Expand Up @@ -703,7 +705,7 @@ def forward(
if attention_mask is not None:
if batch_size <= 0:
raise ValueError("batch_size has to be defined and > 0")
if self._use_flash_attention_2:
if self._use_flash_attention_2 or self._use_flash_attention_3:
attention_mask = attention_mask if 0 in attention_mask else None
else:
attention_mask = attention_mask.view(batch_size, -1)
Expand Down Expand Up @@ -1158,6 +1160,7 @@ def __init__(self, config):

self.layers = nn.ModuleList([BarkBlock(config, is_causal=False) for _ in range(config.num_layers)])
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3"

self.layernorm_final = nn.LayerNorm(config.hidden_size)

Expand Down Expand Up @@ -1353,7 +1356,7 @@ def forward(
if attention_mask is not None:
if batch_size <= 0:
raise ValueError("batch_size has to be defined and > 0")
if self._use_flash_attention_2:
if self._use_flash_attention_2 or self._use_flash_attention_3:
attention_mask = attention_mask if 0 in attention_mask else None
else:
# [bsz, to_seq_length] -> [bsz, 1, 1, to_seq_length]
Expand Down Expand Up @@ -1838,6 +1841,42 @@ def _check_and_enable_flash_attn_2(
config.fine_acoustics_config._attn_implementation = config._attn_implementation
return config

@classmethod
def _check_and_enable_flash_attn_3(
cls,
config,
torch_dtype: Optional[torch.dtype] = None,
device_map: Optional[Union[str, Dict[str, int]]] = None,
hard_check_only: bool = False,
check_device_map: bool = False,
):
"""
`_check_and_enable_flash_attn_3` originally don't expand flash attention enabling to the model
sub-configurations. We override the original method to make sure that Bark sub-models are using Flash Attention
if necessary.

If you don't know about Flash Attention, check out the official repository of flash attention:
https://github.com/Dao-AILab/flash-attention

For using Flash Attention 1.0 you can do it directly via the `BetterTransformer` API, have a look at this
specific section of the documentation to learn more about it:
https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#decoder-models

The method checks if the current setup is compatible with Flash Attention as it requires the model to be in
half precision and not ran on CPU.

If all checks pass and `hard_check_only` is False, the method will set the config attribute `_attn_implementation` to "flash_attention_3" so that the model
can initialize the correct attention module
"""
config = super()._check_and_enable_flash_attn_3(
config, torch_dtype, device_map, hard_check_only=hard_check_only, check_device_map=check_device_map
)

config.semantic_config._attn_implementation = config._attn_implementation
config.coarse_acoustics_config._attn_implementation = config._attn_implementation
config.fine_acoustics_config._attn_implementation = config._attn_implementation
return config


__all__ = [
"BarkFineModel",
Expand Down
Loading