Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/source/en/optimization/attention_backends.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,11 @@ Refer to the table below for a complete list of available attention backends and
| `_native_npu` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | NPU-optimized attention |
| `_native_xla` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | XLA-optimized attention |
| `flash` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-2 |
| `flash_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-2 from `kernels` |
| `flash_varlen` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention |
| `_flash_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 |
| `_flash_varlen_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention-3 |
| `_flash_3_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 from kernels |
| `_flash_3_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 from `kernels` |
| `sage` | [SageAttention](https://github.com/thu-ml/SageAttention) | Quantized attention (INT8 QK) |
| `sage_varlen` | [SageAttention](https://github.com/thu-ml/SageAttention) | Variable length SageAttention |
| `_sage_qk_int8_pv_fp8_cuda` | [SageAttention](https://github.com/thu-ml/SageAttention) | INT8 QK + FP8 PV (CUDA) |
Expand Down
48 changes: 41 additions & 7 deletions src/diffusers/models/attention_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,15 @@
raise ImportError(
"To use FA3 kernel for your hardware from the Hub, the `kernels` library must be installed. Install with `pip install kernels`."
)
from ..utils.kernels_utils import _get_fa3_from_hub
from ..utils.kernels_utils import _get_fa3_from_hub, _get_fa_from_hub

flash_attn_interface_hub = _get_fa3_from_hub()
flash_attn_3_func_hub = flash_attn_interface_hub.flash_attn_func
fa3_interface_hub = _get_fa3_from_hub()
flash_attn_3_func_hub = fa3_interface_hub.flash_attn_func
fa_interface_hub = _get_fa_from_hub()
flash_attn_func_hub = fa_interface_hub.flash_attn_func
Comment on lines +88 to +91

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we fetching both kernels here ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because of the way APIs for attention backends are designed and also to support torch.compile with fullgraph traceability (when possible).

We will let it grow a bit and upon feedback, we can revisit how to better deal with this.

else:
flash_attn_3_func_hub = None
flash_attn_func_hub = None

if _CAN_USE_SAGE_ATTN:
from sageattention import (
Expand Down Expand Up @@ -173,6 +176,8 @@ class AttentionBackendName(str, Enum):
# `flash-attn`
FLASH = "flash"
FLASH_VARLEN = "flash_varlen"
FLASH_HUB = "flash_hub"
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Flash Attention is stable. So, we don't have to mark it private like FA3.

# FLASH_VARLEN_HUB = "flash_varlen_hub" # not supported yet.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this related to the kernel or it just needs more time to be integrated ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't have models that use varlen.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sayakpaul qwen image uses varlen. also, native fused qkv+mlp attn requires varlen function.

_FLASH_3 = "_flash_3"
_FLASH_VARLEN_3 = "_flash_varlen_3"
_FLASH_3_HUB = "_flash_3_hub"
Expand Down Expand Up @@ -403,15 +408,15 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
f"Flash Attention 3 backend '{backend.value}' is not usable because of missing package or the version is too old. Please build FA3 beta release from source."
)

# TODO: add support Hub variant of FA3 varlen later
elif backend in [AttentionBackendName._FLASH_3_HUB]:
# TODO: add support Hub variant of FA and FA3 varlen later
elif backend in [AttentionBackendName.FLASH_HUB, AttentionBackendName._FLASH_3_HUB]:
if not DIFFUSERS_ENABLE_HUB_KERNELS:
raise RuntimeError(
f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `DIFFUSERS_ENABLE_HUB_KERNELS` env var isn't set. Please set it like `export DIFFUSERS_ENABLE_HUB_KERNELS=yes`."
f"Flash Attention Hub backend '{backend.value}' is not usable because the `DIFFUSERS_ENABLE_HUB_KERNELS` env var isn't set. Please set it like `export DIFFUSERS_ENABLE_HUB_KERNELS=yes`."
)
if not is_kernels_available():
raise RuntimeError(
f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
f"Flash Attention Hub backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
)

elif backend in [
Expand Down Expand Up @@ -1228,6 +1233,35 @@ def _flash_attention(
return (out, lse) if return_lse else out


@_AttentionBackendRegistry.register(
AttentionBackendName.FLASH_HUB,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
)
def _flash_attention_hub(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
return_lse: bool = False,
) -> torch.Tensor:
lse = None
out = flash_attn_func_hub(
q=query,
k=key,
v=value,
dropout_p=dropout_p,
softmax_scale=scale,
causal=is_causal,
return_attn_probs=return_lse,
)
if return_lse:
out, lse, *_ = out

return (out, lse) if return_lse else out


@_AttentionBackendRegistry.register(
AttentionBackendName.FLASH_VARLEN,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
Expand Down
36 changes: 23 additions & 13 deletions src/diffusers/utils/kernels_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,32 @@
from .import_utils import is_kernels_available


logger = get_logger(__name__)
if is_kernels_available():
from kernels import get_kernel

logger = get_logger(__name__)

_DEFAULT_HUB_ID_FA3 = "kernels-community/flash-attn3"
_DEFAULT_HUB_IDS = {
"fa3": ("kernels-community/flash-attn3", {"revision": "fake-ops-return-probs"}),
"fa": ("kernels-community/flash-attn", {}),
}


def _get_fa3_from_hub():
def _get_from_hub(key: str):
if not is_kernels_available():
return None
else:
from kernels import get_kernel

try:
# TODO: temporary revision for now. Remove when merged upstream into `main`.
flash_attn_3_hub = get_kernel(_DEFAULT_HUB_ID_FA3, revision="fake-ops-return-probs")
return flash_attn_3_hub
except Exception as e:
logger.error(f"An error occurred while fetching kernel '{_DEFAULT_HUB_ID_FA3}' from the Hub: {e}")
raise

hub_id, kwargs = _DEFAULT_HUB_IDS[key]
try:
return get_kernel(hub_id, **kwargs)
except Exception as e:
logger.error(f"An error occurred while fetching kernel '{hub_id}' from the Hub: {e}")
raise


def _get_fa3_from_hub():
return _get_from_hub("fa3")


def _get_fa_from_hub():
return _get_from_hub("fa")
Loading