20
20
# pylint: disable=too-many-statements,too-many-lines,too-many-arguments,invalid-name
21
21
import enum
22
22
import math
23
- from typing import Any , Dict , List , Literal , Optional , Tuple
23
+ from typing import Any , Dict , List , Literal , Optional , Tuple , Union
24
24
25
25
import tvm
26
26
from tvm import relax as rx
@@ -86,6 +86,7 @@ class AttnKind(enum.IntEnum):
86
86
87
87
MHA = 0
88
88
MLA = 1
89
+ MHA_SLIDING = 3
89
90
90
91
91
92
class RopeMode (enum .IntEnum ):
@@ -301,7 +302,7 @@ class FlashInferPagedKVCache(PagedKVCache): # pylint: disable=too-few-public-me
301
302
302
303
def __init__ ( # pylint: disable=too-many-locals
303
304
self ,
304
- attn_kind : Literal ["mha" , "mla" ],
305
+ attn_kind : Union [ Literal ["mha" , "mla" ], List [ Literal [ "mha" , "mla" , "mha_sliding" ]] ],
305
306
max_batch_size : tir .Var ,
306
307
max_total_seq_len : tir .Var ,
307
308
prefill_chunk_size : tir .Var ,
@@ -377,8 +378,16 @@ def __init__( # pylint: disable=too-many-locals
377
378
dtype_q = dtype ,
378
379
dtype_kv = dtype ,
379
380
dtype_o = dtype ,
380
- qk_head_dim = qk_head_dim if attn_kind == "mha" else mla_original_qk_head_dim ,
381
- v_head_dim = v_head_dim if attn_kind == "mha" else mla_original_v_head_dim ,
381
+ qk_head_dim = (
382
+ qk_head_dim
383
+ if (attn_kind == "mha" or isinstance (attn_kind , List ))
384
+ else mla_original_qk_head_dim
385
+ ),
386
+ v_head_dim = (
387
+ v_head_dim
388
+ if (attn_kind == "mha" or isinstance (attn_kind , List ))
389
+ else mla_original_v_head_dim
390
+ ),
382
391
target = target ,
383
392
enable_inline_rope = rope_mode == RopeMode .INLINE ,
384
393
)
@@ -391,7 +400,7 @@ def __init__( # pylint: disable=too-many-locals
391
400
v_head_dim = v_head_dim ,
392
401
target = target ,
393
402
)
394
- if attn_kind == "mha"
403
+ if ( attn_kind == "mha" or isinstance ( attn_kind , List ))
395
404
else []
396
405
)
397
406
flashinfer_mla_mods = (
@@ -420,7 +429,7 @@ def __init__( # pylint: disable=too-many-locals
420
429
rx .Tuple ([rx .StringImm ("tir" ), bb .add_func (tree_attn_with_paged_kv_cache (num_key_value_heads , num_attention_heads , qk_head_dim , dtype , rope_scaling , target ), "tir_attention_prefill_with_tree_mask_with_paged_kv_cache" )]),
421
430
rx .Tuple ([rx .StringImm ("tir" ), bb .add_func (tree_attn (num_key_value_heads , num_attention_heads , qk_head_dim , dtype , rope_scaling , target ), "tir_attention_prefill_with_tree_mask" )]),
422
431
]
423
- if attn_kind == "mha"
432
+ if ( attn_kind == "mha" or isinstance ( attn_kind , List ))
424
433
else [rx .Tuple ([]) for _ in range (6 )]
425
434
)
426
435
mla_function = rx .Tuple ([rx .StringImm ("flashinfer" ), rx .ExternFunc ("batch_mla_paged_attention_run" ), rx .ExternFunc ("batch_mla_paged_attention_plan" )] if attn_kind == "mla" else [])
@@ -430,6 +439,11 @@ def __init__( # pylint: disable=too-many-locals
430
439
if attn_kind == "mla" :
431
440
attn_merge_functions .append (bb .add_func (_merge_state_inplace (num_attention_heads , mla_original_v_head_dim , dtype , target , "tir_attention_merge_state_mla" ), "tir_attention_merge_state_mla" ))
432
441
442
+
443
+ if isinstance (attn_kind , List ):
444
+ attn_kind = [int (getattr (AttnKind , layer_kind .upper ())) for layer_kind in attn_kind ]
445
+ else :
446
+ attn_kind = [int (getattr (AttnKind , attn_kind .upper ())) for _ in range (num_hidden_layers )]
433
447
args = [
434
448
rx .ShapeExpr (
435
449
[
@@ -482,7 +496,7 @@ class TIRPagedKVCache(PagedKVCache): # pylint: disable=too-few-public-methods
482
496
483
497
def __init__ ( # pylint: disable=too-many-locals
484
498
self ,
485
- attn_kind : Literal ["mha" , "mla" ],
499
+ attn_kind : Union [ Literal ["mha" , "mla" ], List [ Literal [ "mha" , "mla" , "mha_sliding" ]] ],
486
500
max_batch_size : tir .Var ,
487
501
max_total_seq_len : tir .Var ,
488
502
prefill_chunk_size : tir .Var ,
@@ -553,7 +567,12 @@ def __init__( # pylint: disable=too-many-locals
553
567
target : Target
554
568
The target to build the model to.
555
569
"""
556
-
570
+ if isinstance (attn_kind , List ):
571
+ attn_kind = [int (getattr (AttnKind , layer_kind .upper ())) for layer_kind in attn_kind ]
572
+ else :
573
+ attn_kind = [
574
+ int (getattr (AttnKind , attn_kind .upper ())) for _ in range (num_hidden_layers )
575
+ ]
557
576
bb = rx .BlockBuilder .current ()
558
577
args = [
559
578
rx .ShapeExpr (
@@ -570,9 +589,7 @@ def __init__( # pylint: disable=too-many-locals
570
589
rx .PrimValue (num_key_value_heads ),
571
590
rx .PrimValue (qk_head_dim ),
572
591
rx .PrimValue (v_head_dim ),
573
- rx .ShapeExpr (
574
- [int (getattr (AttnKind , attn_kind .upper ())) for _ in range (num_hidden_layers )]
575
- ),
592
+ rx .ShapeExpr (attn_kind ),
576
593
rx .PrimValue (enable_disaggregation ),
577
594
rx .PrimValue (rope_mode ),
578
595
rx .PrimValue (rope_scale ),
@@ -614,9 +631,9 @@ def __init__( # pylint: disable=too-many-locals
614
631
else :
615
632
# pylint: disable=line-too-long
616
633
# fmt: off
617
- ragged_qk_head_dim = qk_head_dim if attn_kind == "mha" else mla_original_qk_head_dim
618
- ragged_v_head_dim = v_head_dim if attn_kind == "mha" else mla_original_v_head_dim
619
- args .append (rx .Tuple ([rx .StringImm ("tir" ), bb .add_func (_attention_prefill_ragged (num_key_value_heads if attn_kind == "mha" else num_attention_heads , num_attention_heads , ragged_qk_head_dim , ragged_v_head_dim , dtype , rope_scaling , target ), "tir_attention_prefill_ragged" )]))
634
+ ragged_qk_head_dim = qk_head_dim if ( attn_kind == "mha" or isinstance ( attn_kind , List )) else mla_original_qk_head_dim
635
+ ragged_v_head_dim = v_head_dim if ( attn_kind == "mha" or isinstance ( attn_kind , List )) else mla_original_v_head_dim
636
+ args .append (rx .Tuple ([rx .StringImm ("tir" ), bb .add_func (_attention_prefill_ragged (num_key_value_heads if ( attn_kind == "mha" or isinstance ( attn_kind , List )) else num_attention_heads , num_attention_heads , ragged_qk_head_dim , ragged_v_head_dim , dtype , rope_scaling , target ), "tir_attention_prefill_ragged" )]))
620
637
mha_functions = (
621
638
[
622
639
rx .Tuple ([rx .StringImm ("tir" ), bb .add_func (_attention_prefill (num_key_value_heads , num_attention_heads , qk_head_dim , dtype , False , rope_scaling , target ), "tir_attention_prefill" )]),
@@ -626,7 +643,7 @@ def __init__( # pylint: disable=too-many-locals
626
643
rx .Tuple ([rx .StringImm ("tir" ), bb .add_func (tree_attn_with_paged_kv_cache (num_key_value_heads , num_attention_heads , qk_head_dim , dtype , rope_scaling , target ), "tir_attention_prefill_with_tree_mask_with_paged_kv_cache" )]),
627
644
rx .Tuple ([rx .StringImm ("tir" ), bb .add_func (tree_attn (num_key_value_heads , num_attention_heads , qk_head_dim , dtype , rope_scaling , target ), "tir_attention_prefill_with_tree_mask" )]),
628
645
]
629
- if attn_kind == "mha"
646
+ if ( attn_kind == "mha" or isinstance ( attn_kind , List ))
630
647
else [rx .Tuple ([]) for _ in range (6 )]
631
648
)
632
649
mla_function = rx .Tuple ([rx .StringImm ("tir" ), bb .add_func (_attention_prefill_mla (num_attention_heads , v_head_dim , qk_head_dim - v_head_dim , dtype , False , target ), "tir_attention_prefill_mla" )] if attn_kind == "mla" else [])
@@ -641,7 +658,7 @@ def __init__( # pylint: disable=too-many-locals
641
658
[
642
659
rx .Tuple (attn_merge_functions ),
643
660
bb .add_func (llama_rope_with_position_map (rope_theta , rope_scale , qk_head_dim , num_attention_heads , num_key_value_heads , dtype , rope_scaling , rotary_dim ), "tir_split_rotary" ),
644
- bb .add_func (_copy_single_page (num_key_value_heads , page_size , qk_head_dim , dtype , target ) if attn_kind == "mha" else _copy_single_page_mla (page_size , qk_head_dim , dtype , target ), "kv_cache_copy_single_page" ),
661
+ bb .add_func (_copy_single_page (num_key_value_heads , page_size , qk_head_dim , dtype , target ) if ( attn_kind == "mha" or isinstance ( attn_kind , List )) else _copy_single_page_mla (page_size , qk_head_dim , dtype , target ), "kv_cache_copy_single_page" ),
645
662
bb .add_func (_kv_cache_debug_get_kv (num_hidden_layers , num_key_value_heads , qk_head_dim , dtype ), "kv_cache_debug_get_kv" ),
646
663
bb .add_func (_compact_kv_copy (num_key_value_heads , qk_head_dim , dtype , target ), "kv_cache_compact_kv_copy" ),
647
664
]
0 commit comments