Skip to content

Commit de81b7d

Browse files
Further relax constraints to cuDNN 9.13 for disabling fused attn for kv caching (#2121)
Signed-off-by: Kshitij Lakhani <[email protected]>
1 parent 1e2c68d commit de81b7d

File tree

1 file changed

+2
-2
lines changed
  • transformer_engine/pytorch/attention/dot_product_attention

1 file changed

+2
-2
lines changed

transformer_engine/pytorch/attention/dot_product_attention/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -434,8 +434,8 @@ def get_attention_backend(
434434
# | FP8 | non-paged/paged | sm90 | thd | >= 1
435435
# Unfused | FP32/FP16/BF16 | non-paged/paged | all | bshd,sbhd,thd | >= 1
436436
if inference_params is not None:
437-
if device_compute_capability == (8, 9) and cudnn_version <= (9, 12, 0):
438-
logger.debug("Disabling FusedAttention for KV caching for sm89 and cuDNN <= 9.12")
437+
if device_compute_capability == (8, 9) and cudnn_version <= (9, 13, 0):
438+
logger.debug("Disabling FusedAttention for KV caching for sm89 and cuDNN <= 9.13")
439439
use_fused_attention = False
440440
if context_parallel:
441441
logger.debug("Disabling all backends for KV caching with context parallelism")

0 commit comments

Comments
 (0)