diff --git a/site/docs/concepts/optimization-techniques/kvcache-eviction-algorithm.md b/site/docs/concepts/optimization-techniques/kvcache-eviction-algorithm.md index c52ea367fd..0584ba6d29 100644 --- a/site/docs/concepts/optimization-techniques/kvcache-eviction-algorithm.md +++ b/site/docs/concepts/optimization-techniques/kvcache-eviction-algorithm.md @@ -60,3 +60,40 @@ It can be enabled by setting the `CacheEvictionConfig.apply_rotation` field to ` * Cache rotation is only targeted for the regular, linear LLaMa-like RoPE application and may degrade accuracy on models that use other RoPE schemes. * Cache rotation is currently only supported for the models with uniform V embedding sizes across the layers. + +## (Optional) KVCrush + +KVCrush enhances the standard H2O/SnapKV eviction by selecting the most representative blocks from the evictable area using clustering analysis, rather than simply evicting the low score blocks. + +### Algorithm Overview + +1. **Indicator Creation**: Generate binary indicators for tokens based on importance scores +2. **Anchor Point Generation**: Create reference patterns using configurable modes +3. **Distance Calculation**: Measure Hamming distance between block patterns and the anchor point +4. **Representative Selection**: Select blocks to best represent context diversity + +### Configuration +Setup KVCrush config parameters and pass it to ```CacheEvictionConfig```. Sample code to allocate KVCrush a budget of 2 blocks and use MEAN anchor mode is following. +```cpp +const ov::genai::CacheEvictionConfig EXAMPLE_CACHE_EVICTION_CONFIG = + {32, 32, 192, ov::genai::AggregationMode::NORM_SUM, false, 8, KVCrushConfig(2, KVCrushAnchorPointMode::MEAN)}; +``` +```python +CacheEvictionConfig( + start_size=32, + recent_size=128, + max_cache_size=448, + aggregation_mode=AggregationMode.NORM_SUM, + apply_rotation=False, + snapkv_window_size=8, + kvcrush_config=KVCrushConfig(budget=2, anchor_point_mode=KVCrushAnchorPointMode.MEAN) + ) +``` + +**Anchor Point Modes:** +- `RANDOM`: Random binary pattern +- `ZEROS`: All zeros pattern +- `ONES`: All ones pattern +- `MEAN`: Mean of indicators across blocks +- `ALTERNATE`: Alternating 0-1 pattern + diff --git a/src/cpp/include/openvino/genai/cache_eviction.hpp b/src/cpp/include/openvino/genai/cache_eviction.hpp index 3869680d1c..e5f54b24ec 100644 --- a/src/cpp/include/openvino/genai/cache_eviction.hpp +++ b/src/cpp/include/openvino/genai/cache_eviction.hpp @@ -19,6 +19,55 @@ enum class AggregationMode { * of a given token in cache */ }; +/** + * @brief Represents the mode of how anchor points are formed in KVCrush Cache eviction algorithm + */ +enum class KVCrushAnchorPointMode { + RANDOM, /** */ + ZEROS, /** (start_size + recent_size), "CacheEvictionConfig.max_cache_size must be larger than CacheEvictionConfig.start_size + CacheEvictionConfig.recent_size"); m_evictable_size = m_max_cache_size - m_start_size - m_recent_size; - } /** @return Number of tokens between the "start" and "recent" areas of KV cache that @@ -76,6 +137,11 @@ class CacheEvictionConfig { * score aggregation. **/ size_t snapkv_window_size = 8; + /** KVCrush configuration for this cache eviction algorithm. + * KVCrush is an additional mechanism that allows to retain some tokens in the cache + * even if they are not among the most important ones.*/ + KVCrushConfig kvcrush_config; + private: /** Number of tokens in the *beginning* of KV cache that should be retained * in the KV cache for this sequence during generation. Must be non-zero and a multiple of the KV cache block size for diff --git a/src/cpp/src/continuous_batching/cache_eviction.cpp b/src/cpp/src/continuous_batching/cache_eviction.cpp index 4cbd3511d9..7196b29e56 100644 --- a/src/cpp/src/continuous_batching/cache_eviction.cpp +++ b/src/cpp/src/continuous_batching/cache_eviction.cpp @@ -220,7 +220,7 @@ namespace ov::genai { CacheEvictionAlgorithm::CacheEvictionAlgorithm(const CacheEvictionConfig &eviction_config, size_t block_size, size_t num_decoder_layers, size_t max_pool_window_size) : m_eviction_config(eviction_config), m_block_size(block_size), m_num_decoder_layers(num_decoder_layers), - m_score_manager(block_size, num_decoder_layers, max_pool_window_size, eviction_config.aggregation_mode, eviction_config.get_start_size() / block_size, eviction_config.snapkv_window_size) + m_score_manager(block_size, num_decoder_layers, max_pool_window_size, eviction_config.aggregation_mode, eviction_config.get_start_size() / block_size, eviction_config.snapkv_window_size), m_kvcrush_algo(eviction_config.kvcrush_config, block_size) { OPENVINO_ASSERT(!(m_eviction_config.get_start_size() % m_block_size), "CacheEvictionConfig.start_size in tokens must be a multiple of block size ", m_block_size); @@ -265,6 +265,38 @@ namespace ov::genai { size_t num_blocks_to_evict = get_num_blocks_to_evict(decoder_layer_idx); auto evicted_block_indices = get_indices_of_blocks_to_evict(scores_for_all_evictable_blocks, num_blocks_to_evict); + // KVCrush: start + bool should_apply_kvcrush = (m_eviction_config.kvcrush_config.budget > 0) && + (evicted_block_indices.size() >= m_eviction_config.kvcrush_config.budget); + if (should_apply_kvcrush) { + size_t num_tokens_in_evictable_blocks = scores_for_all_evictable_blocks.size() * m_block_size; + + auto kvcrush_retained_block_indices = m_kvcrush_algo.get_indices_of_blocks_to_retain_using_kvcrush( + num_tokens_in_evictable_blocks, + evicted_block_indices, + m_score_manager.get_scores()[decoder_layer_idx]); + + // Remove the indices in kvcrush_retained_block_indices from evicted_block_indices + if (!kvcrush_retained_block_indices.empty()) { + // Convert both vectors to sets for efficient operations + std::unordered_set retained_set(kvcrush_retained_block_indices.begin(), + kvcrush_retained_block_indices.end()); + + // Create a new vector containing only elements not in retained_set + std::vector filtered_evicted_indices; + filtered_evicted_indices.reserve(evicted_block_indices.size()); + + for (const auto& idx : evicted_block_indices) { + if (retained_set.find(idx) == retained_set.end()) { + filtered_evicted_indices.push_back(idx); + } + } + // Replace the original vector with the filtered one + evicted_block_indices = std::move(filtered_evicted_indices); + } + } + // KVCrush: end + m_num_evicted_tokens += evicted_block_indices.size() * m_block_size; // No longer need to track the overall "heavy-hitter" attention scores for freshly evicted blocks diff --git a/src/cpp/src/continuous_batching/cache_eviction.hpp b/src/cpp/src/continuous_batching/cache_eviction.hpp index fc5fbfbb59..1069ea6241 100644 --- a/src/cpp/src/continuous_batching/cache_eviction.hpp +++ b/src/cpp/src/continuous_batching/cache_eviction.hpp @@ -11,6 +11,7 @@ #include "openvino/openvino.hpp" #include "continuous_batching/attention_output.hpp" #include "openvino/genai/cache_eviction.hpp" +#include "continuous_batching/kvcrush.hpp" namespace ov::genai { @@ -215,6 +216,7 @@ class CacheEvictionAlgorithm { void remove_scores_of_evicted_blocks(const std::vector& evicted_block_indices, size_t decoder_layer_idx); CacheEvictionConfig m_eviction_config; + KVCrushAlgorithm m_kvcrush_algo; std::size_t m_block_size; std::size_t m_num_evicted_tokens = 0; std::size_t m_num_decoder_layers; diff --git a/src/cpp/src/continuous_batching/kvcrush.cpp b/src/cpp/src/continuous_batching/kvcrush.cpp new file mode 100644 index 0000000000..808ec85e43 --- /dev/null +++ b/src/cpp/src/continuous_batching/kvcrush.cpp @@ -0,0 +1,176 @@ +// Copyright (C) 2023-2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "continuous_batching/kvcrush.hpp" + +#include +namespace ov::genai { + +KVCrushAlgorithm::KVCrushAlgorithm(const KVCrushConfig& kvcrush_config, size_t block_size) + : m_kvcrush_config(kvcrush_config), + m_block_size(block_size), + rng(std::mt19937(kvcrush_config.rng_seed)) {} + +// step 1: create_indicators_kvcrush() +std::vector KVCrushAlgorithm::create_indicators_kvcrush(size_t num_tokens_in_evictable_blocks, + + std::vector& evicted_block_indices, + const std::vector& layer_scores) { + // Step 1: Sort the scores of the blocks to be evicted + const auto& blocks_eligible_for_kvcrush = evicted_block_indices; + std::vector indices(num_tokens_in_evictable_blocks); + std::iota(indices.begin(), indices.end(), 0); + std::partial_sort(indices.begin(), + indices.begin() + num_tokens_in_evictable_blocks / 2, + indices.end(), + [&](size_t i, size_t j) { + return layer_scores[i] > layer_scores[j]; + }); + + std::vector indicators(num_tokens_in_evictable_blocks, 0); + for (size_t i = 0; i < num_tokens_in_evictable_blocks / 2; ++i) { + indicators[indices[i]] = 1; + } + return indicators; +} +// step 2: create_anchor_point_kvcrush() +std::vector KVCrushAlgorithm::create_anchor_point_kvcrush(size_t num_tokens_in_evictable_blocks, + + std::vector& indicators) { + // Step 2: Create a binary vector of size block_size as anchor point + std::vector anchor_point(m_block_size); + // Initialize anchor_point based on anchor using switch-case + switch (m_kvcrush_config.anchor_point_mode) { + case KVCrushAnchorPointMode::RANDOM: { + std::uniform_int_distribution dist(0, 1); + std::generate(anchor_point.begin(), anchor_point.end(), [&]() { + return dist(rng); + }); + } break; + case KVCrushAnchorPointMode::ZEROS: + std::fill(anchor_point.begin(), anchor_point.end(), 0); + break; + case KVCrushAnchorPointMode::ONES: + std::fill(anchor_point.begin(), anchor_point.end(), 1); + break; + case KVCrushAnchorPointMode::MEAN: { + size_t num_blocks = num_tokens_in_evictable_blocks / m_block_size; + for (size_t pos = 0; pos < m_block_size; pos++) { + // Calculate sum of indicators at this position across all blocks + size_t sum = 0; + for (size_t block_idx = 0; block_idx < num_blocks; block_idx++) { + size_t idx = block_idx * m_block_size + pos; + sum += indicators[idx]; + } + + // Calculate mean and set anchor point based on threshold (0.5) + double mean = static_cast(sum) / num_blocks; + anchor_point[pos] = (mean > 0.5) ? 1 : 0; + } + break; + } + case KVCrushAnchorPointMode::ALTERNATE: + for (size_t i = 0; i < m_block_size; ++i) { + anchor_point[i] = i % 2; + } + break; + default: + OPENVINO_THROW("Invalid anchor point type"); + } + return anchor_point; +} + +// step 3: calculate_hamming_distance() +std::vector> KVCrushAlgorithm::calculate_hamming_distance_kvcrush( + size_t num_tokens_in_evictable_blocks, + + std::vector& indicators, + std::vector& anchor_point) { + // Step 3: Calculate Hamming distances between anchor point and each block + size_t num_blocks = num_tokens_in_evictable_blocks / m_block_size; + std::vector> block_distances; // pair + block_distances.reserve(num_blocks); + + for (size_t block_idx = 0; block_idx < num_blocks; ++block_idx) { + size_t hamming_distance = 0; + for (size_t j = 0; j < m_block_size; ++j) { + size_t token_idx = block_idx * m_block_size + j; + if (token_idx < num_tokens_in_evictable_blocks) { + // Use the indicators vector to determine the bit value of this position + int bit_value = indicators[token_idx]; + if (bit_value != anchor_point[j]) { + hamming_distance++; + } + } + } + block_distances.emplace_back(hamming_distance, block_idx); + } + return block_distances; +} + +// step 4: get_representative_blocks() +std::vector KVCrushAlgorithm::get_representative_blocks_kvcrush( + + size_t num_tokens_in_evictable_blocks, + std::vector>& block_distances, + const std::vector& blocks_eligible_for_kvcrush) { + // Step 4: Find the representative blocks + // Filter block indices that are in blocks_eligible_for_kvcrush + std::vector filtered_block_indices; + filtered_block_indices.reserve(block_distances.size()); + + for (const auto& entry : block_distances) { + size_t block_idx = entry.second; + // Check if block_idx is in blocks_eligible_for_kvcrush + if (std::find(blocks_eligible_for_kvcrush.begin(), blocks_eligible_for_kvcrush.end(), block_idx) != + blocks_eligible_for_kvcrush.end()) { + filtered_block_indices.push_back(block_idx); + } + } + // Sort filtered_block_indices based on Hamming distance + std::sort(filtered_block_indices.begin(), filtered_block_indices.end(), [&](size_t a, size_t b) { + return block_distances[a].first < block_distances[b].first; + }); + // select kvcrush_budget number of blocks from filtered_block_indices, uniformly spaced + size_t num_blocks_to_retain = std::min(filtered_block_indices.size(), m_kvcrush_config.get_budget()); + size_t step = filtered_block_indices.size() / num_blocks_to_retain; + std::vector kvcrush_retained_block_indices; + kvcrush_retained_block_indices.reserve(num_blocks_to_retain); + for (size_t i = 0; i < num_blocks_to_retain; ++i) { + size_t idx = i * step; + if (idx < filtered_block_indices.size()) { + kvcrush_retained_block_indices.push_back(filtered_block_indices[idx]); + } + } + + return kvcrush_retained_block_indices; +} + +std::vector KVCrushAlgorithm::get_indices_of_blocks_to_retain_using_kvcrush( + + size_t num_tokens_in_evictable_blocks, + std::vector& evicted_block_indices, + const std::vector& layer_scores) { + // step 1: Create indicators_kvcrush makes binary feature vectors based on top-k/2 scores + const auto& blocks_eligible_for_kvcrush = evicted_block_indices; // only the blocks that are evicted by the score + // based eviction are eligible for kvcrush + + std::vector indicators = + create_indicators_kvcrush(num_tokens_in_evictable_blocks, evicted_block_indices, layer_scores); + + // Step 2: Create anchor_point based on the selected anchor point type + std::vector anchor_point = create_anchor_point_kvcrush(num_tokens_in_evictable_blocks, indicators); + + // Step 3: Calculate Hamming distances between anchor point and each block, where each block is represented by + // its binary feature vector called indicators + std::vector> block_distances = + calculate_hamming_distance_kvcrush(num_tokens_in_evictable_blocks, indicators, anchor_point); + + // Step 4: Find the representative blocks + // Filter block indices that are in blocks_eligible_for_kvcrush + return get_representative_blocks_kvcrush(num_tokens_in_evictable_blocks, + block_distances, + blocks_eligible_for_kvcrush); +} + +} // namespace ov::genai \ No newline at end of file diff --git a/src/cpp/src/continuous_batching/kvcrush.hpp b/src/cpp/src/continuous_batching/kvcrush.hpp new file mode 100644 index 0000000000..6c22d4d7b8 --- /dev/null +++ b/src/cpp/src/continuous_batching/kvcrush.hpp @@ -0,0 +1,62 @@ +// Copyright (C) 2023-2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include +#include +#include + +#include "continuous_batching/attention_output.hpp" +#include "openvino/genai/cache_eviction.hpp" +#include "openvino/openvino.hpp" + +namespace ov::genai { + +class KVCrushAlgorithm { +private: + /** + * @brief Random number generator used for stochastic operations within the class. + * + * This instance of std::mt19937 is initialized with the seed specified in + * m_kvcrush_config.rng_seed, ensuring reproducible random sequences for operations + * that require randomness, such as shuffling or sampling. + */ + std::mt19937 rng; + +public: + KVCrushAlgorithm() = default; // Default constructor for compatibility + explicit KVCrushAlgorithm(const KVCrushConfig& eviction_config, size_t block_size); + + /** @return A vector of size kvcrush_budget, where each element contains the index of a + * representative block selected using KVCrush algorithm */ + std::vector get_indices_of_blocks_to_retain_using_kvcrush( + size_t num_tokens_in_evictable_blocks, + std::vector& evicted_block_indices, + const std::vector& layer_scores); + /** @return A binary (feature) vector of size num_tokens_in_evictable_blocks, where each element indicates wheather + * the corresponding token has a high score */ + std::vector create_indicators_kvcrush(size_t num_tokens_in_evictable_blocks, + std::vector& evicted_block_indices, + const std::vector& layer_scores); + /** @return A binary vector of size block_size, where each individual element is selected based on the anchor point + * mode */ + std::vector create_anchor_point_kvcrush(size_t num_tokens_in_evictable_blocks, + std::vector& indicators); + /** @return A vector of size num_blocks, where each individual element (pair) + * represents it's hamming distance from the anchor point */ + std::vector> calculate_hamming_distance_kvcrush(size_t num_tokens_in_evictable_blocks, + std::vector& indicators, + std::vector& anchor_point); + /** @return A vector of size kvcrush_budget, where each element contains the index of a + * representative block selected using KVCrush algorithm */ + std::vector get_representative_blocks_kvcrush(size_t num_tokens_in_evictable_blocks, + std::vector>& block_distances, + const std::vector& keep_clus_eligible); + + KVCrushConfig m_kvcrush_config; + size_t m_block_size; +}; + +} // namespace ov::genai \ No newline at end of file diff --git a/src/python/openvino_genai/__init__.py b/src/python/openvino_genai/__init__.py index 2a40a707c0..218f782e29 100644 --- a/src/python/openvino_genai/__init__.py +++ b/src/python/openvino_genai/__init__.py @@ -96,7 +96,9 @@ CacheEvictionConfig, AggregationMode, SparseAttentionMode, - SparseAttentionConfig + SparseAttentionConfig, + KVCrushAnchorPointMode, + KVCrushConfig ) # RAG diff --git a/src/python/openvino_genai/__init__.pyi b/src/python/openvino_genai/__init__.pyi index ca6b3e725e..175df870eb 100644 --- a/src/python/openvino_genai/__init__.pyi +++ b/src/python/openvino_genai/__init__.pyi @@ -25,6 +25,8 @@ from openvino_genai.py_openvino_genai import Image2ImagePipeline from openvino_genai.py_openvino_genai import ImageGenerationConfig from openvino_genai.py_openvino_genai import ImageGenerationPerfMetrics from openvino_genai.py_openvino_genai import InpaintingPipeline +from openvino_genai.py_openvino_genai import KVCrushAnchorPointMode +from openvino_genai.py_openvino_genai import KVCrushConfig from openvino_genai.py_openvino_genai import LLMPipeline from openvino_genai.py_openvino_genai import PerfMetrics from openvino_genai.py_openvino_genai import RawImageGenerationPerfMetrics @@ -62,5 +64,5 @@ from openvino_genai.py_openvino_genai import draft_model from openvino_genai.py_openvino_genai import get_version import os as os from . import py_openvino_genai -__all__: list[str] = ['Adapter', 'AdapterConfig', 'AggregationMode', 'AutoencoderKL', 'CLIPTextModel', 'CLIPTextModelWithProjection', 'CacheEvictionConfig', 'ChunkStreamerBase', 'ContinuousBatchingPipeline', 'CppStdGenerator', 'DecodedResults', 'EncodedResults', 'FluxTransformer2DModel', 'GenerationConfig', 'GenerationFinishReason', 'GenerationResult', 'GenerationStatus', 'Generator', 'Image2ImagePipeline', 'ImageGenerationConfig', 'ImageGenerationPerfMetrics', 'InpaintingPipeline', 'LLMPipeline', 'PerfMetrics', 'RawImageGenerationPerfMetrics', 'RawPerfMetrics', 'SD3Transformer2DModel', 'Scheduler', 'SchedulerConfig', 'SparseAttentionConfig', 'SparseAttentionMode', 'SpeechGenerationConfig', 'SpeechGenerationPerfMetrics', 'StopCriteria', 'StreamerBase', 'StreamingStatus', 'StructuralTagItem', 'StructuralTagsConfig', 'StructuredOutputConfig', 'T5EncoderModel', 'Text2ImagePipeline', 'Text2SpeechDecodedResults', 'Text2SpeechPipeline', 'TextEmbeddingPipeline', 'TextRerankPipeline', 'TextStreamer', 'TokenizedInputs', 'Tokenizer', 'TorchGenerator', 'UNet2DConditionModel', 'VLMPipeline', 'WhisperGenerationConfig', 'WhisperPerfMetrics', 'WhisperPipeline', 'WhisperRawPerfMetrics', 'draft_model', 'get_version', 'openvino', 'os', 'py_openvino_genai'] +__all__: list[str] = ['Adapter', 'AdapterConfig', 'AggregationMode', 'AutoencoderKL', 'CLIPTextModel', 'CLIPTextModelWithProjection', 'CacheEvictionConfig', 'ChunkStreamerBase', 'ContinuousBatchingPipeline', 'CppStdGenerator', 'DecodedResults', 'EncodedResults', 'FluxTransformer2DModel', 'GenerationConfig', 'GenerationFinishReason', 'GenerationResult', 'GenerationStatus', 'Generator', 'Image2ImagePipeline', 'ImageGenerationConfig', 'ImageGenerationPerfMetrics', 'InpaintingPipeline', 'KVCrushAnchorPointMode', 'KVCrushConfig', 'LLMPipeline', 'PerfMetrics', 'RawImageGenerationPerfMetrics', 'RawPerfMetrics', 'SD3Transformer2DModel', 'Scheduler', 'SchedulerConfig', 'SparseAttentionConfig', 'SparseAttentionMode', 'SpeechGenerationConfig', 'SpeechGenerationPerfMetrics', 'StopCriteria', 'StreamerBase', 'StreamingStatus', 'StructuralTagItem', 'StructuralTagsConfig', 'StructuredOutputConfig', 'T5EncoderModel', 'Text2ImagePipeline', 'Text2SpeechDecodedResults', 'Text2SpeechPipeline', 'TextEmbeddingPipeline', 'TextRerankPipeline', 'TextStreamer', 'TokenizedInputs', 'Tokenizer', 'TorchGenerator', 'UNet2DConditionModel', 'VLMPipeline', 'WhisperGenerationConfig', 'WhisperPerfMetrics', 'WhisperPipeline', 'WhisperRawPerfMetrics', 'draft_model', 'get_version', 'openvino', 'os', 'py_openvino_genai'] __version__: str diff --git a/src/python/openvino_genai/py_openvino_genai.pyi b/src/python/openvino_genai/py_openvino_genai.pyi index 77be053fcc..8d475312a9 100644 --- a/src/python/openvino_genai/py_openvino_genai.pyi +++ b/src/python/openvino_genai/py_openvino_genai.pyi @@ -5,7 +5,7 @@ from __future__ import annotations import collections.abc import openvino._pyopenvino import typing -__all__: list[str] = ['Adapter', 'AdapterConfig', 'AggregationMode', 'AutoencoderKL', 'CLIPTextModel', 'CLIPTextModelWithProjection', 'CacheEvictionConfig', 'ChunkStreamerBase', 'ContinuousBatchingPipeline', 'CppStdGenerator', 'DecodedResults', 'EncodedGenerationResult', 'EncodedResults', 'ExtendedPerfMetrics', 'FluxTransformer2DModel', 'GenerationConfig', 'GenerationFinishReason', 'GenerationHandle', 'GenerationOutput', 'GenerationResult', 'GenerationStatus', 'Generator', 'Image2ImagePipeline', 'ImageGenerationConfig', 'ImageGenerationPerfMetrics', 'InpaintingPipeline', 'LLMPipeline', 'MeanStdPair', 'PerfMetrics', 'PipelineMetrics', 'RawImageGenerationPerfMetrics', 'RawPerfMetrics', 'SD3Transformer2DModel', 'SDPerModelsPerfMetrics', 'SDPerfMetrics', 'Scheduler', 'SchedulerConfig', 'SparseAttentionConfig', 'SparseAttentionMode', 'SpeechGenerationConfig', 'SpeechGenerationPerfMetrics', 'StopCriteria', 'StreamerBase', 'StreamingStatus', 'StructuralTagItem', 'StructuralTagsConfig', 'StructuredOutputConfig', 'SummaryStats', 'T5EncoderModel', 'Text2ImagePipeline', 'Text2SpeechDecodedResults', 'Text2SpeechPipeline', 'TextEmbeddingPipeline', 'TextRerankPipeline', 'TextStreamer', 'TokenizedInputs', 'Tokenizer', 'TorchGenerator', 'UNet2DConditionModel', 'VLMDecodedResults', 'VLMPerfMetrics', 'VLMPipeline', 'VLMRawPerfMetrics', 'WhisperDecodedResultChunk', 'WhisperDecodedResults', 'WhisperGenerationConfig', 'WhisperPerfMetrics', 'WhisperPipeline', 'WhisperRawPerfMetrics', 'draft_model', 'get_version'] +__all__: list[str] = ['Adapter', 'AdapterConfig', 'AggregationMode', 'AutoencoderKL', 'CLIPTextModel', 'CLIPTextModelWithProjection', 'CacheEvictionConfig', 'ChunkStreamerBase', 'ContinuousBatchingPipeline', 'CppStdGenerator', 'DecodedResults', 'EncodedGenerationResult', 'EncodedResults', 'ExtendedPerfMetrics', 'FluxTransformer2DModel', 'GenerationConfig', 'GenerationFinishReason', 'GenerationHandle', 'GenerationOutput', 'GenerationResult', 'GenerationStatus', 'Generator', 'Image2ImagePipeline', 'ImageGenerationConfig', 'ImageGenerationPerfMetrics', 'InpaintingPipeline', 'KVCrushAnchorPointMode', 'KVCrushConfig', 'LLMPipeline', 'MeanStdPair', 'PerfMetrics', 'PipelineMetrics', 'RawImageGenerationPerfMetrics', 'RawPerfMetrics', 'SD3Transformer2DModel', 'SDPerModelsPerfMetrics', 'SDPerfMetrics', 'Scheduler', 'SchedulerConfig', 'SparseAttentionConfig', 'SparseAttentionMode', 'SpeechGenerationConfig', 'SpeechGenerationPerfMetrics', 'StopCriteria', 'StreamerBase', 'StreamingStatus', 'StructuralTagItem', 'StructuralTagsConfig', 'StructuredOutputConfig', 'SummaryStats', 'T5EncoderModel', 'Text2ImagePipeline', 'Text2SpeechDecodedResults', 'Text2SpeechPipeline', 'TextEmbeddingPipeline', 'TextRerankPipeline', 'TextStreamer', 'TokenizedInputs', 'Tokenizer', 'TorchGenerator', 'UNet2DConditionModel', 'VLMDecodedResults', 'VLMPerfMetrics', 'VLMPipeline', 'VLMRawPerfMetrics', 'WhisperDecodedResultChunk', 'WhisperDecodedResults', 'WhisperGenerationConfig', 'WhisperPerfMetrics', 'WhisperPipeline', 'WhisperRawPerfMetrics', 'draft_model', 'get_version'] class Adapter: """ Immutable LoRA Adapter that carries the adaptation matrices and serves as unique adapter identifier. @@ -359,7 +359,8 @@ class CacheEvictionConfig: """ aggregation_mode: AggregationMode apply_rotation: bool - def __init__(self, start_size: typing.SupportsInt, recent_size: typing.SupportsInt, max_cache_size: typing.SupportsInt, aggregation_mode: AggregationMode, apply_rotation: bool = False, snapkv_window_size: typing.SupportsInt = 8) -> None: + kvcrush_config: KVCrushConfig + def __init__(self, start_size: typing.SupportsInt, recent_size: typing.SupportsInt, max_cache_size: typing.SupportsInt, aggregation_mode: AggregationMode, apply_rotation: bool = False, snapkv_window_size: typing.SupportsInt = 8, kvcrush_config: typing.Any = None) -> None: ... def get_evictable_size(self) -> int: ... @@ -1445,6 +1446,86 @@ class InpaintingPipeline: ... def set_scheduler(self, scheduler: Scheduler) -> None: ... +class KVCrushAnchorPointMode: + """ + Represents the anchor point types for KVCrush cache eviction + :param KVCrushAnchorPointMode.RANDOM: Random binary vector will be used as anchor point + :param KVCrushAnchorPointMode.ZEROS: Vector of all zeros will be used as anchor point + :param KVCrushAnchorPointMode.ONES: Vector of all ones will be used as anchor point + :param KVCrushAnchorPointMode.MEAN: Mean of indicator feature vector to be used as anchor point + :param KVCrushAnchorPointMode.ALTERNATE: Alternating 0s and 1s will be used as anchor point + + Members: + + RANDOM + + ZEROS + + ONES + + MEAN + + ALTERNATE + """ + ALTERNATE: typing.ClassVar[KVCrushAnchorPointMode] # value = + MEAN: typing.ClassVar[KVCrushAnchorPointMode] # value = + ONES: typing.ClassVar[KVCrushAnchorPointMode] # value = + RANDOM: typing.ClassVar[KVCrushAnchorPointMode] # value = + ZEROS: typing.ClassVar[KVCrushAnchorPointMode] # value = + __members__: typing.ClassVar[dict[str, KVCrushAnchorPointMode]] # value = {'RANDOM': , 'ZEROS': , 'ONES': , 'MEAN': , 'ALTERNATE': } + def __eq__(self, other: typing.Any) -> bool: + ... + def __getstate__(self) -> int: + ... + def __hash__(self) -> int: + ... + def __index__(self) -> int: + ... + def __init__(self, value: typing.SupportsInt) -> None: + ... + def __int__(self) -> int: + ... + def __ne__(self, other: typing.Any) -> bool: + ... + def __repr__(self) -> str: + ... + def __setstate__(self, state: typing.SupportsInt) -> None: + ... + def __str__(self) -> str: + ... + @property + def name(self) -> str: + ... + @property + def value(self) -> int: + ... +class KVCrushConfig: + """ + Configuration for KVCrush cache eviction algorithm + """ + anchor_point_mode: KVCrushAnchorPointMode + @typing.overload + def __init__(self) -> None: + """ + Default constructor + """ + @typing.overload + def __init__(self, budget: typing.SupportsInt, anchor_point_mode: KVCrushAnchorPointMode = ..., rng_seed: typing.SupportsInt = 0) -> None: + """ + Constructor with budget, anchor point mode, and RNG seed + """ + @property + def budget(self) -> int: + ... + @budget.setter + def budget(self, arg0: typing.SupportsInt) -> None: + ... + @property + def rng_seed(self) -> int: + ... + @rng_seed.setter + def rng_seed(self, arg0: typing.SupportsInt) -> None: + ... class LLMPipeline: """ This class is used for generation with LLMs diff --git a/src/python/py_continuous_batching_pipeline.cpp b/src/python/py_continuous_batching_pipeline.cpp index 05af43b7a7..3a25667936 100644 --- a/src/python/py_continuous_batching_pipeline.cpp +++ b/src/python/py_continuous_batching_pipeline.cpp @@ -30,6 +30,8 @@ using ov::genai::GenerationFinishReason; using ov::genai::GenerationStatus; using ov::genai::SchedulerConfig; using ov::genai::PipelineMetrics; +using ov::genai::KVCrushAnchorPointMode; +using ov::genai::KVCrushConfig; namespace { @@ -277,17 +279,52 @@ void init_continuous_batching_pipeline(py::module_& m) { :param AggregationMode.NORM_SUM: Same as SUM, but the importance scores are additionally divided by the lifetime (in tokens generated) of a given token in cache)") .value("SUM", AggregationMode::SUM) .value("NORM_SUM", AggregationMode::NORM_SUM); - + py::enum_(m, + "KVCrushAnchorPointMode", + R"(Represents the anchor point types for KVCrush cache eviction + :param KVCrushAnchorPointMode.RANDOM: Random binary vector will be used as anchor point + :param KVCrushAnchorPointMode.ZEROS: Vector of all zeros will be used as anchor point + :param KVCrushAnchorPointMode.ONES: Vector of all ones will be used as anchor point + :param KVCrushAnchorPointMode.MEAN: Mean of indicator feature vector to be used as anchor point + :param KVCrushAnchorPointMode.ALTERNATE: Alternating 0s and 1s will be used as anchor point)") + .value("RANDOM", KVCrushAnchorPointMode::RANDOM) + .value("ZEROS", KVCrushAnchorPointMode::ZEROS) + .value("ONES", KVCrushAnchorPointMode::ONES) + .value("MEAN", KVCrushAnchorPointMode::MEAN) + .value("ALTERNATE", KVCrushAnchorPointMode::ALTERNATE); + + py::class_(m, "KVCrushConfig", "Configuration for KVCrush cache eviction algorithm") + .def(py::init<>(), "Default constructor") + .def(py::init(), + "Constructor with budget, anchor point mode, and RNG seed", + py::arg("budget"), + py::arg_v("anchor_point_mode", KVCrushAnchorPointMode::RANDOM, "KVCrushAnchorPointMode.RANDOM"), + py::arg("rng_seed") = 0) + .def_readwrite("budget", &KVCrushConfig::budget) + .def_readwrite("anchor_point_mode", &KVCrushConfig::anchor_point_mode) + .def_readwrite("rng_seed", &KVCrushConfig::rng_seed); py::class_(m, "CacheEvictionConfig", cache_eviction_config_docstring) .def(py::init<>([](const size_t start_size, size_t recent_size, size_t max_cache_size, AggregationMode aggregation_mode, bool apply_rotation, - size_t snapkv_window_size) { - return CacheEvictionConfig{start_size, recent_size, max_cache_size, aggregation_mode, apply_rotation, snapkv_window_size}; }), + size_t snapkv_window_size, py::object kvcrush_config) { + if (kvcrush_config.is_none()) { + return CacheEvictionConfig{start_size, recent_size, max_cache_size, aggregation_mode, apply_rotation, snapkv_window_size}; + } else { + const auto& config_ref = kvcrush_config.cast(); + return CacheEvictionConfig{start_size, + recent_size, + max_cache_size, + aggregation_mode, + apply_rotation, + snapkv_window_size, + config_ref}; + } }), py::arg("start_size"), py::arg("recent_size"), py::arg("max_cache_size"), py::arg("aggregation_mode"), py::arg("apply_rotation") = false, - py::arg("snapkv_window_size") = 8) + py::arg("snapkv_window_size") = 8, py::arg("kvcrush_config") = py::none()) .def_readwrite("aggregation_mode", &CacheEvictionConfig::aggregation_mode) .def_readwrite("apply_rotation", &CacheEvictionConfig::apply_rotation) .def_readwrite("snapkv_window_size", &CacheEvictionConfig::snapkv_window_size) + .def_readwrite("kvcrush_config", &CacheEvictionConfig::kvcrush_config) .def("get_start_size", &CacheEvictionConfig::get_start_size) .def("get_recent_size", &CacheEvictionConfig::get_recent_size) .def("get_max_cache_size", &CacheEvictionConfig::get_max_cache_size) diff --git a/tests/cpp/cache_eviction.cpp b/tests/cpp/cache_eviction.cpp index 000d27f99a..7c7f6209f5 100644 --- a/tests/cpp/cache_eviction.cpp +++ b/tests/cpp/cache_eviction.cpp @@ -11,9 +11,15 @@ #include "gtest/gtest.h" #include "openvino/genai/cache_eviction.hpp" - -const ov::genai::CacheEvictionConfig DEFAULT_CACHE_EVICTION_CONFIG = {32, 32, 192, ov::genai::AggregationMode::NORM_SUM, false, 0}; -const ov::genai::CacheEvictionConfig SHORT_RECENT_EVICTION_CONFIG = {32, 32, 72, ov::genai::AggregationMode::NORM_SUM, false, 0}; +using ov::genai::KVCrushAnchorPointMode; +using ov::genai::KVCrushConfig; + +const ov::genai::CacheEvictionConfig DEFAULT_CACHE_EVICTION_CONFIG = + {32, 32, 192, ov::genai::AggregationMode::NORM_SUM, false, 0, KVCrushConfig(0, KVCrushAnchorPointMode::MEAN)}; +const ov::genai::CacheEvictionConfig SHORT_RECENT_EVICTION_CONFIG = + {32, 32, 72, ov::genai::AggregationMode::NORM_SUM, false, 0, KVCrushConfig(0, KVCrushAnchorPointMode::MEAN)}; +const ov::genai::CacheEvictionConfig KVCRUSH_CACHE_EVICTION_CONFIG = + {32, 32, 192, ov::genai::AggregationMode::NORM_SUM, false, 0, KVCrushConfig(2, KVCrushAnchorPointMode::MEAN)}; constexpr size_t DEFAULT_BLOCK_SIZE = 4; constexpr size_t DEFAULT_NUM_DECODER_LAYERS = 2; constexpr size_t DEFAULT_MAX_POOL_WINDOW_SIZE = 8; @@ -818,6 +824,39 @@ const std::vector LOW_SCORE_BLOCK_EVICTION_TEST_CASES {{0, 9, 10, 11, 13}, {12, 11, 8, 9, 17}}, {{8, 9, 10, 11, 13}, {8, 9, 10, 11, 12}} }, + //multiple blocks in evictable area - with KVCrush + { + "three_blocks_kvcrush", + 2 * 4 + 2, // 2 blocks worth of overflow + 2 tokens, amounting to 3 blocks to be evicted + KVCRUSH_CACHE_EVICTION_CONFIG, + {{28, 10, 11}, {18, 8, 31}}, + {{28, 10, 11}, {18, 8, 31}} + }, + //if there are more blocks with same low score than should be evicted, the lower-indexed ones should take precedence - with KVCrush + { + "four_zeroed_two_to_evict_kvcrush", + 1 * 4 + 2, // 2 blocks to be evicted + KVCRUSH_CACHE_EVICTION_CONFIG, + {{15, 36, 13, 10}, {9, 39, 31, 11}}, // 4 zeroed blocks + {{10, 13}, {9, 11}} + }, + //will prefer to evict lower-indexed blocks if there are multiple same-scored blocks - with KVCrush + { + "less_zeroed_than_to_evict_kvcrush", + 5 * 4 + 2, // 6 blocks to be evicted + KVCRUSH_CACHE_EVICTION_CONFIG, + {{}, {30, 22}}, // 1st layer has no zeroed blocks, 2nd has only 2 zeroed blocks + {{8, 9, 10, 11, 12, 13}, {8, 9, 10, 11, 22, 30}} // non-zeroed blocks to evict are taken from the beginning of evictable range + }, + + //low-scored blocks in non-evictable range do not lead to eviction - with KVCrush + { + "zeros_also_in_non_evictable_areas_kvcrush", + 5 * 4 + 2, // 6 blocks to be evicted + KVCRUSH_CACHE_EVICTION_CONFIG, + {{0, 2, 7, 24, 31, 49}, {5, 19, 27, 39, 50, 52}}, // 1st layer has 0, 2, 7 in start_area, 49 in recent_area; 2nd has 5 in start_area, 50, 54 in recent_area + {{8, 9, 10, 11, 24, 31}, {8, 9, 10, 19, 27, 39}} // eviction padded up to 6 blocks by blocks in the beginning of the evictable_area + } }; // clang-format on @@ -843,7 +882,24 @@ TEST_P(CacheEvictionLowScoreBlocksParameterizedTest, EvictsLowestScoredBlocks) { auto test_evicted_blocks = algo.evict_logical_blocks(); auto ref_evicted_blocks = test_struct.ref_evicted_blocks; for (size_t layer_idx = 0; layer_idx < num_decoder_layers; layer_idx++) { - EXPECT_EQ(test_evicted_blocks[layer_idx], ref_evicted_blocks[layer_idx]); + bool should_apply_kvcrush = + (test_struct.eviction_config.kvcrush_config.budget > 0) && + (std::ceil(static_cast(test_struct.tokens_over_max_cache_size) / DEFAULT_BLOCK_SIZE) >= + (test_struct.eviction_config.kvcrush_config.budget)); + if (should_apply_kvcrush) { + // check test_evicted_blocks is a subset of ref_evicted_blocks + EXPECT_TRUE(std::includes(ref_evicted_blocks[layer_idx].begin(), + ref_evicted_blocks[layer_idx].end(), + test_evicted_blocks[layer_idx].begin(), + test_evicted_blocks[layer_idx].end())); + EXPECT_EQ(static_cast(test_evicted_blocks[layer_idx].size()), + static_cast(ref_evicted_blocks[layer_idx].size()) - + static_cast(test_struct.eviction_config.kvcrush_config.budget)); + } + // if kvcrush is disabled + else { + EXPECT_EQ(test_evicted_blocks[layer_idx], ref_evicted_blocks[layer_idx]); + } } } @@ -1001,23 +1057,24 @@ struct CacheEvictionConfigInitParamsForTest { size_t recent_size; size_t max_cache_size; size_t snapkv_window_size; + size_t kvcrush_budget; }; using CacheEvictionConfigInitializationTest = ::testing::TestWithParam; const std::vector INVALID_CONFIG_INIT_PARAMS_CASES = { // zero area sizes - {32, 32, 64, 8}, - {0, 13, 39, 1}, - {128, 0, 384, 7}, + {32, 32, 64, 8, 10}, + {0, 13, 39, 1, 10}, + {128, 0, 384, 7, 10}, // max_cache_size less than start_size + recent_size - {32, 64, 32}, + {32, 64, 32, 8, 10}, }; TEST_P(CacheEvictionConfigInitializationTest, ThrowsForInvalidConfigParams) { auto params = GetParam(); - EXPECT_THROW(ov::genai::CacheEvictionConfig(params.start_size, params.recent_size, params.max_cache_size, ov::genai::AggregationMode::NORM_SUM, /* apply_rotation = */ false, params.snapkv_window_size), ov::Exception); + EXPECT_THROW(ov::genai::CacheEvictionConfig(params.start_size, params.recent_size, params.max_cache_size, ov::genai::AggregationMode::NORM_SUM, /* apply_rotation = */ false, params.snapkv_window_size, KVCrushConfig(2, KVCrushAnchorPointMode::MEAN)), ov::Exception); } INSTANTIATE_TEST_SUITE_P(VariousInvalidInitParams, CacheEvictionConfigInitializationTest, @@ -1033,12 +1090,12 @@ using CacheEvictionAlgoInitializationTest = ::testing::TestWithParam INVALID_ALGO_INIT_PARAMS_CASES = { // area sizes not multiple of block size - { {32, 32, 97, ov::genai::AggregationMode::SUM}, 16, 8}, - { {11, 13, 50, ov::genai::AggregationMode::NORM_SUM}, 13, 1}, - { {128, 200, 584, ov::genai::AggregationMode::NORM_SUM}, 128, 19}, + { {32, 32, 97, ov::genai::AggregationMode::SUM, false, 8, KVCrushConfig(2, KVCrushAnchorPointMode::MEAN)}, 16, 8}, + { {11, 13, 50, ov::genai::AggregationMode::NORM_SUM, false, 8, KVCrushConfig(2, KVCrushAnchorPointMode::MEAN)}, 13, 1}, + { {128, 200, 584, ov::genai::AggregationMode::NORM_SUM, false, 8, KVCrushConfig(2, KVCrushAnchorPointMode::MEAN)}, 128, 19}, // zero decoder layers - { {32, 64, 192, ov::genai::AggregationMode::SUM}, 32, 0}, + { {32, 64, 192, ov::genai::AggregationMode::SUM, false, 8, KVCrushConfig(2, KVCrushAnchorPointMode::MEAN)}, 32, 0}, }; TEST_P(CacheEvictionAlgoInitializationTest, ThrowsForInvalidConfigs) { auto params = GetParam(); diff --git a/tests/cpp/kvcrush.cpp b/tests/cpp/kvcrush.cpp new file mode 100644 index 0000000000..3b92116cb3 --- /dev/null +++ b/tests/cpp/kvcrush.cpp @@ -0,0 +1,360 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "continuous_batching/kvcrush.hpp" + +#include +#include +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "openvino/genai/cache_eviction.hpp" + +namespace ov::genai::tests { +// Test fixture for KVCrushAlgorithm tests +class KVCrushAlgorithmTest : public ::testing::Test { +protected: + void SetUp() override { + // Default test configuration + m_kvcrush_config.budget = 8; + m_kvcrush_config.anchor_point_mode = KVCrushAnchorPointMode::RANDOM; + m_kvcrush_config.rng_seed = 42; // Fixed seed for reproducibility + + m_block_size = 4; + m_num_decoder_layers = 2; + + // Create the algorithm instance + m_kvcrush_algo = KVCrushAlgorithm(m_kvcrush_config, m_block_size); + } + + // Helper to create mock attention scores + std::vector> create_mock_scores(size_t num_layers, size_t sequence_length) { + std::vector> scores(num_layers); + for (size_t i = 0; i < num_layers; i++) { + scores[i].resize(sequence_length); + for (size_t j = 0; j < sequence_length; j++) { + scores[i][j] = static_cast(j % 10) / 10.0; // Repeating pattern 0.0, 0.1, ..., 0.9 + } + } + return scores; + } + + // Helper function to setup anchor points with given mode + std::vector setup_anchor_points(KVCrushAnchorPointMode mode, size_t num_tokens = 16) { + // Common setup for anchor point tests + std::vector indicators(num_tokens, 1); // All indicators set to 1 + + m_kvcrush_config.anchor_point_mode = mode; + m_kvcrush_algo = KVCrushAlgorithm(m_kvcrush_config, m_block_size); + + return m_kvcrush_algo.create_anchor_point_kvcrush(num_tokens, indicators); + } + + KVCrushConfig m_kvcrush_config; + KVCrushAlgorithm m_kvcrush_algo; + size_t m_block_size; + size_t m_num_decoder_layers; +}; + +// Test KVCrushAlgorithm constructor +TEST_F(KVCrushAlgorithmTest, ConstructorTest) { + KVCrushAlgorithm algo(m_kvcrush_config, m_block_size); + EXPECT_EQ(algo.m_kvcrush_config.budget, 8); + EXPECT_EQ(algo.m_kvcrush_config.anchor_point_mode, KVCrushAnchorPointMode::RANDOM); + EXPECT_EQ(algo.m_kvcrush_config.rng_seed, 42); +} + +// Test create_indicators_kvcrush +TEST_F(KVCrushAlgorithmTest, CreateIndicatorsTest) { + // Setup + size_t decoder_layer_idx = 0; + size_t num_tokens_in_evictable_blocks = 20; + std::vector evicted_block_indices = {0, 1, 2, 3, 4}; + auto scores = create_mock_scores(m_num_decoder_layers, num_tokens_in_evictable_blocks); + + // Execute + auto indicators = m_kvcrush_algo.create_indicators_kvcrush(num_tokens_in_evictable_blocks, + evicted_block_indices, + scores[decoder_layer_idx]); + + // Verify + EXPECT_EQ(indicators.size(), num_tokens_in_evictable_blocks); + + // Check that half of the indicators are set to 1 (top-k/2) + size_t count_ones = std::count(indicators.begin(), indicators.end(), 1); + EXPECT_EQ(count_ones, num_tokens_in_evictable_blocks / 2); +} + +// Test create_anchor_point_kvcrush with RANDOM mode +TEST_F(KVCrushAlgorithmTest, CreateAnchorPointRandomTest) { + // Setup using helper function + auto anchor_point = setup_anchor_points(KVCrushAnchorPointMode::RANDOM, 16); + + // Verify basic properties + EXPECT_EQ(anchor_point.size(), m_block_size); + + // 1. Check that all values are less than 1 (should be 0 or 1 only) + for (const auto& val : anchor_point) { + EXPECT_LE(val, 1) << "Anchor point values should be 0 or 1 (less than or equal to 1)"; + EXPECT_GE(val, 0) << "Anchor point values should be non-negative"; + } + + // 2. Check that they are bitwise different (not all same value) + bool all_same = true; + size_t first_value = anchor_point[0]; + for (size_t i = 1; i < anchor_point.size(); i++) { + if (anchor_point[i] != first_value) { + all_same = false; + break; + } + } + EXPECT_FALSE(all_same) << "Random anchor point should not have all identical values"; + + // 3. Additional randomness check: run multiple times with different seeds + // and verify we get different results + std::vector> multiple_results; + for (size_t seed = 0; seed < 5; seed++) { + KVCrushConfig temp_config = m_kvcrush_config; + temp_config.rng_seed = seed; + KVCrushAlgorithm temp_algo(temp_config, m_block_size); + + size_t num_tokens_in_evictable_blocks = 16; + std::vector indicators(num_tokens_in_evictable_blocks, 1); + auto temp_anchor_point = temp_algo.create_anchor_point_kvcrush(num_tokens_in_evictable_blocks, indicators); + multiple_results.push_back(temp_anchor_point); + } + + // Check that different seeds produce different results (proving randomness) + bool found_different = false; + for (size_t i = 1; i < multiple_results.size(); i++) { + if (multiple_results[0] != multiple_results[i]) { + found_different = true; + break; + } + } + EXPECT_TRUE(found_different) << "Different random seeds should produce different anchor points"; + + // 4. Check that we have a mix of 0s and 1s (not all zeros or all ones) + size_t count_zeros = std::count(anchor_point.begin(), anchor_point.end(), 0); + size_t count_ones = std::count(anchor_point.begin(), anchor_point.end(), 1); + + EXPECT_GT(count_zeros, 0) << "Random anchor point should contain some zeros"; + EXPECT_GT(count_ones, 0) << "Random anchor point should contain some ones"; + EXPECT_EQ(count_zeros + count_ones, anchor_point.size()) << "All values should be either 0 or 1"; +} + +// Test create_anchor_point_kvcrush with ALTERNATE mode +TEST_F(KVCrushAlgorithmTest, CreateAnchorPointAlternateTest) { + // Setup using helper function + auto anchor_point = setup_anchor_points(KVCrushAnchorPointMode::ALTERNATE, 16); + + // Verify + EXPECT_EQ(anchor_point.size(), m_block_size); + + // In ALTERNATE mode, we expect alternating blocks to be selected + // We should have non-zero values at regular intervals + bool found_pattern = false; + for (size_t i = 0; i < anchor_point.size(); i += m_block_size) { + bool has_nonzero = false; + for (size_t j = 0; j < m_block_size; j++) { + if (anchor_point[i + j] > 0) { + has_nonzero = true; + break; + } + } + if (has_nonzero) { + found_pattern = true; + break; + } + } + EXPECT_TRUE(found_pattern); +} + +// Test create_anchor_point_kvcrush with ZEROS mode +TEST_F(KVCrushAlgorithmTest, CreateAnchorPointZerosTest) { + // Setup using helper function + auto anchor_point = setup_anchor_points(KVCrushAnchorPointMode::ZEROS, 16); + + // Verify - For ZEROS mode, we expect the anchor point to be all zeros + EXPECT_EQ(anchor_point.size(), m_block_size); + + // Check that all values are zeros + for (const auto& val : anchor_point) { + EXPECT_EQ(val, 0) << "Anchor point value should be zero in ZEROS mode"; + } +} + +// Test create_anchor_point_kvcrush with ONES mode +TEST_F(KVCrushAlgorithmTest, CreateAnchorPointOnesTest) { + // Setup using helper function + auto anchor_point = setup_anchor_points(KVCrushAnchorPointMode::ONES, 16); + + // Verify - For ONES mode, we expect the anchor point to be all ones + EXPECT_EQ(anchor_point.size(), m_block_size); + + // Check that all values are ones + for (const auto& val : anchor_point) { + EXPECT_EQ(val, 1) << "Anchor point value should be one in ONES mode"; + } +} + +// Test create_anchor_point_kvcrush with MEAN mode +TEST_F(KVCrushAlgorithmTest, CreateAnchorPointMeanTest) { + // Setup + m_kvcrush_config.anchor_point_mode = KVCrushAnchorPointMode::MEAN; + KVCrushAlgorithm algo(m_kvcrush_config, m_block_size); + + size_t num_tokens_in_evictable_blocks = 20; + // Create indicators with a mix of 0s and 1s to test mean calculation + std::vector indicators = {1, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0}; + + // Execute + auto anchor_point = algo.create_anchor_point_kvcrush(num_tokens_in_evictable_blocks, indicators); + + // Verify + // For MEAN mode, we expect anchor points to be the mean of indicators in each block + size_t expected_size = num_tokens_in_evictable_blocks / m_block_size; + EXPECT_EQ(anchor_point.size(), m_block_size); + + // Expected means (rounded or converted to binary based on implementation) + if (anchor_point.size() == 4) { + // If using binary threshold (e.g., mean ≥ 0.5 -> 1, else 0) + EXPECT_EQ(anchor_point[0], 1) << "Block 1: 50% ones should round to 0"; + EXPECT_EQ(anchor_point[1], 0) << "Block 2: 50% ones should round to 0"; + EXPECT_EQ(anchor_point[2], 1) << "Block 3: 75% ones should round to 1"; + EXPECT_EQ(anchor_point[3], 0) << "Block 4: 25% ones should round to 0"; + } +} +// Test calculate_hamming_distance_kvcrush (parameterized) +TEST_F(KVCrushAlgorithmTest, CalculateHammingDistanceParameterizedTest) { + struct TestCase { + std::vector indicators; + std::vector anchor_point; + std::vector expected_distances; + }; + + std::vector test_cases = { + // Case 1: Alternating 1s and 0s, anchor_point {0, 1, 0, 1} + { + {1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0}, + {0, 1, 0, 1}, + {4, 4, 4} + }, + // Case 2: All indicators match anchor_point + { + {0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1}, + {0, 1, 0, 1}, + {0, 0, 0} + }, + // Case 3: All indicators are 1, anchor_point all 0 + { + {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, + {0, 0, 0, 0}, + {4, 4, 4} + }, + // Case 4: All indicators are 0, anchor_point all 1 + { + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + {1, 1, 1, 1}, + {4, 4, 4} + }, + // Case 5: Mixed indicators, anchor_point {1, 0, 1, 0} + { + {1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0}, + {1, 0, 1, 0}, + {0, 4, 2} + } + }; + + size_t num_tokens_in_evictable_blocks = 12; + for (const auto& tc : test_cases) { + ASSERT_EQ(tc.indicators.size(), num_tokens_in_evictable_blocks); + ASSERT_EQ(tc.anchor_point.size(), m_block_size); + + // Create mutable copies since the function expects non-const references + std::vector indicators_copy = tc.indicators; + std::vector anchor_point_copy = tc.anchor_point; + + auto block_distances = + m_kvcrush_algo.calculate_hamming_distance_kvcrush(num_tokens_in_evictable_blocks, indicators_copy, anchor_point_copy); + + size_t num_blocks = num_tokens_in_evictable_blocks / m_block_size; + EXPECT_EQ(block_distances.size(), num_blocks); + + for (size_t i = 0; i < num_blocks; ++i) { + EXPECT_EQ(block_distances[i].first, tc.expected_distances[i]) + << "Block " << i << " failed for indicators: " + << ::testing::PrintToString(tc.indicators) + << " and anchor_point: " + << ::testing::PrintToString(tc.anchor_point); + } + } +} + +// Parameterized test for get_representative_blocks_kvcrush +TEST_F(KVCrushAlgorithmTest, GetRepresentativeBlocksParameterizedTest) { + struct TestCase { + std::vector> block_distances; + std::vector keep_clus_eligible; + size_t budget; + std::vector expected_blocks; + }; + + std::vector test_cases = { + // Case 1: Uniform distances, select 2 out of 4 + { + {{0, 0}, {1, 1}, {2, 2}, {3, 3}}, + {0, 1, 2, 3}, + 2, + {0, 2} + }, + // Case 2: All blocks eligible, budget equals number of blocks + { + {{0, 0}, {1, 1}, {2, 2}, {3, 3}}, + {0, 1, 2, 3}, + 4, + {0, 1, 2, 3} + }, + // Case 3: Only some blocks eligible + { + {{0, 0}, {1, 1}, {2, 2}, {3, 3}}, + {1, 3}, + 2, + {1, 3} + }, + // Case 4: Budget larger than eligible blocks + { + {{0, 0}, {1, 1}, {2, 2}, {3, 3}}, + {2, 3}, + 5, + {2, 3} + } + }; + + size_t num_tokens_in_evictable_blocks = 16; + for (const auto& tc : test_cases) { + m_kvcrush_config.budget = tc.budget; + KVCrushAlgorithm algo(m_kvcrush_config, m_block_size); + + // Create mutable copies since the function expects non-const references + std::vector> block_distances_copy = tc.block_distances; + + auto representative_blocks = + algo.get_representative_blocks_kvcrush(num_tokens_in_evictable_blocks, block_distances_copy, tc.keep_clus_eligible); + + EXPECT_EQ(representative_blocks.size(), tc.expected_blocks.size()) + << "Failed for budget " << tc.budget << " and eligible blocks " + << ::testing::PrintToString(tc.keep_clus_eligible); + + for (size_t i = 0; i < tc.expected_blocks.size(); ++i) { + EXPECT_EQ(representative_blocks[i], tc.expected_blocks[i]) + << "Block " << i << " failed for distances: " + << ::testing::PrintToString(tc.block_distances) + << " and eligible: " + << ::testing::PrintToString(tc.keep_clus_eligible); + } + } +} + +} // namespace ov::genai::tests \ No newline at end of file diff --git a/tests/python_tests/test_kv_cache_eviction.py b/tests/python_tests/test_kv_cache_eviction.py index ecd81ba4c2..1206c88da5 100644 --- a/tests/python_tests/test_kv_cache_eviction.py +++ b/tests/python_tests/test_kv_cache_eviction.py @@ -9,7 +9,7 @@ from typing import Optional from tqdm import tqdm -from openvino_genai import ContinuousBatchingPipeline, SchedulerConfig, GenerationConfig, CacheEvictionConfig, AggregationMode, SparseAttentionMode +from openvino_genai import ContinuousBatchingPipeline, SchedulerConfig, GenerationConfig, CacheEvictionConfig, AggregationMode, SparseAttentionMode, KVCrushAnchorPointMode, KVCrushConfig from utils.ov_genai_pipelines import PipelineType, generate_and_compare from utils.longbench import dataset2maxlen, evaluate, preprocess_prompt, post_process_pred @@ -49,6 +49,27 @@ class CacheOptTestStruct: SHORT_CACHE_EVICTION_CONFIG = CacheEvictionConfig(start_size=32, recent_size=32, max_cache_size=96, aggregation_mode=AggregationMode.NORM_SUM) LONGBENCH_CACHE_EVICTION_CONFIG = CacheEvictionConfig(start_size=32, recent_size=128, max_cache_size=672, aggregation_mode=AggregationMode.NORM_SUM) +# KVCrush test configurations +KVCRUSH_SNAPKV_BASELINE_CONFIG = CacheEvictionConfig( + start_size=32, + recent_size=128, + max_cache_size=512, + aggregation_mode=AggregationMode.NORM_SUM, + apply_rotation=False, + snapkv_window_size=8, + kvcrush_config=KVCrushConfig(budget=0) +) + +KVCRUSH_TEST_CONFIG = CacheEvictionConfig( + start_size=32, + recent_size=128, + max_cache_size=480, + aggregation_mode=AggregationMode.NORM_SUM, + apply_rotation=False, + snapkv_window_size=8, + kvcrush_config=KVCrushConfig(budget=1, anchor_point_mode=KVCrushAnchorPointMode.ALTERNATE) +) + @pytest.mark.precommit @pytest.mark.skipif( @@ -190,6 +211,11 @@ class LongBenchTestData: avg_cache_usage_optimization_ratio: float +@dataclass +class KVCrushTestData: + subset: str + + @pytest.mark.precommit @pytest.mark.parametrize("test_struct", [ LongBenchTestData("samsum", 4, 1.6, 2.5), @@ -266,3 +292,75 @@ def test_optimized_generation_longbench(test_struct): assert ref_score - score <= test_struct.threshold assert max_optimization_ratio >= test_struct.max_cache_usage_optimization_ratio assert avg_optimization_ratio >= test_struct.avg_cache_usage_optimization_ratio + + +@pytest.mark.nightly +@pytest.mark.parametrize("device", ["CPU"]) +@pytest.mark.parametrize("test_struct", [ + KVCrushTestData(subset="samsum"), + KVCrushTestData(subset="trec"), +], ids=["samsum", "trec"]) +def test_kvcrush_vs_snapkv_baseline(device, test_struct): + """Test that KVCrush performs equal or better than SnapKV baseline on LongBench datasets.""" + seqs_per_request = 32 + num_kv_blocks = 1000 if device == "CPU" else 500 + model_id = "Qwen/Qwen2-0.5B-Instruct" + _, _, models_path = download_and_convert_model(model_id) + + # Setup baseline and KVCrush configurations + scheduler_config_baseline = get_scheduler_config(num_kv_blocks) + scheduler_config_baseline.use_cache_eviction = True + scheduler_config_baseline.cache_eviction_config = KVCRUSH_SNAPKV_BASELINE_CONFIG + + scheduler_config_kvcrush = get_scheduler_config(num_kv_blocks) + scheduler_config_kvcrush.use_cache_eviction = True + scheduler_config_kvcrush.cache_eviction_config = KVCRUSH_TEST_CONFIG + + model_cb_baseline = ContinuousBatchingPipeline(models_path, scheduler_config_baseline, device, {}, get_default_llm_properties()) + model_cb_kvcrush = ContinuousBatchingPipeline(models_path, scheduler_config_kvcrush, device, {}, get_default_llm_properties()) + + model_name = "/".join(models_path.parts[-2:]) + subset = test_struct.subset + max_new_tokens = dataset2maxlen[subset] + + generation_config = GenerationConfig() + generation_config.num_return_sequences = 1 + generation_config.max_new_tokens = max_new_tokens + generation_config.apply_chat_template = False + + data = datasets.load_dataset('THUDM/LongBench', subset, split='test[:32]') + with tqdm(total=len(data)) as progress_bar: + batch = [] + baseline_answers = [] + kvcrush_answers = [] + for p_idx, data_sample in enumerate(data): + prompt = preprocess_prompt(data_sample, subset, model_name) + progress_bar.update(1) + batch.append(prompt) + baseline_answers.append({"answers": data_sample["answers"], "all_classes": data_sample["all_classes"]}) + kvcrush_answers.append({"answers": data_sample["answers"], "all_classes": data_sample["all_classes"]}) + + if len(batch) == seqs_per_request or p_idx == len(data) - 1: + baseline_batch = model_cb_baseline.generate( + batch, [generation_config] * len(batch) + ) + kvcrush_batch = model_cb_kvcrush.generate( + batch, [generation_config] * len(batch) + ) + for i, (baseline_output, kvcrush_output) in enumerate(zip(baseline_batch, kvcrush_batch), start=p_idx-len(batch)+1): + baseline_answers[i]["pred"] = post_process_pred(baseline_output.m_generation_ids[0], subset, model_name) + kvcrush_answers[i]["pred"] = post_process_pred(kvcrush_output.m_generation_ids[0], subset, model_name) + batch.clear() + + baseline_score = evaluate(baseline_answers, subset) + kvcrush_score = evaluate(kvcrush_answers, subset) + + print(f"Baseline (SnapKV) score: {baseline_score}") + print(f"KVCrush score: {kvcrush_score}") + + assert kvcrush_score >= baseline_score, f"KVCrush score ({kvcrush_score}) is worse than baseline ({baseline_score}) on {subset} dataset" + + del model_cb_baseline + del model_cb_kvcrush + import gc + gc.collect() diff --git a/tools/continuous_batching/benchmark/continuous_batching_benchmark.cpp b/tools/continuous_batching/benchmark/continuous_batching_benchmark.cpp index d7cad80fd0..afb6dd4265 100644 --- a/tools/continuous_batching/benchmark/continuous_batching_benchmark.cpp +++ b/tools/continuous_batching/benchmark/continuous_batching_benchmark.cpp @@ -488,7 +488,7 @@ int main(int argc, char* argv[]) try { scheduler_config.max_num_seqs = 256; // not used if dynamic_split_fuse=True if (use_cache_eviction) { scheduler_config.use_cache_eviction = true; - scheduler_config.cache_eviction_config = ov::genai::CacheEvictionConfig(32, 32, 128, ov::genai::AggregationMode::NORM_SUM); + scheduler_config.cache_eviction_config = ov::genai::CacheEvictionConfig(32, 32, 128, ov::genai::AggregationMode::NORM_SUM, false, 8, ov::genai::KVCrushConfig(0, ov::genai::KVCrushAnchorPointMode::MEAN)); } std::cout << "Benchmarking parameters: " << std::endl;