@@ -764,13 +764,14 @@ SequenceGroupSamplingInfo Sampler::sample_from_sequence_group(SequenceGroup::Ptr
764
764
const size_t output_seq_len = sequence_group->get_output_seq_len ();
765
765
// get number of tokens to be validated
766
766
size_t num_tokens_to_process = sequence_group->get_num_tokens_to_validate ();
767
+ size_t num_generated_tokens_to_validate = num_tokens_to_process;
767
768
768
769
if (num_tokens_to_process > output_seq_len - 1 ) {
769
770
auto delta = num_tokens_to_process - (output_seq_len - 1 );
770
771
assisting_pipeline_info.updated_validation_len = std::max (assisting_pipeline_info.updated_validation_len , delta);
771
772
num_tokens_to_process -= delta;
772
773
}
773
-
774
+
774
775
if (sampling_params.is_greedy_decoding () || sampling_params.is_multinomial ()) {
775
776
std::vector<Sequence::Ptr> running_sequences = sequence_group->get_running_sequences ();
776
777
size_t num_running_sequences = sequence_group->num_running_seqs ();
@@ -786,25 +787,25 @@ SequenceGroupSamplingInfo Sampler::sample_from_sequence_group(SequenceGroup::Ptr
786
787
break ;
787
788
sg_sampling_info.sampler_output .num_generated_tokens ++;
788
789
// calculate token offset from the end of logit
789
- size_t token_offset = num_tokens_to_process - i;
790
+ size_t logit_token_offset = num_tokens_to_process - i;
791
+ size_t generated_seq_token_offset = num_generated_tokens_to_validate - i;
790
792
// max counter of needed to be sampled tokens
791
- OPENVINO_ASSERT (running_sequence->get_generated_len () >= token_offset );
792
- size_t generated_and_verified_len = running_sequence->get_generated_len () - token_offset ;
793
+ OPENVINO_ASSERT (running_sequence->get_generated_len () >= generated_seq_token_offset );
794
+ size_t generated_and_verified_len = running_sequence->get_generated_len () - generated_seq_token_offset ;
793
795
OPENVINO_ASSERT (sequence_group->get_max_new_tokens () >= generated_and_verified_len);
794
796
size_t max_num_sampled_token = sequence_group->get_max_new_tokens () - generated_and_verified_len;
795
797
if (max_num_sampled_token == 0 ) {
796
- stop_sample_tokens (running_sequence, token_offset , max_num_sampled_token, assisting_pipeline_info.max_removed_tokens_per_request );
798
+ stop_sample_tokens (running_sequence, generated_seq_token_offset , max_num_sampled_token, assisting_pipeline_info.max_removed_tokens_per_request );
797
799
break ;
798
800
}
799
-
800
801
// do sampling only for token validation/generation.
801
802
// continue in case of extending draft model sequences by main model generated tokens which
802
803
// should be taken to KV cache without validation
803
- if (!is_validation_mode_enabled && token_offset > 0 ) {
804
+ if (!is_validation_mode_enabled && generated_seq_token_offset > 0 ) {
804
805
continue ;
805
806
}
806
807
807
- auto logit_vector = _get_logit_vector (sequence_group_logits, running_sequence_id, token_offset );
808
+ auto logit_vector = _get_logit_vector (sequence_group_logits, running_sequence_id, logit_token_offset );
808
809
logit_processor.apply (logit_vector);
809
810
810
811
Token sampled_token;
@@ -826,8 +827,8 @@ SequenceGroupSamplingInfo Sampler::sample_from_sequence_group(SequenceGroup::Ptr
826
827
sampled_token = sampled_token_ids.front ();
827
828
// make `_speculative_sampling` in case of previous token was not accepted in speculative decoding
828
829
if (!is_validation_passed) {
829
- float p_prime = get_p_prime (running_sequence, sampled_token, token_offset + 1 );
830
- assisting_pipeline_info.max_removed_tokens_per_request = std::max (assisting_pipeline_info.max_removed_tokens_per_request , token_offset );
830
+ float p_prime = get_p_prime (running_sequence, sampled_token, generated_seq_token_offset + 1 );
831
+ assisting_pipeline_info.max_removed_tokens_per_request = std::max (assisting_pipeline_info.max_removed_tokens_per_request , generated_seq_token_offset );
831
832
// update prob only in case candidate prob > sampled token prob
832
833
if (p_prime > 0 .f ) {
833
834
auto prob = std::exp (sampled_token.m_log_prob );
@@ -837,12 +838,13 @@ SequenceGroupSamplingInfo Sampler::sample_from_sequence_group(SequenceGroup::Ptr
837
838
}
838
839
}
839
840
// flag to add sampled token to generated sequence or extend logit processors only
840
- bool is_extend_sequence = token_offset == 0 || is_generate_n_tokens || !is_validation_passed;
841
+ bool is_extend_sequence = logit_token_offset == 0 || is_generate_n_tokens || !is_validation_passed;
841
842
if (is_validation_mode_enabled && !is_extend_sequence) {
842
- is_validation_passed = validate_candidate (running_sequences[running_sequence_id], token_offset, sampled_token ,
843
- is_extend_sequence, assisting_pipeline_info.max_removed_tokens_per_request ,
843
+ is_validation_passed = validate_candidate (running_sequences[running_sequence_id], generated_seq_token_offset ,
844
+ sampled_token, is_extend_sequence, assisting_pipeline_info.max_removed_tokens_per_request ,
844
845
sampling_params.do_sample , !sampling_params.is_prompt_lookup ());
845
- // doing resample in case of non accepted tokens in specualtive sampling, if candidates have real logits
846
+
847
+ // doing resample in case of non accepted tokens in speculative sampling
846
848
if (!is_validation_passed && sampling_params.do_sample && !sampling_params.is_prompt_lookup ()) {
847
849
continue ;
848
850
}
@@ -897,7 +899,7 @@ SequenceGroupSamplingInfo Sampler::sample_from_sequence_group(SequenceGroup::Ptr
897
899
}
898
900
// Notify handle after sampling is done.
899
901
// For non-streaming this is effective only when the generation is finished.
900
- OPENVINO_ASSERT (num_tokens_to_process >= assisting_pipeline_info.max_removed_tokens_per_request );
902
+ OPENVINO_ASSERT (num_generated_tokens_to_validate >= assisting_pipeline_info.max_removed_tokens_per_request );
901
903
sequence_group->notify_handle ();
902
904
return sg_sampling_info;
903
905
}
0 commit comments