diff --git a/python/tvm/relax/frontend/nn/llm/kv_cache.py b/python/tvm/relax/frontend/nn/llm/kv_cache.py index 919d37364606..a1d742739aca 100644 --- a/python/tvm/relax/frontend/nn/llm/kv_cache.py +++ b/python/tvm/relax/frontend/nn/llm/kv_cache.py @@ -20,7 +20,7 @@ # pylint: disable=too-many-statements,too-many-lines,too-many-arguments,invalid-name import enum import math -from typing import Any, Dict, List, Literal, Optional, Tuple +from typing import Any, Dict, List, Literal, Optional, Tuple, Union import tvm from tvm import relax as rx @@ -86,6 +86,7 @@ class AttnKind(enum.IntEnum): MHA = 0 MLA = 1 + MHA_SLIDING = 3 class RopeMode(enum.IntEnum): @@ -301,7 +302,7 @@ class FlashInferPagedKVCache(PagedKVCache): # pylint: disable=too-few-public-me def __init__( # pylint: disable=too-many-locals self, - attn_kind: Literal["mha", "mla"], + attn_kind: Union[Literal["mha", "mla"], List[Literal["mha", "mla", "mha_sliding"]]], max_batch_size: tir.Var, max_total_seq_len: tir.Var, prefill_chunk_size: tir.Var, @@ -377,8 +378,16 @@ def __init__( # pylint: disable=too-many-locals dtype_q=dtype, dtype_kv=dtype, dtype_o=dtype, - qk_head_dim=qk_head_dim if attn_kind == "mha" else mla_original_qk_head_dim, - v_head_dim=v_head_dim if attn_kind == "mha" else mla_original_v_head_dim, + qk_head_dim=( + qk_head_dim + if (attn_kind == "mha" or isinstance(attn_kind, List)) + else mla_original_qk_head_dim + ), + v_head_dim=( + v_head_dim + if (attn_kind == "mha" or isinstance(attn_kind, List)) + else mla_original_v_head_dim + ), target=target, enable_inline_rope=rope_mode == RopeMode.INLINE, ) @@ -391,7 +400,7 @@ def __init__( # pylint: disable=too-many-locals v_head_dim=v_head_dim, target=target, ) - if attn_kind == "mha" + if (attn_kind == "mha" or isinstance(attn_kind, List)) else [] ) flashinfer_mla_mods = ( @@ -420,7 +429,7 @@ def __init__( # pylint: disable=too-many-locals rx.Tuple([rx.StringImm("tir"), bb.add_func(tree_attn_with_paged_kv_cache(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask_with_paged_kv_cache")]), rx.Tuple([rx.StringImm("tir"), bb.add_func(tree_attn(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask")]), ] - if attn_kind == "mha" + if (attn_kind == "mha" or isinstance(attn_kind, List)) else [rx.Tuple([]) for _ in range(6)] ) mla_function = rx.Tuple([rx.StringImm("flashinfer"), rx.ExternFunc("batch_mla_paged_attention_run"), rx.ExternFunc("batch_mla_paged_attention_plan")] if attn_kind == "mla" else []) @@ -430,6 +439,11 @@ def __init__( # pylint: disable=too-many-locals if attn_kind == "mla": attn_merge_functions.append(bb.add_func(_merge_state_inplace(num_attention_heads, mla_original_v_head_dim, dtype, target, "tir_attention_merge_state_mla"), "tir_attention_merge_state_mla")) + + if isinstance(attn_kind, List): + attn_kind = [int(getattr(AttnKind, layer_kind.upper())) for layer_kind in attn_kind] + else: + attn_kind = [int(getattr(AttnKind, attn_kind.upper())) for _ in range(num_hidden_layers)] args = [ rx.ShapeExpr( [ @@ -482,7 +496,7 @@ class TIRPagedKVCache(PagedKVCache): # pylint: disable=too-few-public-methods def __init__( # pylint: disable=too-many-locals self, - attn_kind: Literal["mha", "mla"], + attn_kind: Union[Literal["mha", "mla"], List[Literal["mha", "mla", "mha_sliding"]]], max_batch_size: tir.Var, max_total_seq_len: tir.Var, prefill_chunk_size: tir.Var, @@ -553,7 +567,12 @@ def __init__( # pylint: disable=too-many-locals target : Target The target to build the model to. """ - + if isinstance(attn_kind, List): + attn_kind = [int(getattr(AttnKind, layer_kind.upper())) for layer_kind in attn_kind] + else: + attn_kind = [ + int(getattr(AttnKind, attn_kind.upper())) for _ in range(num_hidden_layers) + ] bb = rx.BlockBuilder.current() args = [ rx.ShapeExpr( @@ -570,9 +589,7 @@ def __init__( # pylint: disable=too-many-locals rx.PrimValue(num_key_value_heads), rx.PrimValue(qk_head_dim), rx.PrimValue(v_head_dim), - rx.ShapeExpr( - [int(getattr(AttnKind, attn_kind.upper())) for _ in range(num_hidden_layers)] - ), + rx.ShapeExpr(attn_kind), rx.PrimValue(enable_disaggregation), rx.PrimValue(rope_mode), rx.PrimValue(rope_scale), @@ -614,9 +631,9 @@ def __init__( # pylint: disable=too-many-locals else: # pylint: disable=line-too-long # fmt: off - ragged_qk_head_dim = qk_head_dim if attn_kind == "mha" else mla_original_qk_head_dim - ragged_v_head_dim = v_head_dim if attn_kind == "mha" else mla_original_v_head_dim - args.append(rx.Tuple([rx.StringImm("tir"), bb.add_func(_attention_prefill_ragged(num_key_value_heads if attn_kind == "mha" else num_attention_heads, num_attention_heads, ragged_qk_head_dim, ragged_v_head_dim, dtype, rope_scaling, target), "tir_attention_prefill_ragged")])) + ragged_qk_head_dim = qk_head_dim if (attn_kind == "mha" or isinstance(attn_kind, List)) else mla_original_qk_head_dim + ragged_v_head_dim = v_head_dim if (attn_kind == "mha" or isinstance(attn_kind, List)) else mla_original_v_head_dim + args.append(rx.Tuple([rx.StringImm("tir"), bb.add_func(_attention_prefill_ragged(num_key_value_heads if (attn_kind == "mha" or isinstance(attn_kind, List)) else num_attention_heads, num_attention_heads, ragged_qk_head_dim, ragged_v_head_dim, dtype, rope_scaling, target), "tir_attention_prefill_ragged")])) mha_functions = ( [ rx.Tuple([rx.StringImm("tir"), bb.add_func(_attention_prefill(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, False, rope_scaling, target), "tir_attention_prefill")]), @@ -626,7 +643,7 @@ def __init__( # pylint: disable=too-many-locals rx.Tuple([rx.StringImm("tir"), bb.add_func(tree_attn_with_paged_kv_cache(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask_with_paged_kv_cache")]), rx.Tuple([rx.StringImm("tir"), bb.add_func(tree_attn(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask")]), ] - if attn_kind == "mha" + if (attn_kind == "mha" or isinstance(attn_kind, List)) else [rx.Tuple([]) for _ in range(6)] ) mla_function = rx.Tuple([rx.StringImm("tir"), bb.add_func(_attention_prefill_mla(num_attention_heads, v_head_dim, qk_head_dim - v_head_dim, dtype, False, target), "tir_attention_prefill_mla")] if attn_kind == "mla" else []) @@ -641,7 +658,7 @@ def __init__( # pylint: disable=too-many-locals [ rx.Tuple(attn_merge_functions), bb.add_func(llama_rope_with_position_map(rope_theta, rope_scale, qk_head_dim, num_attention_heads, num_key_value_heads, dtype, rope_scaling, rotary_dim), "tir_split_rotary"), - bb.add_func(_copy_single_page(num_key_value_heads, page_size, qk_head_dim, dtype, target) if attn_kind == "mha" else _copy_single_page_mla(page_size, qk_head_dim, dtype, target), "kv_cache_copy_single_page"), + bb.add_func(_copy_single_page(num_key_value_heads, page_size, qk_head_dim, dtype, target) if (attn_kind == "mha" or isinstance(attn_kind, List)) else _copy_single_page_mla(page_size, qk_head_dim, dtype, target), "kv_cache_copy_single_page"), bb.add_func(_kv_cache_debug_get_kv(num_hidden_layers, num_key_value_heads, qk_head_dim, dtype), "kv_cache_debug_get_kv"), bb.add_func(_compact_kv_copy(num_key_value_heads, qk_head_dim, dtype, target), "kv_cache_compact_kv_copy"), ] diff --git a/src/runtime/vm/attn_utils.h b/src/runtime/vm/attn_utils.h index dce577da0889..290ca02653d2 100644 --- a/src/runtime/vm/attn_utils.h +++ b/src/runtime/vm/attn_utils.h @@ -62,13 +62,14 @@ enum class AttnKind : int { kMHA = 0, kMLA = 1, kLinearAttn = 2, + kMHASliding = 3, }; /*! \brief Given the attention kind and other metadata, return the one-layer KV cache shape. */ inline ffi::Shape GetKVCacheShape(AttnKind attn_kind, int64_t num_total_pages, int num_sequence, int64_t num_kv_heads, int64_t page_size, int64_t qk_head_dim, int64_t v_head_dim) { - if (attn_kind == AttnKind::kMHA) { + if (attn_kind == AttnKind::kMHA || attn_kind == AttnKind::kMHASliding) { // Ignore v_head_dim since multi-head attention requires K/V to have the same head dim. return {num_total_pages, 2, num_kv_heads, page_size, qk_head_dim}; } else if (attn_kind == AttnKind::kMLA) { diff --git a/src/runtime/vm/paged_kv_cache.cc b/src/runtime/vm/paged_kv_cache.cc index 5b70c0317719..2af4b19b06b1 100644 --- a/src/runtime/vm/paged_kv_cache.cc +++ b/src/runtime/vm/paged_kv_cache.cc @@ -98,6 +98,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { const int64_t prefill_chunk_size_; /*! \brief A boolean flag indicating if the KV cache supports sliding window. */ const bool support_sliding_window_; + /*! \brief A boolean flag indicating if the KV cache has per layer sliding window. */ + const bool support_layer_sliding_window_; /*! \brief The attention kinds for each layer. */ const std::vector attn_kinds_; @@ -195,10 +197,13 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { std::vector qo_indptr_on_depths_host_; std::vector page_indptr_on_depths_host_; std::vector page_indices_on_depths_host_; + std::vector page_indptr_sliding_window_on_depths_host_; + std::vector page_indices_sliding_window_on_depths_host_; std::vector last_page_len_on_depths_host_; std::vector sliding_window_offset_on_depths_host_; std::vector sink_size_on_depths_host_; std::vector k_rope_pos_offset_on_depths_host_; + std::vector k_rope_pos_offset_sliding_window_on_depths_host_; HostMemoryVector k_ragged_rope_pos_offset_host_; HostMemoryVector q_rope_position_map_host_; HostMemoryVector append_position_map_host_; @@ -236,8 +241,12 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { std::vector qo_indptr_on_depths_view_; std::vector page_indptr_on_depths_view_; std::vector page_indices_on_depths_view_; + std::vector page_indptr_sliding_window_on_depths_view_; + std::vector page_indices_sliding_window_on_depths_view_; std::vector length_info_on_depths_view_; + std::vector layer_sliding_window_length_info_on_depths_view_; std::vector k_rope_pos_offset_view_; + std::vector k_rope_pos_offset_sliding_window_view_; std::vector tree_attn_mask_view_; std::vector tree_attn_mn_indptr_view_; @@ -298,7 +307,12 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { v_head_dim_(v_head_dim), num_total_pages_(num_total_pages), prefill_chunk_size_(prefill_chunk_size), - support_sliding_window_(support_sliding_window), + support_sliding_window_(std::find(attn_kinds.begin(), attn_kinds.end(), + AttnKind::kMHASliding) != attn_kinds.end() + ? false + : support_sliding_window), + support_layer_sliding_window_(std::find(attn_kinds.begin(), attn_kinds.end(), + AttnKind::kMHASliding) != attn_kinds.end()), attn_kinds_(std::move(attn_kinds)), rope_mode_(support_sliding_window && rope_mode != RoPEMode::kNone ? RoPEMode::kInline : rope_mode), @@ -378,6 +392,10 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { HostMemoryVector(reserved_num_seqs + 1, dtype_aux_, preferred_host_device)); page_indices_on_depths_host_.push_back( HostMemoryVector(num_total_pages, dtype_aux_, preferred_host_device)); + page_indptr_sliding_window_on_depths_host_.push_back( + HostMemoryVector(reserved_num_seqs + 1, dtype_aux_, preferred_host_device)); + page_indices_sliding_window_on_depths_host_.push_back( + HostMemoryVector(num_total_pages, dtype_aux_, preferred_host_device)); last_page_len_on_depths_host_.push_back( HostMemoryVector(reserved_num_seqs, dtype_aux_, preferred_host_device)); sliding_window_offset_on_depths_host_.push_back( @@ -386,6 +404,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { HostMemoryVector(reserved_num_seqs, dtype_aux_, preferred_host_device)); k_rope_pos_offset_on_depths_host_.push_back( HostMemoryVector(reserved_num_seqs, dtype_aux_, preferred_host_device)); + k_rope_pos_offset_sliding_window_on_depths_host_.push_back( + HostMemoryVector(reserved_num_seqs, dtype_aux_, preferred_host_device)); tree_attn_mask_host_.push_back(HostMemoryVector(kTreeAttnMaxTreeSize * 2 * reserved_num_seqs, dtype_aux_, preferred_host_device)); tree_attn_mn_indptr_host_.push_back( @@ -428,8 +448,12 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { qo_indptr_on_depths_view_.push_back(NDArray()); page_indptr_on_depths_view_.push_back(NDArray()); page_indices_on_depths_view_.push_back(NDArray()); + page_indptr_sliding_window_on_depths_view_.push_back(NDArray()); + page_indices_sliding_window_on_depths_view_.push_back(NDArray()); length_info_on_depths_view_.push_back(NDArray()); + layer_sliding_window_length_info_on_depths_view_.push_back(NDArray()); k_rope_pos_offset_view_.push_back(NDArray()); + k_rope_pos_offset_sliding_window_view_.push_back(NDArray()); tree_attn_mask_view_.push_back(NDArray()); tree_attn_mn_indptr_view_.push_back(NDArray()); is_chain_on_depths_.push_back(true); @@ -716,7 +740,9 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { void EnableSlidingWindowForSeq(int64_t seq_id, int32_t sliding_window_size, int32_t attn_sink_size) final { - CHECK(support_sliding_window_) << "The KV cache does not support sliding window."; + // If per layer sliding window exists, enable sliding window for sequence + CHECK(support_sliding_window_ || support_layer_sliding_window_) + << "The KV cache does not support sliding window."; auto it = seq_map_.find(seq_id); CHECK(it != seq_map_.end()) << "The sequence \"" << seq_id << "\" cannot be found in KV cache."; CHECK_GE(attn_sink_size, 0) @@ -938,28 +964,40 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { HostMemoryVector& qo_indptr_h = qo_indptr_on_depths_host_[d]; HostMemoryVector& page_indptr_h = page_indptr_on_depths_host_[d]; HostMemoryVector& page_indices_h = page_indices_on_depths_host_[d]; + HostMemoryVector& page_indptr_sliding_window_h = + page_indptr_sliding_window_on_depths_host_[d]; + HostMemoryVector& page_indices_sliding_window_h = + page_indices_sliding_window_on_depths_host_[d]; HostMemoryVector& last_page_len_h = last_page_len_on_depths_host_[d]; HostMemoryVector& sliding_window_offset_h = sliding_window_offset_on_depths_host_[d]; HostMemoryVector& sink_size_h = sink_size_on_depths_host_[d]; HostMemoryVector& k_rope_pos_offset_h = k_rope_pos_offset_on_depths_host_[d]; + HostMemoryVector& k_rope_pos_offset_sliding_window_h = + k_rope_pos_offset_sliding_window_on_depths_host_[d]; qo_indptr_h.clear(); page_indptr_h.clear(); page_indices_h.clear(); + page_indptr_sliding_window_h.clear(); + page_indices_sliding_window_h.clear(); last_page_len_h.clear(); sliding_window_offset_h.clear(); sink_size_h.clear(); k_rope_pos_offset_h.clear(); + k_rope_pos_offset_sliding_window_h.clear(); qo_indptr_h.push_back(0); page_indptr_h.push_back(0); + page_indptr_sliding_window_h.push_back(0); for (int i = 0; i < static_cast(chunked_block_ids_arr[d].size()); ++i) { const auto& [block_id, chunk_append_length] = chunked_block_ids_arr[d][i]; qo_indptr_h.push_back(qo_indptr_h.back() + chunk_append_length); if (block_id == -1) { page_indptr_h.push_back(page_indptr_h.back()); + page_indptr_sliding_window_h.push_back(page_indptr_sliding_window_h.back()); last_page_len_h.push_back(0); sliding_window_offset_h.push_back(0); sink_size_h.push_back(0); k_rope_pos_offset_h.push_back(0); + k_rope_pos_offset_sliding_window_h.push_back(0); } else { if (d < kPagedKVCacheMaxBlockDepth - 1) { // Blocks not at maximum depth @@ -967,16 +1005,44 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { page_indptr_h.push_back(page_indptr_h.back() + block.page_ids.size()); for (int32_t page_id : block.page_ids) { page_indices_h.push_back(page_id); + // Do the same for page_indices_sliding_window } + + // For sliding window, the first page and last page will both be partially used + page_indptr_sliding_window_h.push_back( + page_indptr_sliding_window_h.back() + + std::min(static_cast(block.page_ids.size()), + static_cast(1024 / page_size_ + + (block.seq_length % page_size_ ? 1 : 0)))); + for (int i = page_indices_h.size() - page_indptr_sliding_window_h.back(); + i < static_cast(page_indices_h.size()); i++) { + page_indices_sliding_window_h.push_back(page_indices_h[i]); + } + // set up the page indices properly by choosing the last (sliding_window_size / + // page_size_) pages (at most) last_page_len_h.push_back( block.seq_length == 0 ? 0 : (block.seq_length - block.sink_length + block.sliding_window_offset - 1) % page_size_ + 1); - sliding_window_offset_h.push_back(block.sliding_window_offset); + if (support_layer_sliding_window_) { + if (block.seq_length < 1024) { + sliding_window_offset_h.push_back(0); + } else { + sliding_window_offset_h.push_back(block.seq_length % page_size_); + } + } else { + sliding_window_offset_h.push_back(block.sliding_window_offset); + } sink_size_h.push_back(block.sink_length); k_rope_pos_offset_h.push_back(block.start_pos); + + // If sliding window, we need to calculate the positional offset + if (support_layer_sliding_window_) { + k_rope_pos_offset_sliding_window_h.push_back( + std::max(0, block.start_pos + block.seq_length - 1024)); + } } else { // Blocks at maximum depth const Block& block = global_block_pool_[block_id]; @@ -997,6 +1063,15 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { last_block_id = id; } page_indptr_h.push_back(page_indptr_h.back() + num_pages); + page_indptr_sliding_window_h.push_back( + page_indptr_sliding_window_h.back() + + std::min(static_cast(block.page_ids.size()), + static_cast(1024 / page_size_ + + (block.seq_length % page_size_ ? 1 : 0)))); + for (int i = page_indices_h.size() - page_indptr_sliding_window_h.back(); + i < static_cast(page_indices_h.size()); i++) { + page_indices_sliding_window_h.push_back(page_indices_h[i]); + } const Block& last_block = global_block_pool_[last_block_id]; last_page_len_h.push_back(total_seq_length == 0 ? 0 @@ -1004,9 +1079,21 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { last_block.sliding_window_offset - 1) % page_size_ + 1); - sliding_window_offset_h.push_back(last_block.sliding_window_offset); + if (support_layer_sliding_window_) { + if (last_block.seq_length < 1024) { + sliding_window_offset_h.push_back(0); + } else { + sliding_window_offset_h.push_back(last_block.seq_length % page_size_); + } + } else { + sliding_window_offset_h.push_back(last_block.sliding_window_offset); + } sink_size_h.push_back(last_block.sink_length); k_rope_pos_offset_h.push_back(block.start_pos); + if (support_layer_sliding_window_) { + k_rope_pos_offset_sliding_window_h.push_back( + std::max(0, block.start_pos + block.seq_length - 1024)); + } } } } @@ -1192,7 +1279,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { NDArray pages = pages_[local_layer_id]; CHECK(qkv_data.DataType() == pages.DataType()); CHECK(o_data.DataType() == pages.DataType()); - CHECK(attn_kinds_[layer_id] == AttnKind::kMHA); + CHECK(attn_kinds_[layer_id] == AttnKind::kMHA || + attn_kinds_[layer_id] == AttnKind::kMHASliding); // qkv_data: (num_total_length, num_qo_heads + 2 * num_kv_heads, qk_head_dim) // o_data: (num_total_length, num_qo_heads, qk_head_dim) @@ -1768,7 +1856,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { */ void SlideWindowForSequence(Sequence* seq) { // - No action when the sequence is not enabled for sliding window. - if (seq->sliding_window_size == -1) { + if (seq->sliding_window_size == -1 || !support_sliding_window_) { return; } // - No action when the sequence length does not exceed the window size. @@ -1854,7 +1942,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { for (int64_t page_idx = cur_npage; page_idx < tgt_npage; ++page_idx) { // When sliding window is enabled for the seq, we can "borrow temporary pages (-1)", // since the pages need to be slidden out might not have been released. - if (free_page_ids_.empty() && seq->sliding_window_size != -1) { + if (free_page_ids_.empty() && seq->sliding_window_size != -1 && support_sliding_window_) { block.page_ids.push_back(kPagedKVCacheTempPageId); } else { block.page_ids.push_back(GetFreePage()); @@ -1865,10 +1953,12 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { // ==================== Slide ==================== // Slide the sequences so that the pages exceed the sliding window are released. SlideWindowForSequence(seq); - for (int i = 0; i < static_cast(block.page_ids.size()); ++i) { - if (block.page_ids[i] == kPagedKVCacheTempPageId) { - // Re-allocate the temporary pages after sliding window release. - block.page_ids[i] = GetFreePage(); + if (support_sliding_window_) { + for (int i = 0; i < static_cast(block.page_ids.size()); ++i) { + if (block.page_ids[i] == kPagedKVCacheTempPageId) { + // Re-allocate the temporary pages after sliding window release. + block.page_ids[i] = GetFreePage(); + } } } @@ -1926,7 +2016,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { if (page_indices_on_depths_view_[d]->shape[0] == 0) { continue; } - CHECK(!support_sliding_window_) << "Kernel BeginForward doesn't support sliding window."; + CHECK(!support_sliding_window_ || !support_layer_sliding_window_) + << "Kernel BeginForward doesn't support sliding window."; if (use_decode_kernel_[d]) { if (f_attention_decode_ != nullptr && f_attention_decode_->backend_kind == AttnBackendKind::kFlashInfer) { @@ -2044,9 +2135,15 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { bool MHACrossAttnInternal(int64_t local_layer_id, NDArray q_data, NDArray o_data, NDArray lse_data, double sm_scale, bool is_first_kernel) { std::unique_ptr& f_prefill = - !support_sliding_window_ ? f_attention_prefill_ : f_attention_prefill_sliding_window_; + (!support_sliding_window_ && + attn_kinds_[local_layer_id + layer_id_begin_offset_] != AttnKind::kMHASliding) + ? f_attention_prefill_ + : f_attention_prefill_sliding_window_; std::unique_ptr& f_decode = - !support_sliding_window_ ? f_attention_decode_ : f_attention_decode_sliding_window_; + (!support_sliding_window_ && + attn_kinds_[local_layer_id + layer_id_begin_offset_] != AttnKind::kMHASliding) + ? f_attention_decode_ + : f_attention_decode_sliding_window_; CHECK_GE(num_depths_, 1) << "The number of effective depths must be greater or equal to 1."; bool cross_attn_computed = false; @@ -2063,30 +2160,50 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { attn_output = temp_attn_output_view_; attn_lse = temp_attn_lse_view_; } + // If layer is sliding window, use sliding window index pointer/indices + NDArray page_indptr; + NDArray page_indices; + NDArray length_info; + NDArray k_rope_pos; + double rotary_theta; + double rotary_scale; + + if (attn_kinds_[local_layer_id + layer_id_begin_offset_] == AttnKind::kMHASliding) { + page_indptr = page_indptr_sliding_window_on_depths_view_[d]; + page_indices = page_indices_sliding_window_on_depths_view_[d]; + length_info = layer_sliding_window_length_info_on_depths_view_[d]; + k_rope_pos = k_rope_pos_offset_sliding_window_view_[d]; + rotary_theta = 10000; + rotary_scale = 1; + } else { + page_indptr = page_indptr_on_depths_view_[d]; + page_indices = page_indices_on_depths_view_[d]; + length_info = length_info_on_depths_view_[d]; + k_rope_pos = k_rope_pos_offset_view_[d]; + rotary_theta = rotary_theta_; + rotary_scale = rotary_scale_; + } + if (append_before_attn_ && !is_chain_on_depths_[d]) { ICHECK_NOTNULL(f_attention_prefill_with_tree_mask_paged_kv_); f_attention_prefill_with_tree_mask_paged_kv_->MHA( - q_data, qo_indptr_on_depths_view_[d], pages_[local_layer_id], - page_indptr_on_depths_view_[d], page_indices_on_depths_view_[d], - length_info_on_depths_view_[d], k_rope_pos_offset_view_[d], q_rope_position_map_view_, - tree_attn_mn_indptr_view_[d], tree_attn_mask_view_[d], rope_mode_, rotary_scale_, - rotary_theta_, sm_scale, attn_output, attn_lse, compute_stream_); + q_data, qo_indptr_on_depths_view_[d], pages_[local_layer_id], page_indptr, page_indices, + length_info, k_rope_pos, q_rope_position_map_view_, tree_attn_mn_indptr_view_[d], + tree_attn_mask_view_[d], rope_mode_, rotary_scale, rotary_theta, sm_scale, attn_output, + attn_lse, compute_stream_); } else if (use_decode_kernel_[d]) { // Use decode kernel for depth d ICHECK_NOTNULL(f_decode); - f_decode->MHA(d, q_data, pages_[local_layer_id], page_indptr_on_depths_view_[d], - page_indices_on_depths_view_[d], length_info_on_depths_view_[d], - k_rope_pos_offset_view_[d], q_rope_position_map_view_, rope_mode_, - rotary_scale_, rotary_theta_, sm_scale, attn_output, attn_lse, - compute_stream_); + f_decode->MHA(d, q_data, pages_[local_layer_id], page_indptr, page_indices, length_info, + k_rope_pos, q_rope_position_map_view_, rope_mode_, rotary_scale, rotary_theta, + sm_scale, attn_output, attn_lse, compute_stream_); } else { // Use prefill kernel for depth d ICHECK_NOTNULL(f_prefill); - f_prefill->MHA(d, q_data, qo_indptr_on_depths_view_[d], pages_[local_layer_id], - page_indptr_on_depths_view_[d], page_indices_on_depths_view_[d], - length_info_on_depths_view_[d], q_rope_position_map_view_, - k_rope_pos_offset_view_[d], /*causal=*/false, - /*rotary_mode=*/rope_mode_, rotary_scale_, rotary_theta_, sm_scale, + f_prefill->MHA(d, q_data, qo_indptr_on_depths_view_[d], pages_[local_layer_id], page_indptr, + page_indices, length_info, q_rope_position_map_view_, k_rope_pos, + /*causal=*/false, + /*rotary_mode=*/rope_mode_, rotary_scale, rotary_theta, sm_scale, attn_output, attn_lse, compute_stream_); } @@ -2198,7 +2315,27 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { page_indices_on_depths_view_[d] = aux_data_manager_->CopyPageIndicesOnDepthAsync(&page_indices_on_depths_host_[d], d); } - // 5. length_info_on_depths + + // If per layer sliding window exists, must copy additional vectors + if (support_layer_sliding_window_) { + // 5. page_indptr_sliding_window_on_depths + for (int d = 0; d < num_depths_; ++d) { + ICHECK_EQ(page_indptr_sliding_window_on_depths_host_[d].size(), + qo_indptr_on_depths_host_[d].size()); + page_indptr_sliding_window_on_depths_view_[d] = + aux_data_manager_->CopyPageIndptrOnDepthAsync( + &page_indptr_sliding_window_on_depths_host_[d], d); + } + // 6. page_indices_sliding_window_on_depths + for (int d = 0; d < num_depths_; ++d) { + ICHECK_EQ(page_indices_sliding_window_on_depths_host_[d].size(), + page_indptr_sliding_window_on_depths_host_[d].back()); + page_indices_sliding_window_on_depths_view_[d] = + aux_data_manager_->CopyPageIndicesOnDepthAsync( + &page_indices_sliding_window_on_depths_host_[d], d); + } + } + // 7. length_info_on_depths // last_page_len_on_depths_host_; // sliding_window_offset_on_depths_host_; // sink_size_on_depths_host_; @@ -2217,6 +2354,13 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { &last_page_len_on_depths_host_[d], &sliding_window_offset_on_depths_host_[d], &sink_size_on_depths_host_[d], d); } + + if (support_layer_sliding_window_) { + layer_sliding_window_length_info_on_depths_view_[d] = + aux_data_manager_->CopyLengthInfoOnDepthAsync(&last_page_len_on_depths_host_[d], + &sliding_window_offset_on_depths_host_[d], + &sink_size_on_depths_host_[d], d); + } } // 6. k_rope_pos_offset_on_depths for (int d = 0; d < num_depths_; ++d) { @@ -2224,6 +2368,13 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { qo_indptr_on_depths_host_[d].size()); k_rope_pos_offset_view_[d] = aux_data_manager_->CopyKRoPEPosOffsetOnDepthAsync( &k_rope_pos_offset_on_depths_host_[d], d); + if (support_layer_sliding_window_) { + ICHECK_EQ(k_rope_pos_offset_sliding_window_on_depths_host_[d].size() + 1, + qo_indptr_on_depths_host_[d].size()); + k_rope_pos_offset_sliding_window_view_[d] = + aux_data_manager_->CopyKRoPEPosOffsetOnDepthAsync( + &k_rope_pos_offset_sliding_window_on_depths_host_[d], d); + } } // 7. cur_append_lengths_indptr cur_append_length_indptr_view_ =