Skip to content

Commit 19e2477

Browse files
joshua-j-hongShiboXing
authored andcommitted
[KVCache] Per Layer Sliding Window (apache#17928)
Adds per layer sliding window functionality to the KV Cache. Correctness is mostly achieved, but there are some cases where single tokens are strange. The corresponding MLC-LLM PR is mlc-ai/mlc-llm#3248 A full list of changes and additions are below - Add a new attention type for per-layer sliding window called `MHA_SLIDING` - Add corresponding vectors for per-layer sliding window offset calculations - For sliding window attention enabled KV-cache, regular sliding window is disabled to prevent page eviction - Gemma3 has different rope parameters for local sliding window layers. This should be passed as a parameter for the KVCache, but currently these values are hardcoded
1 parent c1e4054 commit 19e2477

File tree

3 files changed

+216
-47
lines changed

3 files changed

+216
-47
lines changed

python/tvm/relax/frontend/nn/llm/kv_cache.py

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
# pylint: disable=too-many-statements,too-many-lines,too-many-arguments,invalid-name
2121
import enum
2222
import math
23-
from typing import Any, Dict, List, Literal, Optional, Tuple
23+
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
2424

2525
import tvm
2626
from tvm import relax as rx
@@ -86,6 +86,7 @@ class AttnKind(enum.IntEnum):
8686

8787
MHA = 0
8888
MLA = 1
89+
MHA_SLIDING = 3
8990

9091

9192
class RopeMode(enum.IntEnum):
@@ -301,7 +302,7 @@ class FlashInferPagedKVCache(PagedKVCache): # pylint: disable=too-few-public-me
301302

302303
def __init__( # pylint: disable=too-many-locals
303304
self,
304-
attn_kind: Literal["mha", "mla"],
305+
attn_kind: Union[Literal["mha", "mla"], List[Literal["mha", "mla", "mha_sliding"]]],
305306
max_batch_size: tir.Var,
306307
max_total_seq_len: tir.Var,
307308
prefill_chunk_size: tir.Var,
@@ -377,8 +378,16 @@ def __init__( # pylint: disable=too-many-locals
377378
dtype_q=dtype,
378379
dtype_kv=dtype,
379380
dtype_o=dtype,
380-
qk_head_dim=qk_head_dim if attn_kind == "mha" else mla_original_qk_head_dim,
381-
v_head_dim=v_head_dim if attn_kind == "mha" else mla_original_v_head_dim,
381+
qk_head_dim=(
382+
qk_head_dim
383+
if (attn_kind == "mha" or isinstance(attn_kind, List))
384+
else mla_original_qk_head_dim
385+
),
386+
v_head_dim=(
387+
v_head_dim
388+
if (attn_kind == "mha" or isinstance(attn_kind, List))
389+
else mla_original_v_head_dim
390+
),
382391
target=target,
383392
enable_inline_rope=rope_mode == RopeMode.INLINE,
384393
)
@@ -391,7 +400,7 @@ def __init__( # pylint: disable=too-many-locals
391400
v_head_dim=v_head_dim,
392401
target=target,
393402
)
394-
if attn_kind == "mha"
403+
if (attn_kind == "mha" or isinstance(attn_kind, List))
395404
else []
396405
)
397406
flashinfer_mla_mods = (
@@ -420,7 +429,7 @@ def __init__( # pylint: disable=too-many-locals
420429
rx.Tuple([rx.StringImm("tir"), bb.add_func(tree_attn_with_paged_kv_cache(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask_with_paged_kv_cache")]),
421430
rx.Tuple([rx.StringImm("tir"), bb.add_func(tree_attn(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask")]),
422431
]
423-
if attn_kind == "mha"
432+
if (attn_kind == "mha" or isinstance(attn_kind, List))
424433
else [rx.Tuple([]) for _ in range(6)]
425434
)
426435
mla_function = rx.Tuple([rx.StringImm("flashinfer"), rx.ExternFunc("batch_mla_paged_attention_run"), rx.ExternFunc("batch_mla_paged_attention_plan")] if attn_kind == "mla" else [])
@@ -430,6 +439,11 @@ def __init__( # pylint: disable=too-many-locals
430439
if attn_kind == "mla":
431440
attn_merge_functions.append(bb.add_func(_merge_state_inplace(num_attention_heads, mla_original_v_head_dim, dtype, target, "tir_attention_merge_state_mla"), "tir_attention_merge_state_mla"))
432441

442+
443+
if isinstance(attn_kind, List):
444+
attn_kind = [int(getattr(AttnKind, layer_kind.upper())) for layer_kind in attn_kind]
445+
else:
446+
attn_kind = [int(getattr(AttnKind, attn_kind.upper())) for _ in range(num_hidden_layers)]
433447
args = [
434448
rx.ShapeExpr(
435449
[
@@ -482,7 +496,7 @@ class TIRPagedKVCache(PagedKVCache): # pylint: disable=too-few-public-methods
482496

483497
def __init__( # pylint: disable=too-many-locals
484498
self,
485-
attn_kind: Literal["mha", "mla"],
499+
attn_kind: Union[Literal["mha", "mla"], List[Literal["mha", "mla", "mha_sliding"]]],
486500
max_batch_size: tir.Var,
487501
max_total_seq_len: tir.Var,
488502
prefill_chunk_size: tir.Var,
@@ -553,7 +567,12 @@ def __init__( # pylint: disable=too-many-locals
553567
target : Target
554568
The target to build the model to.
555569
"""
556-
570+
if isinstance(attn_kind, List):
571+
attn_kind = [int(getattr(AttnKind, layer_kind.upper())) for layer_kind in attn_kind]
572+
else:
573+
attn_kind = [
574+
int(getattr(AttnKind, attn_kind.upper())) for _ in range(num_hidden_layers)
575+
]
557576
bb = rx.BlockBuilder.current()
558577
args = [
559578
rx.ShapeExpr(
@@ -570,9 +589,7 @@ def __init__( # pylint: disable=too-many-locals
570589
rx.PrimValue(num_key_value_heads),
571590
rx.PrimValue(qk_head_dim),
572591
rx.PrimValue(v_head_dim),
573-
rx.ShapeExpr(
574-
[int(getattr(AttnKind, attn_kind.upper())) for _ in range(num_hidden_layers)]
575-
),
592+
rx.ShapeExpr(attn_kind),
576593
rx.PrimValue(enable_disaggregation),
577594
rx.PrimValue(rope_mode),
578595
rx.PrimValue(rope_scale),
@@ -614,9 +631,9 @@ def __init__( # pylint: disable=too-many-locals
614631
else:
615632
# pylint: disable=line-too-long
616633
# fmt: off
617-
ragged_qk_head_dim = qk_head_dim if attn_kind == "mha" else mla_original_qk_head_dim
618-
ragged_v_head_dim = v_head_dim if attn_kind == "mha" else mla_original_v_head_dim
619-
args.append(rx.Tuple([rx.StringImm("tir"), bb.add_func(_attention_prefill_ragged(num_key_value_heads if attn_kind == "mha" else num_attention_heads, num_attention_heads, ragged_qk_head_dim, ragged_v_head_dim, dtype, rope_scaling, target), "tir_attention_prefill_ragged")]))
634+
ragged_qk_head_dim = qk_head_dim if (attn_kind == "mha" or isinstance(attn_kind, List)) else mla_original_qk_head_dim
635+
ragged_v_head_dim = v_head_dim if (attn_kind == "mha" or isinstance(attn_kind, List)) else mla_original_v_head_dim
636+
args.append(rx.Tuple([rx.StringImm("tir"), bb.add_func(_attention_prefill_ragged(num_key_value_heads if (attn_kind == "mha" or isinstance(attn_kind, List)) else num_attention_heads, num_attention_heads, ragged_qk_head_dim, ragged_v_head_dim, dtype, rope_scaling, target), "tir_attention_prefill_ragged")]))
620637
mha_functions = (
621638
[
622639
rx.Tuple([rx.StringImm("tir"), bb.add_func(_attention_prefill(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, False, rope_scaling, target), "tir_attention_prefill")]),
@@ -626,7 +643,7 @@ def __init__( # pylint: disable=too-many-locals
626643
rx.Tuple([rx.StringImm("tir"), bb.add_func(tree_attn_with_paged_kv_cache(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask_with_paged_kv_cache")]),
627644
rx.Tuple([rx.StringImm("tir"), bb.add_func(tree_attn(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask")]),
628645
]
629-
if attn_kind == "mha"
646+
if (attn_kind == "mha" or isinstance(attn_kind, List))
630647
else [rx.Tuple([]) for _ in range(6)]
631648
)
632649
mla_function = rx.Tuple([rx.StringImm("tir"), bb.add_func(_attention_prefill_mla(num_attention_heads, v_head_dim, qk_head_dim - v_head_dim, dtype, False, target), "tir_attention_prefill_mla")] if attn_kind == "mla" else [])
@@ -641,7 +658,7 @@ def __init__( # pylint: disable=too-many-locals
641658
[
642659
rx.Tuple(attn_merge_functions),
643660
bb.add_func(llama_rope_with_position_map(rope_theta, rope_scale, qk_head_dim, num_attention_heads, num_key_value_heads, dtype, rope_scaling, rotary_dim), "tir_split_rotary"),
644-
bb.add_func(_copy_single_page(num_key_value_heads, page_size, qk_head_dim, dtype, target) if attn_kind == "mha" else _copy_single_page_mla(page_size, qk_head_dim, dtype, target), "kv_cache_copy_single_page"),
661+
bb.add_func(_copy_single_page(num_key_value_heads, page_size, qk_head_dim, dtype, target) if (attn_kind == "mha" or isinstance(attn_kind, List)) else _copy_single_page_mla(page_size, qk_head_dim, dtype, target), "kv_cache_copy_single_page"),
645662
bb.add_func(_kv_cache_debug_get_kv(num_hidden_layers, num_key_value_heads, qk_head_dim, dtype), "kv_cache_debug_get_kv"),
646663
bb.add_func(_compact_kv_copy(num_key_value_heads, qk_head_dim, dtype, target), "kv_cache_compact_kv_copy"),
647664
]

src/runtime/vm/attn_utils.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,14 @@ enum class AttnKind : int {
6262
kMHA = 0,
6363
kMLA = 1,
6464
kLinearAttn = 2,
65+
kMHASliding = 3,
6566
};
6667

6768
/*! \brief Given the attention kind and other metadata, return the one-layer KV cache shape. */
6869
inline ffi::Shape GetKVCacheShape(AttnKind attn_kind, int64_t num_total_pages, int num_sequence,
6970
int64_t num_kv_heads, int64_t page_size, int64_t qk_head_dim,
7071
int64_t v_head_dim) {
71-
if (attn_kind == AttnKind::kMHA) {
72+
if (attn_kind == AttnKind::kMHA || attn_kind == AttnKind::kMHASliding) {
7273
// Ignore v_head_dim since multi-head attention requires K/V to have the same head dim.
7374
return {num_total_pages, 2, num_kv_heads, page_size, qk_head_dim};
7475
} else if (attn_kind == AttnKind::kMLA) {

0 commit comments

Comments
 (0)