Skip to content

Commit b248532

Browse files
authored
Align PromptLookupDecoding with greedy when dynamic_split_fuse works (#2360)
[CVS-169291](https://jira.devtools.intel.com/browse/CVS-169291) when dynamic_split_fuse is true and max_num_batched_tokens less them prompt len, then PromptLookupDecoding Pipeline starts to generate candidates before full prompt will be put to cache, which causes to difference with greedy pipeline, let's avoid that
1 parent cb3f1de commit b248532

File tree

4 files changed

+48
-15
lines changed

4 files changed

+48
-15
lines changed

src/cpp/src/prompt_lookup/continuous_batching_for_prompt_lookup.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ void ContinuousBatchingPipeline::ContinuousBatchingForPromptLookupImpl::generate
6060
size_t max_validation_len = 0;
6161
for (auto& running_sequence : request->get_running_sequences()) {
6262
const auto generated_tokens = running_sequence->get_generated_ids();
63+
if (generated_tokens.empty()) {
64+
continue;
65+
}
6366
TokenIds full_input_ids = prompt;
6467
full_input_ids.insert(full_input_ids.end(), generated_tokens.begin(), generated_tokens.end());
6568

src/cpp/src/sampling/sampler.cpp

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -764,13 +764,14 @@ SequenceGroupSamplingInfo Sampler::sample_from_sequence_group(SequenceGroup::Ptr
764764
const size_t output_seq_len = sequence_group->get_output_seq_len();
765765
// get number of tokens to be validated
766766
size_t num_tokens_to_process = sequence_group->get_num_tokens_to_validate();
767+
size_t num_generated_tokens_to_validate = num_tokens_to_process;
767768

768769
if (num_tokens_to_process > output_seq_len - 1) {
769770
auto delta = num_tokens_to_process - (output_seq_len - 1);
770771
assisting_pipeline_info.updated_validation_len = std::max(assisting_pipeline_info.updated_validation_len, delta);
771772
num_tokens_to_process -= delta;
772773
}
773-
774+
774775
if (sampling_params.is_greedy_decoding() || sampling_params.is_multinomial()) {
775776
std::vector<Sequence::Ptr> running_sequences = sequence_group->get_running_sequences();
776777
size_t num_running_sequences = sequence_group->num_running_seqs();
@@ -786,25 +787,25 @@ SequenceGroupSamplingInfo Sampler::sample_from_sequence_group(SequenceGroup::Ptr
786787
break;
787788
sg_sampling_info.sampler_output.num_generated_tokens++;
788789
// calculate token offset from the end of logit
789-
size_t token_offset = num_tokens_to_process - i;
790+
size_t logit_token_offset = num_tokens_to_process - i;
791+
size_t generated_seq_token_offset = num_generated_tokens_to_validate - i;
790792
// max counter of needed to be sampled tokens
791-
OPENVINO_ASSERT(running_sequence->get_generated_len() >= token_offset);
792-
size_t generated_and_verified_len = running_sequence->get_generated_len() - token_offset;
793+
OPENVINO_ASSERT(running_sequence->get_generated_len() >= generated_seq_token_offset);
794+
size_t generated_and_verified_len = running_sequence->get_generated_len() - generated_seq_token_offset;
793795
OPENVINO_ASSERT(sequence_group->get_max_new_tokens() >= generated_and_verified_len);
794796
size_t max_num_sampled_token = sequence_group->get_max_new_tokens() - generated_and_verified_len;
795797
if (max_num_sampled_token == 0) {
796-
stop_sample_tokens(running_sequence, token_offset, max_num_sampled_token, assisting_pipeline_info.max_removed_tokens_per_request);
798+
stop_sample_tokens(running_sequence, generated_seq_token_offset, max_num_sampled_token, assisting_pipeline_info.max_removed_tokens_per_request);
797799
break;
798800
}
799-
800801
// do sampling only for token validation/generation.
801802
// continue in case of extending draft model sequences by main model generated tokens which
802803
// should be taken to KV cache without validation
803-
if (!is_validation_mode_enabled && token_offset > 0) {
804+
if (!is_validation_mode_enabled && generated_seq_token_offset > 0) {
804805
continue;
805806
}
806807

807-
auto logit_vector = _get_logit_vector(sequence_group_logits, running_sequence_id, token_offset);
808+
auto logit_vector = _get_logit_vector(sequence_group_logits, running_sequence_id, logit_token_offset);
808809
logit_processor.apply(logit_vector);
809810

810811
Token sampled_token;
@@ -826,8 +827,8 @@ SequenceGroupSamplingInfo Sampler::sample_from_sequence_group(SequenceGroup::Ptr
826827
sampled_token = sampled_token_ids.front();
827828
// make `_speculative_sampling` in case of previous token was not accepted in speculative decoding
828829
if (!is_validation_passed) {
829-
float p_prime = get_p_prime(running_sequence, sampled_token, token_offset + 1);
830-
assisting_pipeline_info.max_removed_tokens_per_request = std::max(assisting_pipeline_info.max_removed_tokens_per_request, token_offset);
830+
float p_prime = get_p_prime(running_sequence, sampled_token, generated_seq_token_offset + 1);
831+
assisting_pipeline_info.max_removed_tokens_per_request = std::max(assisting_pipeline_info.max_removed_tokens_per_request, generated_seq_token_offset);
831832
// update prob only in case candidate prob > sampled token prob
832833
if (p_prime > 0.f) {
833834
auto prob = std::exp(sampled_token.m_log_prob);
@@ -837,12 +838,13 @@ SequenceGroupSamplingInfo Sampler::sample_from_sequence_group(SequenceGroup::Ptr
837838
}
838839
}
839840
// flag to add sampled token to generated sequence or extend logit processors only
840-
bool is_extend_sequence = token_offset == 0 || is_generate_n_tokens || !is_validation_passed;
841+
bool is_extend_sequence = logit_token_offset == 0 || is_generate_n_tokens || !is_validation_passed;
841842
if (is_validation_mode_enabled && !is_extend_sequence) {
842-
is_validation_passed = validate_candidate(running_sequences[running_sequence_id], token_offset, sampled_token,
843-
is_extend_sequence, assisting_pipeline_info.max_removed_tokens_per_request,
843+
is_validation_passed = validate_candidate(running_sequences[running_sequence_id], generated_seq_token_offset,
844+
sampled_token, is_extend_sequence, assisting_pipeline_info.max_removed_tokens_per_request,
844845
sampling_params.do_sample, !sampling_params.is_prompt_lookup());
845-
// doing resample in case of non accepted tokens in specualtive sampling, if candidates have real logits
846+
847+
// doing resample in case of non accepted tokens in speculative sampling
846848
if (!is_validation_passed && sampling_params.do_sample && !sampling_params.is_prompt_lookup()) {
847849
continue;
848850
}
@@ -897,7 +899,7 @@ SequenceGroupSamplingInfo Sampler::sample_from_sequence_group(SequenceGroup::Ptr
897899
}
898900
// Notify handle after sampling is done.
899901
// For non-streaming this is effective only when the generation is finished.
900-
OPENVINO_ASSERT(num_tokens_to_process >= assisting_pipeline_info.max_removed_tokens_per_request);
902+
OPENVINO_ASSERT(num_generated_tokens_to_validate >= assisting_pipeline_info.max_removed_tokens_per_request);
901903
sequence_group->notify_handle();
902904
return sg_sampling_info;
903905
}

src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,9 @@ ContinuousBatchingPipeline::SpeculativeDecodingImpl::SpeculativeDecodingImpl(con
7474

7575
main_scheduler_config_updated.cache_size = main_cache_size;
7676
draft_scheduler_config.cache_size = draft_cache_size;
77+
} else {
78+
draft_scheduler_config.dynamic_split_fuse = main_scheduler_config_updated.dynamic_split_fuse;
79+
draft_scheduler_config.max_num_batched_tokens = main_scheduler_config_updated.max_num_batched_tokens;
7780
}
7881

7982
ov::AnyMap draft_properties = draft_model_desc.properties.empty() ? main_model_desc.properties : draft_model_desc.properties;

tests/python_tests/test_continuous_batching.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import os
55
import pytest
66
import math
7+
import sys
78

89
from pathlib import Path
910
from shutil import rmtree
@@ -414,6 +415,30 @@ def test_preemption_with_multinomial_n_seq(dynamic_split_fuse):
414415
scheduler_config=scheduler_config)
415416

416417

418+
@pytest.mark.parametrize("pipeline_type", [PipelineType.PROMPT_LOOKUP_DECODING])
419+
@pytest.mark.precommit
420+
def test_dynamic_split_fuse_doesnt_affect_generated_text(pipeline_type):
421+
model_id : str = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
422+
_, _, models_path = download_and_convert_model(model_id)
423+
424+
scheduler_config_ref = dict_to_scheduler_config({"dynamic_split_fuse": False, "max_num_batched_tokens": sys.maxsize})
425+
cb_pipe_ref = create_ov_pipeline(models_path, scheduler_config=scheduler_config_ref, pipeline_type=pipeline_type)
426+
427+
scheduler_config_target = dict_to_scheduler_config({"dynamic_split_fuse": True, "max_num_batched_tokens": 5})
428+
cb_pipe_target = create_ov_pipeline(models_path, scheduler_config=scheduler_config_target, pipeline_type=pipeline_type)
429+
430+
generation_config = GenerationConfig(do_sample=False, max_new_tokens=20, eos_token_id=cb_pipe_ref.get_tokenizer().get_eos_token_id())
431+
432+
generation_config = prepare_generation_config_by_pipe_type(generation_config=generation_config, pipeline_type=pipeline_type)
433+
cb_pipe_ref.set_generation_config(generation_config)
434+
cb_pipe_target.set_generation_config(generation_config)
435+
436+
question = "Why is the Sun yellow?"
437+
reference = cb_pipe_ref.generate(question, generation_config=generation_config)
438+
generated = cb_pipe_target.generate(question, generation_config=generation_config)
439+
assert generated == reference
440+
441+
417442
def get_data_by_pipeline_type(model_path: Path, pipeline_type: str, generation_config: GenerationConfig):
418443
device = "CPU"
419444
prompt = "Prompt example is"

0 commit comments

Comments
 (0)