Skip to content

Commit 330ffd5

Browse files
Fix failing tests for dropout=0.1 and bias for fused attn for blackwell
Signed-off-by: Kshitij Lakhani <[email protected]>
1 parent 06a38cc commit 330ffd5

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

tests/jax/test_fused_attn.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from transformer_engine_jax import (
4242
NVTE_Fused_Attn_Backend,
4343
get_cudnn_version,
44+
get_device_compute_capability,
4445
)
4546

4647
from distributed_test_base import assert_equal_collectives
@@ -347,7 +348,9 @@ def _check_configs(self):
347348
pytest.skip(
348349
"seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN"
349350
)
350-
351+
352+
if get_device_compute_capability(0) == 100 and self.dropout_prob == 0.1 and self.attn_bias_type is not AttnBiasType.NO_BIAS:
353+
pytest.skip("For Blackwell, there is no bprop kernel support for dropout + deterministic (bias) config ")
351354
# Test the MLA case where head dims for qk differ from head dims for v, only if the tensors
352355
# are provided in BSHD_BSHD_BSHD or THD_THD_THD formats
353356
if self.head_dim_qk != self.head_dim_v and not self.qkv_layout.is_separate():

0 commit comments

Comments
 (0)