Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ line-ending = "auto"
addopts = "--doctest-glob='**/*.md'"
doctest_optionflags="NUMBER NORMALIZE_WHITESPACE ELLIPSIS"
markers = [
"flash_attn_3_test: marks tests related to flash attention 3 (deselect with '-m \"not flash_attn_3_test\"')",
"flash_attn_test: marks tests related to flash attention (deselect with '-m \"not flash_attn_test\"')",
"bitsandbytes: select (or deselect with `not`) bitsandbytes integration tests",
"generate: marks tests that use the GenerationTesterMixin"
Expand Down
1 change: 1 addition & 0 deletions src/transformers/integrations/flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def flash_attention_forward(
softcap=softcap,
use_top_left_mask=_use_top_left_mask,
target_dtype=target_dtype,
attn_implementation=module.config._attn_implementation,
**kwargs,
)

Expand Down
198 changes: 174 additions & 24 deletions src/transformers/modeling_flash_attention_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@

import inspect
import os
import warnings
from typing import Optional, TypedDict

import torch
import torch.nn.functional as F

from .utils import (
is_flash_attn_2_available,
is_flash_attn_3_available,
is_flash_attn_greater_or_equal,
is_flash_attn_greater_or_equal_2_10,
is_torch_npu_available,
Expand All @@ -32,18 +34,123 @@
flash_attn_func = None


if is_flash_attn_2_available():
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.layers.rotary import apply_rotary_emb # noqa
def _index_first_axis(tensor, indices):
"""
A local implementation of the PyTorch indexing operation `tensor[indices]` on the first axis,
after flattening the first two dimensions of the tensor. This is functionally equivalent to
FA2's `index_first_axis` and replaces the need to import it.
"""
# The input tensor is expected to be of shape (batch, seq_len, ...). We flatten the first
# two dimensions to get (total_tokens, ...) before indexing.
reshaped_tensor = tensor.reshape(-1, *tensor.shape[2:])
return reshaped_tensor[indices]


def _fa3_unpad_input(hidden_states, attention_mask, unused_mask=None):
"""
FA3-compatible unpad_input function.

Arguments:
hidden_states: (batch, seqlen, ...)
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused.
Return:
hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask.
indices: (total_nnz), the indices of masked tokens from the flattened input sequence.
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
max_seqlen_in_batch: int
seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask.
"""
all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask
seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32)
used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))

return (
_index_first_axis(hidden_states, indices),
indices,
cu_seqlens,
max_seqlen_in_batch,
used_seqlens_in_batch,
)


def _fa3_pad_input(hidden_states, indices, batch, seqlen):
"""
FA3-compatible pad_input function.

Arguments:
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
batch: int, batch size for the padded sequence.
seqlen: int, maximum sequence length for the padded sequence.
Return:
hidden_states: (batch, seqlen, ...)
"""
dim = hidden_states.shape[1:]
output = torch.zeros((batch * seqlen), *dim, device=hidden_states.device, dtype=hidden_states.dtype)
output[indices] = hidden_states
return output.view(batch, seqlen, *dim)


FA_VERSION = None
if is_flash_attn_2_available():
from flash_attn import flash_attn_func as flash_attn_2_func
from flash_attn import flash_attn_varlen_func as flash_attn_2_varlen_func
from flash_attn.bert_padding import pad_input as pad_input_fa2
from flash_attn.bert_padding import unpad_input as unpad_input_fa2
from flash_attn.layers.rotary import apply_rotary_emb

HAS_FA2 = True
FA_VERSION = 2
else:
flash_attn_2_func = None
flash_attn_2_varlen_func = None
pad_input_fa2 = None
unpad_input_fa2 = None
apply_rotary_emb = None
HAS_FA2 = False

if is_flash_attn_3_available():
from flash_attn_interface import flash_attn_func as flash_attn_3_func
from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func

pad_input_fa3 = _fa3_pad_input
unpad_input_fa3 = _fa3_unpad_input
HAS_FA3 = True
FA_VERSION = 3
else:
flash_attn_3_func = None
flash_attn_3_varlen_func = None
pad_input_fa3 = None
unpad_input_fa3 = None
HAS_FA3 = False


# Current Flash Attention implementations
if FA_VERSION:
flash_attn_func = globals()[f"flash_attn_{FA_VERSION}_func"]
flash_attn_varlen_func = globals()[f"flash_attn_{FA_VERSION}_varlen_func"]
unpad_input = globals()[f"unpad_input_fa{FA_VERSION}"]
pad_input = globals()[f"pad_input_fa{FA_VERSION}"]

# patch functions in package `flash-attn` when using flash-attention on Ascend NPU.
if is_torch_npu_available():
from .integrations.npu_flash_attention import index_first_axis, pad_input, unpad_input
from .integrations.npu_flash_attention import npu_apply_rotary_emb as apply_rotary_emb # noqa
from .integrations.npu_flash_attention import npu_flash_attn_func as flash_attn_func
from .integrations.npu_flash_attention import npu_flash_attn_varlen_func as flash_attn_varlen_func
from .integrations.npu_flash_attention import (
npu_apply_rotary_emb as apply_rotary_emb, # noqa: F401
)
from .integrations.npu_flash_attention import (
npu_flash_attn_func as flash_attn_func,
)
from .integrations.npu_flash_attention import (
npu_flash_attn_varlen_func as flash_attn_varlen_func,
)
from .integrations.npu_flash_attention import (
pad_input,
unpad_input,
)


_flash_supports_window_size = False
Expand All @@ -56,6 +163,9 @@
def is_flash_attn_available():
"""Determine whether flash-attention can be used or not."""

if is_flash_attn_3_available():
return True

# if package `flash-attn` is available, flash-attention can be used natively.
if is_flash_attn_2_available():
return True
Expand All @@ -70,6 +180,9 @@ def is_flash_attn_available():
def flash_attn_supports_top_left_mask():
"""Determine whether flash-attention uses top-left or down-right mask"""

if is_flash_attn_3_available():
return False

if is_flash_attn_2_available():
# top-left mask is used in package `flash-attn` with version lower than 2.1.0
return not is_flash_attn_greater_or_equal_2_10()
Expand Down Expand Up @@ -116,6 +229,7 @@ def _upad_input(
value_layer: torch.Tensor,
attention_mask: torch.Tensor,
query_length: int,
unpad_input_func,
):
"""
Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches.
Expand All @@ -134,6 +248,8 @@ def _upad_input(
Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
query_length (`int`):
Target length.
unpad_input_func:
The function to use for unpadding the input tensors.

Return:
query_layer (`torch.Tensor`):
Expand All @@ -158,12 +274,10 @@ def _upad_input(

batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape

key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k)
value_layer = index_first_axis(
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
key_layer = _index_first_axis(key_layer, indices_k)
value_layer = _index_first_axis(value_layer, indices_k)
if query_length == kv_seq_len:
query_layer = index_first_axis(query_layer.reshape(batch_size * kv_seq_len, -1, head_dim), indices_k)
query_layer = _index_first_axis(query_layer, indices_k)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
Expand All @@ -177,7 +291,7 @@ def _upad_input(
else:
# The -q_len: slice assumes left padding.
attention_mask = attention_mask[:, -query_length:]
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q, *_ = unpad_input(query_layer, attention_mask)
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q, *_ = unpad_input_func(query_layer, attention_mask)

return (
query_layer,
Expand All @@ -189,7 +303,7 @@ def _upad_input(
)


def prepare_fa2_from_position_ids(query, key, value, position_ids):
def _prepare_flash_attention_from_position_ids(query, key, value, position_ids):
"""
This function returns necessary arguments to call `flash_attn_varlen_func`.
All three query, key, value states will be flattened.
Expand Down Expand Up @@ -239,6 +353,14 @@ def prepare_fa2_from_position_ids(query, key, value, position_ids):
return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length))


def prepare_fa2_from_position_ids(*args, **kwargs):
warnings.warn(
"The function `prepare_fa2_from_position_ids` in `transformers.modeling_flash_attention_utils` is deprecated and will be removed in a future version. Please use `_prepare_flash_attention_from_position_ids` instead.",
FutureWarning,
)
return _prepare_flash_attention_from_position_ids(*args, **kwargs)


def fa_peft_integration_check(
query: torch.Tensor,
key: torch.Tensor,
Expand Down Expand Up @@ -303,6 +425,7 @@ def _flash_attention_forward(
max_length_q: Optional[int] = None,
max_length_k: Optional[int] = None,
target_dtype: Optional[torch.dtype] = None,
attn_implementation: Optional[str] = None,
**kwargs,
):
"""
Expand All @@ -329,7 +452,28 @@ def _flash_attention_forward(
Softcap for the attention logits, used e.g. in gemma2.
deterministic (`bool`, *optional*):
Determines if the deterministic option introduced in flash_attn>=2.4.1 is enabled.
attn_implementation (`str`, *optional*):
The attention implementation to use. If None, will default to the one based on the environment.
"""
if attn_implementation is None:
_flash_attn_varlen_func = flash_attn_varlen_func
_flash_attn_func = flash_attn_func
_pad_input = pad_input
_unpad_input = unpad_input
_is_fa3 = HAS_FA3
elif attn_implementation == "flash_attention_3":
_flash_attn_varlen_func = flash_attn_3_varlen_func
_flash_attn_func = flash_attn_3_func
_pad_input = pad_input_fa3
_unpad_input = unpad_input_fa3
_is_fa3 = True
elif attn_implementation == "flash_attention_2":
_flash_attn_varlen_func = flash_attn_2_varlen_func
_flash_attn_func = flash_attn_2_func
_pad_input = pad_input_fa2
_unpad_input = unpad_input_fa2
_is_fa3 = False

if not use_top_left_mask:
causal = is_causal
else:
Expand All @@ -342,6 +486,12 @@ def _flash_attention_forward(
)
flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {}

if _is_fa3:
if dropout > 0.0:
logger.warning_once("Flash Attention 3 does not support dropout. Setting dropout to 0.0.")
else:
flash_kwargs["dropout_p"] = dropout

if flash_241:
if deterministic is None:
global deterministic_g
Expand All @@ -362,25 +512,24 @@ def _flash_attention_forward(
if attention_mask is not None:
batch_size = query_states.shape[0]
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = _upad_input(
query_states, key_states, value_states, attention_mask, query_length
query_states, key_states, value_states, attention_mask, query_length, _unpad_input
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens

attn_output_unpad = flash_attn_varlen_func(
attn_output_unpad = _flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
**flash_kwargs,
)
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
attn_output = _pad_input(attn_output_unpad, indices_q, batch_size, query_length)

# If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing
# then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage.
Expand All @@ -394,7 +543,7 @@ def _flash_attention_forward(

if cu_seq_lens_q is None or cu_seq_lens_k is None:
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = (
prepare_fa2_from_position_ids(query_states, key_states, value_states, position_ids)
_prepare_flash_attention_from_position_ids(query_states, key_states, value_states, position_ids)
)

cu_seq_lens_q, cu_seq_lens_k = cu_seq_lens
Expand All @@ -405,15 +554,14 @@ def _flash_attention_forward(
key_states = key_states.reshape(-1, key_states.size(-2), key_states.size(-1))
value_states = value_states.reshape(-1, value_states.size(-2), value_states.size(-1))

attn_output = flash_attn_varlen_func(
attn_output = _flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seq_lens_q,
cu_seqlens_k=cu_seq_lens_k,
max_seqlen_q=max_length_q,
max_seqlen_k=max_length_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
**flash_kwargs,
Expand All @@ -422,10 +570,12 @@ def _flash_attention_forward(
attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1))

else:
attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal, **flash_kwargs
attn_output = _flash_attn_func(
query_states, key_states, value_states, softmax_scale=softmax_scale, causal=causal, **flash_kwargs
)

if isinstance(attn_output, tuple):
return attn_output[0]
return attn_output


Expand Down
Loading