Skip to content

Commit eefed45

Browse files
authored
whisper : add initial_prompt param (#645)
1 parent aac1710 commit eefed45

File tree

4 files changed

+13
-36
lines changed

4 files changed

+13
-36
lines changed

examples/addon.node/addon.cpp

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -160,22 +160,6 @@ int run(whisper_params &params, std::vector<std::vector<std::string>> &result) {
160160
return 3;
161161
}
162162

163-
// initial prompt
164-
std::vector<whisper_token> prompt_tokens;
165-
166-
if (!params.prompt.empty()) {
167-
prompt_tokens.resize(1024);
168-
prompt_tokens.resize(whisper_tokenize(ctx, params.prompt.c_str(), prompt_tokens.data(), prompt_tokens.size()));
169-
170-
fprintf(stderr, "\n");
171-
fprintf(stderr, "initial prompt: '%s'\n", params.prompt.c_str());
172-
fprintf(stderr, "initial tokens: [ ");
173-
for (int i = 0; i < (int) prompt_tokens.size(); ++i) {
174-
fprintf(stderr, "%d ", prompt_tokens[i]);
175-
}
176-
fprintf(stderr, "]\n");
177-
}
178-
179163
for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
180164
const auto fname_inp = params.fname_inp[f];
181165
const auto fname_out = f < (int)params.fname_out.size() && !params.fname_out[f].empty() ? params.fname_out[f] : params.fname_inp[f];
@@ -243,8 +227,7 @@ int run(whisper_params &params, std::vector<std::vector<std::string>> &result) {
243227
wparams.greedy.best_of = params.best_of;
244228
wparams.beam_search.beam_size = params.beam_size;
245229

246-
wparams.prompt_tokens = prompt_tokens.empty() ? nullptr : prompt_tokens.data();
247-
wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size();
230+
wparams.initial_prompt = params.prompt.c_str();
248231

249232
whisper_print_user_data user_data = { &params, &pcmf32s };
250233

examples/main/main.cpp

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -639,22 +639,6 @@ int main(int argc, char ** argv) {
639639
return 3;
640640
}
641641

642-
// initial prompt
643-
std::vector<whisper_token> prompt_tokens;
644-
645-
if (!params.prompt.empty()) {
646-
prompt_tokens.resize(1024);
647-
prompt_tokens.resize(whisper_tokenize(ctx, params.prompt.c_str(), prompt_tokens.data(), prompt_tokens.size()));
648-
649-
fprintf(stderr, "\n");
650-
fprintf(stderr, "initial prompt: '%s'\n", params.prompt.c_str());
651-
fprintf(stderr, "initial tokens: [ ");
652-
for (int i = 0; i < (int) prompt_tokens.size(); ++i) {
653-
fprintf(stderr, "%d ", prompt_tokens[i]);
654-
}
655-
fprintf(stderr, "]\n");
656-
}
657-
658642
for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
659643
const auto fname_inp = params.fname_inp[f];
660644
const auto fname_out = f < (int) params.fname_out.size() && !params.fname_out[f].empty() ? params.fname_out[f] : params.fname_inp[f];
@@ -718,8 +702,7 @@ int main(int argc, char ** argv) {
718702

719703
wparams.speed_up = params.speed_up;
720704

721-
wparams.prompt_tokens = prompt_tokens.empty() ? nullptr : prompt_tokens.data();
722-
wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size();
705+
wparams.initial_prompt = params.prompt.c_str();
723706

724707
wparams.greedy.best_of = params.best_of;
725708
wparams.beam_search.beam_size = params.beam_size;

whisper.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3121,6 +3121,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
31213121
/*.speed_up =*/ false,
31223122
/*.audio_ctx =*/ 0,
31233123

3124+
/*.initial_prompt =*/ nullptr,
31243125
/*.prompt_tokens =*/ nullptr,
31253126
/*.prompt_n_tokens =*/ 0,
31263127

@@ -3793,6 +3794,15 @@ int whisper_full_with_state(
37933794
prompt_past.clear();
37943795
}
37953796

3797+
// initial prompt
3798+
if (!params.prompt_tokens && params.initial_prompt) {
3799+
std::vector<whisper_token> prompt_tokens;
3800+
prompt_tokens.resize(1024);
3801+
prompt_tokens.resize(whisper_tokenize(ctx, params.initial_prompt, prompt_tokens.data(), prompt_tokens.size()));
3802+
params.prompt_tokens = prompt_tokens.data();
3803+
params.prompt_n_tokens = prompt_tokens.size();
3804+
}
3805+
37963806
// prepend the prompt tokens to the prompt_past
37973807
if (params.prompt_tokens && params.prompt_n_tokens > 0) {
37983808
// parse tokens from the pointer

whisper.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,7 @@ extern "C" {
356356

357357
// tokens to provide to the whisper decoder as initial prompt
358358
// these are prepended to any existing text context from a previous call
359+
const char * initial_prompt;
359360
const whisper_token * prompt_tokens;
360361
int prompt_n_tokens;
361362

0 commit comments

Comments
 (0)