diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index e1694910997a..447d8a7b1783 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -17,7 +17,8 @@ import inspect import math from enum import Enum -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Tuple, Union +from functools import partial +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union import torch @@ -83,12 +84,20 @@ raise ImportError( "To use FA3 kernel for your hardware from the Hub, the `kernels` library must be installed. Install with `pip install kernels`." ) - from ..utils.kernels_utils import _get_fa3_from_hub + from ..utils.kernels_utils import _DEFAULT_HUB_ID_FA3, _DEFAULT_HUB_ID_SAGE, _get_kernel_from_hub + from ..utils.sage_utils import _get_sage_attn_fn_for_device - flash_attn_interface_hub = _get_fa3_from_hub() + flash_attn_interface_hub = _get_kernel_from_hub(_DEFAULT_HUB_ID_FA3) flash_attn_3_func_hub = flash_attn_interface_hub.flash_attn_func + + sage_interface_hub = _get_kernel_from_hub(_DEFAULT_HUB_ID_SAGE) + sage_fn_with_kwargs = _get_sage_attn_fn_for_device() + sage_attn_func_hub = getattr(sage_interface_hub, sage_fn_with_kwargs["func"]) + sage_attn_func_hub = partial(sage_attn_func_hub, **sage_fn_with_kwargs["kwargs"]) + else: flash_attn_3_func_hub = None + sage_attn_func_hub = None if _CAN_USE_SAGE_ATTN: from sageattention import ( @@ -162,10 +171,6 @@ def wrap(func): # - CP with sage attention, flex, xformers, other missing backends # - Add support for normal and CP training with backends that don't support it yet -_SAGE_ATTENTION_PV_ACCUM_DTYPE = Literal["fp32", "fp32+fp32"] -_SAGE_ATTENTION_QK_QUANT_GRAN = Literal["per_thread", "per_warp"] -_SAGE_ATTENTION_QUANTIZATION_BACKEND = Literal["cuda", "triton"] - class AttentionBackendName(str, Enum): # EAGER = "eager" @@ -190,6 +195,7 @@ class AttentionBackendName(str, Enum): # `sageattention` SAGE = "sage" + SAGE_HUB = "sage_hub" SAGE_VARLEN = "sage_varlen" _SAGE_QK_INT8_PV_FP8_CUDA = "_sage_qk_int8_pv_fp8_cuda" _SAGE_QK_INT8_PV_FP8_CUDA_SM90 = "_sage_qk_int8_pv_fp8_cuda_sm90" @@ -1756,6 +1762,31 @@ def _sage_attention( return (out, lse) if return_lse else out +@_AttentionBackendRegistry.register( + AttentionBackendName.SAGE_HUB, + constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape], + supports_context_parallel=False, +) +def _sage_attention_hub( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + is_causal: bool = False, + scale: Optional[float] = None, + return_lse: bool = False, + _parallel_config: Optional["ParallelConfig"] = None, +) -> torch.Tensor: + lse = None + if _parallel_config is None: + out = sage_attn_func_hub(q=query, k=key, v=value) + if return_lse: + out, lse, *_ = out + else: + raise NotImplementedError("SAGE attention doesn't yet support parallelism.") + + return (out, lse) if return_lse else out + + @_AttentionBackendRegistry.register( AttentionBackendName.SAGE_VARLEN, constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape], diff --git a/src/diffusers/utils/kernels_utils.py b/src/diffusers/utils/kernels_utils.py index 26d6e3972fb7..3470692cca09 100644 --- a/src/diffusers/utils/kernels_utils.py +++ b/src/diffusers/utils/kernels_utils.py @@ -6,18 +6,25 @@ _DEFAULT_HUB_ID_FA3 = "kernels-community/flash-attn3" +_DEFAULT_HUB_ID_SAGE = "kernels-community/sage_attention" +_KERNEL_REVISION = { + # TODO: temporary revision for now. Remove when merged upstream into `main`. + _DEFAULT_HUB_ID_FA3: "fake-ops-return-probs", + _DEFAULT_HUB_ID_SAGE: "compile", +} -def _get_fa3_from_hub(): +def _get_kernel_from_hub(kernel_id): if not is_kernels_available(): return None else: from kernels import get_kernel try: - # TODO: temporary revision for now. Remove when merged upstream into `main`. - flash_attn_3_hub = get_kernel(_DEFAULT_HUB_ID_FA3, revision="fake-ops-return-probs") - return flash_attn_3_hub + if kernel_id not in _KERNEL_REVISION: + raise NotImplementedError(f"{kernel_id} is not implemented in Diffusers.") + kernel_hub = get_kernel(kernel_id, revision=_KERNEL_REVISION.get(kernel_id)) + return kernel_hub except Exception as e: - logger.error(f"An error occurred while fetching kernel '{_DEFAULT_HUB_ID_FA3}' from the Hub: {e}") + logger.error(f"An error occurred while fetching kernel '{kernel_id}' from the Hub: {e}") raise diff --git a/src/diffusers/utils/sage_utils.py b/src/diffusers/utils/sage_utils.py new file mode 100644 index 000000000000..28e4e17941eb --- /dev/null +++ b/src/diffusers/utils/sage_utils.py @@ -0,0 +1,137 @@ +""" +Copyright (c) 2024 by SageAttention, The HuggingFace team. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the +License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the +specific language governing permissions and limitations under the License. +""" + +""" +Modified from +https://github.com/thu-ml/SageAttention/blob/68de3797d163b89d28f9a38026c3b7313f6940d2/sageattention/core.py +""" + + +import torch # noqa + + +SAGE_ATTENTION_DISPATCH = { + "sm80": { + "func": "sageattn_qk_int8_pv_fp16_cuda", + "kwargs": { + "tensor_layout": "NHD", + "is_causal": False, + "sm_scale": None, + "return_lse": False, + "pv_accum_dtype": "fp32", + }, + }, + "sm89": { + "func": "sageattn_qk_int8_pv_fp8_cuda", + "kwargs": { + "tensor_layout": "NHD", + "is_causal": False, + "sm_scale": None, + "return_lse": False, + "pv_accum_dtype": "fp32+fp16", + }, + }, + "sm90": { + "func": "sageattn_qk_int8_pv_fp8_cuda_sm90", + "kwargs": { + "tensor_layout": "NHD", + "is_causal": False, + "sm_scale": None, + "return_lse": False, + "pv_accum_dtype": "fp32+fp32", + }, + }, + "sm120": { + "func": "sageattn_qk_int8_pv_fp8_cuda", + "kwargs": { + "tensor_layout": "NHD", + "is_causal": False, + "qk_quant_gran": "per_warp", + "sm_scale": None, + "return_lse": False, + "pv_accum_dtype": "fp32+fp16", + }, + }, +} + + +def get_cuda_version(): + if torch.cuda.is_available(): + major, minor = torch.cuda.get_device_capability() + return major, minor + else: + raise EnvironmentError("CUDA not found.") + + +def get_cuda_arch_versions(): + if not torch.cuda.is_available(): + EnvironmentError("CUDA not found.") + cuda_archs = [] + for i in range(torch.cuda.device_count()): + major, minor = torch.cuda.get_device_capability(i) + cuda_archs.append(f"sm{major}{minor}") + return cuda_archs + + +# Unlike the actual implementation, we just maintain function names rather than actual +# implementations. +def _get_sage_attn_fn_for_device(): + """ + Automatically selects the appropriate implementation of the SageAttention kernel based on the GPU compute + capability. + + Parameters ---------- q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + v : torch.Tensor + The value tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". Default: "HND". + + is_causal : bool + Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. Default: False. + + sm_scale : Optional[float] + The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. + + return_lse : bool + Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention. + Default: False. + + Returns ------- torch.Tensor + The output tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + torch.Tensor + The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). Shape: + ``[batch_size, num_qo_heads, qo_len]``. Only returned if `return_lse` is True. + + Note ---- + - ``num_qo_heads`` must be divisible by ``num_kv_heads``. + - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - All tensors must be on the same cuda device. + """ + device_index = torch.cuda.current_device() + arch = get_cuda_arch_versions()[device_index] + return SAGE_ATTENTION_DISPATCH[arch]