Skip to content

Commit cbdfb77

Browse files
authored
Enable FlashInfer support encoder models and add head_dim padding workaround (#6230)
1 parent 282eb59 commit cbdfb77

File tree

2 files changed

+25
-3
lines changed

2 files changed

+25
-3
lines changed

python/sglang/srt/layers/attention/flashinfer_backend.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
2626
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
2727
from sglang.srt.layers.dp_attention import get_attention_tp_size
28+
from sglang.srt.layers.radix_attention import AttentionType
2829
from sglang.srt.layers.utils import is_sm100_supported
2930
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
3031
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
@@ -486,12 +487,20 @@ def forward_extend(
486487
v_scale=layer.v_scale,
487488
)
488489
else:
490+
causal = True
491+
if layer.attn_type == AttentionType.ENCODER_ONLY:
492+
save_kv_cache = False
493+
causal = False
494+
489495
if self.forward_metadata.extend_no_prefix:
496+
# NOTE: FlashInfer currently has limitations with head_dim = 32 or other dimensions
497+
# The FlashInfer head_dim limitation itself is tracked here:
498+
# https://github.com/flashinfer-ai/flashinfer/issues/1048
490499
o = self.prefill_wrapper_ragged.forward(
491500
q.view(-1, layer.tp_q_head_num, layer.head_dim),
492501
k.view(-1, layer.tp_k_head_num, layer.head_dim),
493502
v.view(-1, layer.tp_v_head_num, layer.head_dim),
494-
causal=True,
503+
causal=causal,
495504
sm_scale=layer.scaling,
496505
logits_soft_cap=logits_soft_cap,
497506
)

test/srt/models/test_encoder_embedding_models.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@
2727

2828
MODELS = [("BAAI/bge-small-en", 1, 1e-5), ("BAAI/bge-m3", 1, 1e-5)]
2929

30-
ATTENTION_BACKEND = ["torch_native", "triton"]
30+
ATTENTION_BACKEND = ["torch_native", "triton", "flashinfer"]
3131
BATCH_SIZE = [1, 2]
32-
TORCH_DTYPES = [torch.float32]
32+
TORCH_DTYPES = [torch.float32, torch.float16]
3333
sgl_to_st_ratio = []
3434

3535

@@ -126,6 +126,19 @@ def test_prefill_logits(self):
126126
for attention_backend in ATTENTION_BACKEND:
127127
for batch_size in BATCH_SIZE:
128128
for torch_dtype in TORCH_DTYPES:
129+
# NOTE: FlashInfer currently has limitations with head_dim = 32 or
130+
# other dimensions.
131+
# The FlashInfer head_dim limitation itself is tracked here:
132+
# https://github.com/flashinfer-ai/flashinfer/issues/1048
133+
#
134+
# Flashinfer does not support torch.float32 for dtype_q, so skip it
135+
if attention_backend == "flashinfer":
136+
if (
137+
model == "BAAI/bge-small-en"
138+
or torch_dtype == torch.float32
139+
):
140+
continue
141+
129142
self.assert_close_prefill_logits(
130143
DEFAULT_PROMPTS,
131144
model,

0 commit comments

Comments
 (0)