-
Notifications
You must be signed in to change notification settings - Fork 6.4k
[core] support flash attention through kernels
#12387
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
c386f22
d252c02
1b96ed7
474b995
b0fc7af
029975e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -83,12 +83,15 @@ | |
raise ImportError( | ||
"To use FA3 kernel for your hardware from the Hub, the `kernels` library must be installed. Install with `pip install kernels`." | ||
) | ||
from ..utils.kernels_utils import _get_fa3_from_hub | ||
from ..utils.kernels_utils import _get_fa3_from_hub, _get_fa_from_hub | ||
|
||
flash_attn_interface_hub = _get_fa3_from_hub() | ||
flash_attn_3_func_hub = flash_attn_interface_hub.flash_attn_func | ||
fa3_interface_hub = _get_fa3_from_hub() | ||
flash_attn_3_func_hub = fa3_interface_hub.flash_attn_func | ||
fa_interface_hub = _get_fa_from_hub() | ||
flash_attn_func_hub = fa_interface_hub.flash_attn_func | ||
else: | ||
flash_attn_3_func_hub = None | ||
flash_attn_func_hub = None | ||
|
||
if _CAN_USE_SAGE_ATTN: | ||
from sageattention import ( | ||
|
@@ -173,6 +176,8 @@ class AttentionBackendName(str, Enum): | |
# `flash-attn` | ||
FLASH = "flash" | ||
FLASH_VARLEN = "flash_varlen" | ||
FLASH_HUB = "flash_hub" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Flash Attention is stable. So, we don't have to mark it private like FA3. |
||
# FLASH_VARLEN_HUB = "flash_varlen_hub" # not supported yet. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this related to the kernel or it just needs more time to be integrated ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We don't have models that use varlen. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @sayakpaul qwen image uses varlen. also, native fused qkv+mlp attn requires varlen function. |
||
_FLASH_3 = "_flash_3" | ||
_FLASH_VARLEN_3 = "_flash_varlen_3" | ||
_FLASH_3_HUB = "_flash_3_hub" | ||
|
@@ -403,15 +408,15 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None | |
f"Flash Attention 3 backend '{backend.value}' is not usable because of missing package or the version is too old. Please build FA3 beta release from source." | ||
) | ||
|
||
# TODO: add support Hub variant of FA3 varlen later | ||
elif backend in [AttentionBackendName._FLASH_3_HUB]: | ||
# TODO: add support Hub variant of FA and FA3 varlen later | ||
elif backend in [AttentionBackendName.FLASH_HUB, AttentionBackendName._FLASH_3_HUB]: | ||
if not DIFFUSERS_ENABLE_HUB_KERNELS: | ||
raise RuntimeError( | ||
f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `DIFFUSERS_ENABLE_HUB_KERNELS` env var isn't set. Please set it like `export DIFFUSERS_ENABLE_HUB_KERNELS=yes`." | ||
f"Flash Attention Hub backend '{backend.value}' is not usable because the `DIFFUSERS_ENABLE_HUB_KERNELS` env var isn't set. Please set it like `export DIFFUSERS_ENABLE_HUB_KERNELS=yes`." | ||
) | ||
if not is_kernels_available(): | ||
raise RuntimeError( | ||
f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`." | ||
f"Flash Attention Hub backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`." | ||
) | ||
|
||
elif backend in [ | ||
|
@@ -1228,6 +1233,35 @@ def _flash_attention( | |
return (out, lse) if return_lse else out | ||
|
||
|
||
@_AttentionBackendRegistry.register( | ||
AttentionBackendName.FLASH_HUB, | ||
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], | ||
) | ||
def _flash_attention_hub( | ||
query: torch.Tensor, | ||
key: torch.Tensor, | ||
value: torch.Tensor, | ||
dropout_p: float = 0.0, | ||
is_causal: bool = False, | ||
scale: Optional[float] = None, | ||
return_lse: bool = False, | ||
) -> torch.Tensor: | ||
lse = None | ||
out = flash_attn_func_hub( | ||
q=query, | ||
k=key, | ||
v=value, | ||
dropout_p=dropout_p, | ||
softmax_scale=scale, | ||
causal=is_causal, | ||
return_attn_probs=return_lse, | ||
) | ||
if return_lse: | ||
out, lse, *_ = out | ||
|
||
return (out, lse) if return_lse else out | ||
|
||
|
||
@_AttentionBackendRegistry.register( | ||
AttentionBackendName.FLASH_VARLEN, | ||
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are we fetching both kernels here ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because of the way APIs for attention backends are designed and also to support
torch.compile
with fullgraph traceability (when possible).We will let it grow a bit and upon feedback, we can revisit how to better deal with this.