Skip to content

Commit 26b2cb6

Browse files
authored
Fix SnapKV scoring (#2591)
Adds an option to disable SnapKV entirely and fixes token occurence counts in the SnapKV case. Also adds unit tests for the associated logic.
1 parent e725053 commit 26b2cb6

File tree

6 files changed

+280
-64
lines changed

6 files changed

+280
-64
lines changed

src/cpp/include/openvino/genai/cache_eviction.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ class CacheEvictionConfig {
3030
OPENVINO_ASSERT(start_size, "CacheEvictionConfig.start_size must be non-zero");
3131
OPENVINO_ASSERT(recent_size, "CacheEvictionConfig.recent_size must be non-zero");
3232
OPENVINO_ASSERT(max_cache_size, "CacheEvictionConfig.max_cache_size must be non-zero");
33-
OPENVINO_ASSERT(snapkv_window_size, "CacheEvictionConfig.snapkv_window_size must be non-zero");
3433

3534
OPENVINO_ASSERT(max_cache_size > (start_size + recent_size),
3635
"CacheEvictionConfig.max_cache_size must be larger than CacheEvictionConfig.start_size + CacheEvictionConfig.recent_size");
@@ -73,7 +72,8 @@ class CacheEvictionConfig {
7372

7473
/** The size of the importance score aggregation window (in token positions from the end of the prompt) for
7574
* computing initial importance scores at the beginning of the generation phase for purposes of eviction,
76-
* following the SnapKV article approach (https://arxiv.org/abs/2404.14469). **/
75+
* following the SnapKV article approach (https://arxiv.org/abs/2404.14469). Setting this to 0 disables the SnapKV
76+
* score aggregation. **/
7777
size_t snapkv_window_size = 8;
7878

7979
private:

src/cpp/src/continuous_batching/cache_eviction.cpp

Lines changed: 72 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,15 @@ namespace ov::genai {
5050

5151
void EvictionScoreManager::register_new_token_scores(
5252
const AttentionScoresForEachDecoderLayer &attention_scores_for_all_decoder_layers,
53-
const std::set<size_t>& skipped_logical_block_ids) {
53+
const std::set<size_t>& skipped_logical_block_ids, size_t num_snapkv_scores) {
54+
55+
if (m_num_registered_snapkv_aggregated_scores < m_snapkv_window_size) {
56+
OPENVINO_ASSERT(num_snapkv_scores + m_num_registered_snapkv_aggregated_scores <= m_snapkv_window_size, "Total number of aggregated SnapKV scores during prefill phase may not be larger than the configured SnapKV window size");
57+
m_num_registered_snapkv_aggregated_scores += num_snapkv_scores;
58+
}
59+
60+
// FIXME (vshampor): currently in terms of counters we do not discern between the cases when the last chunk has been prefill-only
61+
// or last-prefill-chunk-plus-one-generation_token
5462
for (size_t decoder_layer_idx = 0; decoder_layer_idx < m_num_decoder_layers; decoder_layer_idx++) {
5563

5664
const auto &attention_scores = attention_scores_for_all_decoder_layers[decoder_layer_idx];
@@ -105,9 +113,14 @@ namespace ov::genai {
105113
max_pooled_hh_scores[idx] = max_val;
106114
}
107115

108-
auto &accumulated_scores_for_current_decoder_layer = m_scores[decoder_layer_idx];
116+
auto& accumulated_scores_for_current_decoder_layer = m_scores[decoder_layer_idx];
109117

110118
if (accumulated_scores_for_current_decoder_layer.empty()) {
119+
if (m_snapkv_window_size != 0 && num_snapkv_scores == 0) {
120+
// SnapKV window not yet reached, no meaningful scores to accumulate
121+
continue;
122+
}
123+
// New sequence to track
111124
if (skipped_logical_block_ids.empty()) {
112125
accumulated_scores_for_current_decoder_layer = max_pooled_hh_scores;
113126
}
@@ -127,12 +140,20 @@ namespace ov::genai {
127140
}
128141

129142
if (m_aggregation_mode == AggregationMode::NORM_SUM) {
130-
// New sequence to track - will simulate that the tokens comprising the sequence were added one-by-one
131-
// from the standpoint of the occurrence tracker
132143
std::size_t new_scores_size = num_hh_scores;
133144
std::vector<std::size_t> counter(new_scores_size);
134-
std::generate(counter.begin(), counter.begin() + new_scores_size,
135-
[&new_scores_size] { return new_scores_size--; });
145+
if (m_snapkv_window_size == 0) {
146+
// Will simulate that the tokens comprising the sequence were added one-by-one
147+
// from the standpoint of the occurrence tracker
148+
std::generate(counter.begin(), counter.begin() + new_scores_size,
149+
[&new_scores_size] { return new_scores_size--; });
150+
}
151+
else {
152+
OPENVINO_ASSERT(num_snapkv_scores > 0);
153+
OPENVINO_ASSERT(new_scores_size >= num_snapkv_scores);
154+
std::fill(counter.begin(), counter.end() - num_snapkv_scores, num_snapkv_scores);
155+
std::iota(counter.rbegin(), counter.rbegin() + num_snapkv_scores, 1);
156+
}
136157
m_cache_counter[decoder_layer_idx] = counter;
137158
}
138159
} else {
@@ -142,18 +163,26 @@ namespace ov::genai {
142163
OPENVINO_ASSERT(new_size_in_tokens >= old_size_in_tokens);
143164
size_t num_new_tokens = new_size_in_tokens - old_size_in_tokens;
144165
if (m_aggregation_mode == AggregationMode::NORM_SUM) {
145-
// Increment occurrence counts of all currently tracked cache blocks
146166
auto &counter_for_current_decoder_layer = m_cache_counter[decoder_layer_idx];
147-
for (auto it = counter_for_current_decoder_layer.begin();
148-
it != counter_for_current_decoder_layer.end(); it++) {
149-
*it += num_new_tokens;
150-
}
151-
// Add occurrence counts for new tokens like above
152167
counter_for_current_decoder_layer.resize(new_size_in_tokens);
153-
for (size_t i = 0; i < num_new_tokens; i++) {
154-
auto idx = old_size_in_tokens + i;
155-
counter_for_current_decoder_layer[idx] = num_new_tokens - i;
168+
if (m_snapkv_window_size == 0 || m_num_registered_snapkv_aggregated_scores == m_snapkv_window_size) {
169+
// Increment occurrence counts of all currently tracked cache blocks
170+
for (auto it = counter_for_current_decoder_layer.begin();
171+
it != counter_for_current_decoder_layer.end(); it++) {
172+
*it += num_new_tokens;
173+
}
174+
// Add occurrence counts for new tokens like above
175+
for (size_t i = 0; i < num_new_tokens; i++) {
176+
auto idx = old_size_in_tokens + i;
177+
counter_for_current_decoder_layer[idx] = num_new_tokens - i;
178+
}
156179
}
180+
else {
181+
OPENVINO_ASSERT(new_size_in_tokens >= m_num_registered_snapkv_aggregated_scores);
182+
std::fill(counter_for_current_decoder_layer.begin(), counter_for_current_decoder_layer.end() - m_num_registered_snapkv_aggregated_scores, m_num_registered_snapkv_aggregated_scores);
183+
std::iota(counter_for_current_decoder_layer.rbegin(), counter_for_current_decoder_layer.rbegin() + m_num_registered_snapkv_aggregated_scores, 1);
184+
}
185+
157186
}
158187
accumulated_scores_for_current_decoder_layer.resize(new_size_in_tokens);
159188
add_with_skips(accumulated_scores_for_current_decoder_layer, max_pooled_hh_scores, skip_set_adjusted);
@@ -191,7 +220,7 @@ namespace ov::genai {
191220
CacheEvictionAlgorithm::CacheEvictionAlgorithm(const CacheEvictionConfig &eviction_config, size_t block_size,
192221
size_t num_decoder_layers, size_t max_pool_window_size) :
193222
m_eviction_config(eviction_config), m_block_size(block_size), m_num_decoder_layers(num_decoder_layers),
194-
m_score_manager(block_size, num_decoder_layers, max_pool_window_size, eviction_config.aggregation_mode, eviction_config.get_start_size() / block_size)
223+
m_score_manager(block_size, num_decoder_layers, max_pool_window_size, eviction_config.aggregation_mode, eviction_config.get_start_size() / block_size, eviction_config.snapkv_window_size)
195224
{
196225
OPENVINO_ASSERT(!(m_eviction_config.get_start_size() % m_block_size),
197226
"CacheEvictionConfig.start_size in tokens must be a multiple of block size ", m_block_size);
@@ -263,14 +292,15 @@ namespace ov::genai {
263292
}
264293

265294
void CacheEvictionAlgorithm::register_new_token_scores(
266-
const AttentionScoresForEachDecoderLayer &attention_scores_for_all_decoder_layers) {
267-
register_new_token_scores(attention_scores_for_all_decoder_layers, {});
295+
const AttentionScoresForEachDecoderLayer &attention_scores_for_all_decoder_layers, size_t num_snapkv_scores_aggregated) {
296+
register_new_token_scores(attention_scores_for_all_decoder_layers, {}, num_snapkv_scores_aggregated);
268297
}
269298

270299
void CacheEvictionAlgorithm::register_new_token_scores(
271300
const AttentionScoresForEachDecoderLayer &attention_scores_for_all_decoder_layers,
272-
const std::set<size_t>& skipped_logical_block_ids) {
273-
m_score_manager.register_new_token_scores(attention_scores_for_all_decoder_layers, skipped_logical_block_ids);
301+
const std::set<size_t>& skipped_logical_block_ids,
302+
size_t num_snapkv_scores_aggregated) {
303+
m_score_manager.register_new_token_scores(attention_scores_for_all_decoder_layers, skipped_logical_block_ids, num_snapkv_scores_aggregated);
274304
}
275305

276306

@@ -457,4 +487,26 @@ namespace ov::genai {
457487
return retval;
458488
}
459489

490+
size_t SnapKVScoreAggregationCalculator::get_num_token_scores_to_aggregate(size_t prompt_len, size_t num_scheduled_tokens, size_t num_processed_tokens) {
491+
if (m_snapkv_window_size == 0) {
492+
// If SnapKV is disabled, aggregate all available scores in this chunk
493+
return num_scheduled_tokens;
494+
}
495+
size_t first_scored_token_position = m_snapkv_window_size > prompt_len ? 0 : prompt_len - m_snapkv_window_size;
496+
size_t num_scored_token_positions_in_this_chunk = 0;
497+
size_t num_processed_tokens_before_this_chunk = num_processed_tokens;
498+
size_t num_processed_tokens_after_this_chunk = num_processed_tokens_before_this_chunk + num_scheduled_tokens;
499+
if (num_processed_tokens_after_this_chunk > first_scored_token_position) {
500+
if (num_processed_tokens_before_this_chunk > first_scored_token_position) {
501+
num_scored_token_positions_in_this_chunk = num_scheduled_tokens;
502+
}
503+
else {
504+
num_scored_token_positions_in_this_chunk = num_processed_tokens_after_this_chunk - first_scored_token_position;
505+
}
506+
507+
}
508+
return num_scored_token_positions_in_this_chunk;
509+
}
510+
460511
}
512+

src/cpp/src/continuous_batching/cache_eviction.hpp

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,17 @@ class EvictionScoreManager {
2929
* @param num_decoder_layers Number of independent KV caches (each corresponding to a single attention layer) in the underlying LLM.
3030
* @param max_pool_window_size Window size for the max pooling step applied to the newly registered scores before aggregation.
3131
* @param aggregation_mode Aggregation mode for the scores across register calls.
32-
* @param ignore_first_n_blocks Number of blocks from the beginning of the per-token score vector, the scores for which will be disregarded and never aggregated.
32+
* @param ignore_first_n_blocks Number of blocks from the beginning of the per-token score vector, the scores for which will
33+
* be disregarded and never aggregated.
34+
* @param snapkv_window_size Window size for the SnapKV algorithm in effect. If non-zero, then by the start of the generation phase
35+
* for the tracked sequence (when the total number of `num_snapkv_scores` passed to each `register_new_token_scores` call reaches
36+
* the `snapkv_window_size`) the internal occurence counters will be:
37+
* `| S | S | ... | S | S - 1 | S - 2 | ... | 2 | 1 |`,
38+
* where `S` is equal to `snapkv_window_size`. In contrast, if this is set to 0, then the initial counter state would be
39+
* `| L | L - 1 | ... | 2 | 1 |`,
40+
* where L is the prompt size of the sequence in tokens.
3341
*/
34-
explicit EvictionScoreManager(size_t block_size, size_t num_decoder_layers, size_t max_pool_window_size, AggregationMode aggregation_mode, size_t ignore_first_n_blocks = 0) : m_block_size(block_size), m_num_decoder_layers(num_decoder_layers), m_scores(num_decoder_layers), m_cache_counter(num_decoder_layers), m_max_pool_window_size(max_pool_window_size), m_aggregation_mode(aggregation_mode), m_ignore_first_n_blocks(ignore_first_n_blocks) {}
42+
explicit EvictionScoreManager(size_t block_size, size_t num_decoder_layers, size_t max_pool_window_size, AggregationMode aggregation_mode, size_t ignore_first_n_blocks = 0, size_t snapkv_window_size = 0) : m_block_size(block_size), m_num_decoder_layers(num_decoder_layers), m_scores(num_decoder_layers), m_cache_counter(num_decoder_layers), m_max_pool_window_size(max_pool_window_size), m_aggregation_mode(aggregation_mode), m_ignore_first_n_blocks(ignore_first_n_blocks), m_snapkv_window_size(snapkv_window_size), m_num_registered_snapkv_aggregated_scores(0) {}
3543

3644
/**
3745
* Registers new token scores and aggregates them internally as necessary. The token scores provided may be corresponding not to all
@@ -42,8 +50,9 @@ class EvictionScoreManager {
4250
* scores in a corresponding decoder layer.
4351
* @param skipped_logical_block_ids Logical block indices which had been skipped during inference call that produced the new scores, and
4452
* which are missing from the new scores.
53+
* @param num_snapkv_scores Number of latest token scores that were aggregated together when computing the registered score. If SnapKV is not used, this should be set to 0.
4554
*/
46-
void register_new_token_scores(const AttentionScoresForEachDecoderLayer& attention_scores_for_all_decoder_layers, const std::set<size_t>& skipped_logical_block_ids);
55+
void register_new_token_scores(const AttentionScoresForEachDecoderLayer& attention_scores_for_all_decoder_layers, const std::set<size_t>& skipped_logical_block_ids, size_t num_snapkv_scores = 0);
4756

4857
/**
4958
* Removes the scores from tracking for given block indices and given decoder layer.
@@ -88,6 +97,22 @@ class EvictionScoreManager {
8897
std::size_t m_max_pool_window_size;
8998
AggregationMode m_aggregation_mode;
9099
std::size_t m_ignore_first_n_blocks;
100+
std::size_t m_snapkv_window_size;
101+
std::size_t m_num_registered_snapkv_aggregated_scores;
102+
};
103+
104+
class SnapKVScoreAggregationCalculator {
105+
public:
106+
SnapKVScoreAggregationCalculator() = default;
107+
SnapKVScoreAggregationCalculator(const SnapKVScoreAggregationCalculator& rhs) = default;
108+
SnapKVScoreAggregationCalculator& operator=(const SnapKVScoreAggregationCalculator& rhs) = default;
109+
SnapKVScoreAggregationCalculator(size_t snapkv_window_size) : m_snapkv_window_size(snapkv_window_size) {}
110+
111+
size_t get_num_token_scores_to_aggregate(size_t prompt_len, size_t num_scheduled_tokens, size_t num_processed_tokens);
112+
113+
private:
114+
size_t m_snapkv_window_size;
115+
91116
};
92117

93118
/**
@@ -160,10 +185,12 @@ class CacheEvictionAlgorithm {
160185
* @param attention_scores_for_all_decoder_layers A vector with a size equal to the configured num_decoder_layers, where each entry is a
161186
* vector of per-token attention scores calculated within this layer.
162187
* @param skipped_logical_block_ids The set of logical indices that have been skipped from the scores as part of the sparse attention prefill process
188+
* @param num_snapkv_scores The number of SnapKV-aggregated scores in this score chunk. Set to 0 if SnapKV is not used
189+
* (i.e. eviction_config.snapkv_window_size == 0)
163190
*/
164-
void register_new_token_scores(const AttentionScoresForEachDecoderLayer& attention_scores_for_all_decoder_layers, const std::set<size_t>& skipped_logical_block_ids);
191+
void register_new_token_scores(const AttentionScoresForEachDecoderLayer& attention_scores_for_all_decoder_layers, const std::set<size_t>& skipped_logical_block_ids, size_t num_snapkv_scores = 0);
165192

166-
void register_new_token_scores(const AttentionScoresForEachDecoderLayer& attention_scores_across_decoder_layers_for_current_sequence);
193+
void register_new_token_scores(const AttentionScoresForEachDecoderLayer& attention_scores_across_decoder_layers_for_current_sequence, size_t num_snapkv_scores = 0);
167194
/**
168195
* Returns the per-layer sets of logical block indices that should be evicted according to the internally computed importance scores
169196
* and removes the corresponding blocks from the internal algorithm tracking.

src/cpp/src/continuous_batching/pipeline_impl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -673,7 +673,7 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::_maybe_evict_cache_bloc
673673

674674
if (skip_set.empty()) {
675675
// For now, will only register token scores from the dense attention stages
676-
cache_eviction_algo.register_new_token_scores(attention_scores_for_all_decoder_layers, skip_set);
676+
cache_eviction_algo.register_new_token_scores(attention_scores_for_all_decoder_layers, skip_set, scheduler_output.m_score_aggregation_windows.at(seq_id));
677677
}
678678

679679
auto seq_group_ptr_it = std::find_if(m_requests.begin(), m_requests.end(), [seq_id](const SequenceGroup::Ptr& val) { return val->has_sequence_with_id(seq_id); });

src/cpp/src/continuous_batching/scheduler.hpp

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "continuous_batching/timer.hpp"
1616
#include "continuous_batching/sparse_attention.hpp"
1717
#include "utils.hpp"
18+
#include "continuous_batching/cache_eviction.hpp"
1819

1920
namespace ov::genai {
2021
class Scheduler {
@@ -394,10 +395,7 @@ class Scheduler {
394395
// block tables for each running sequence within a group
395396
scheduler_output.m_block_tables[seq_id] = m_block_manager->get_block_tables(seq_id);
396397

397-
if (seq->get_generated_len() == 0) {
398-
// full prompt or its remaining tail part fit completely into the next inference
399-
scheduler_output.m_score_aggregation_windows[seq_id] = _schedule_scores_to_aggregate(sequence_group);
400-
}
398+
scheduler_output.m_score_aggregation_windows[seq_id] = _schedule_scores_to_aggregate(sequence_group);
401399
scheduler_output.m_xattention_block_size = m_config.sparse_attention_config.xattention_block_size;
402400
scheduler_output.m_xattention_stride = m_config.sparse_attention_config.xattention_stride;
403401
}
@@ -585,22 +583,13 @@ class Scheduler {
585583
}
586584

587585
size_t _schedule_scores_to_aggregate(SequenceGroup::Ptr sequence_group) {
588-
size_t prompt_len = sequence_group->get_prompt_len();
589-
size_t first_scored_token_position = m_snapkv_window_size > prompt_len ? 0 : prompt_len - m_snapkv_window_size;
590-
size_t num_scored_token_positions_in_this_chunk = 0;
591-
size_t num_processed_tokens_before_this_chunk = sequence_group->get_num_processed_tokens();
586+
auto calculator = SnapKVScoreAggregationCalculator(m_snapkv_window_size);
587+
592588
size_t num_scheduled_tokens = sequence_group->get_num_scheduled_tokens();
593-
size_t num_processed_tokens_after_this_chunk = num_processed_tokens_before_this_chunk + num_scheduled_tokens;
594-
if (num_processed_tokens_after_this_chunk > first_scored_token_position) {
595-
if (num_processed_tokens_before_this_chunk > first_scored_token_position) {
596-
num_scored_token_positions_in_this_chunk = num_scheduled_tokens;
597-
}
598-
else {
599-
num_scored_token_positions_in_this_chunk = num_processed_tokens_after_this_chunk - first_scored_token_position;
600-
}
589+
size_t num_processed_tokens = sequence_group->get_num_processed_tokens();
590+
size_t prompt_len = sequence_group->get_prompt_len();
601591

602-
}
603-
return num_scored_token_positions_in_this_chunk;
592+
return calculator.get_num_token_scores_to_aggregate(prompt_len, num_scheduled_tokens, num_processed_tokens);
604593
}
605594

606595
float _schedule_xattention_threshold(SequenceGroup::Ptr sequence_group) {

0 commit comments

Comments
 (0)