Skip to content

Commit 4a2a993

Browse files
pre-commit-ci[bot]KshitijLakhani
authored andcommitted
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 729b7d2 commit 4a2a993

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

tests/jax/test_fused_attn.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -348,9 +348,16 @@ def _check_configs(self):
348348
pytest.skip(
349349
"seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN"
350350
)
351-
352-
if get_device_compute_capability(0) == 100 and self.dropout_prob == 0.1 and self.attn_bias_type is not AttnBiasType.NO_BIAS:
353-
pytest.skip("For Blackwell, there is no bprop kernel support for dropout + deterministic (bias) config ")
351+
352+
if (
353+
get_device_compute_capability(0) == 100
354+
and self.dropout_prob == 0.1
355+
and self.attn_bias_type is not AttnBiasType.NO_BIAS
356+
):
357+
pytest.skip(
358+
"For Blackwell, there is no bprop kernel support for dropout + deterministic (bias)"
359+
" config "
360+
)
354361
# Test the MLA case where head dims for qk differ from head dims for v, only if the tensors
355362
# are provided in BSHD_BSHD_BSHD or THD_THD_THD formats
356363
if self.head_dim_qk != self.head_dim_v and not self.qkv_layout.is_separate():

0 commit comments

Comments
 (0)