Skip to content

Commit d012b5c

Browse files
authored
whisper : add "split_on_word" flag when using using "max_len" option (#455)
* Update whisper.cpp * fix: trim function * feat: added flag to split on word * fix: arguments for main
1 parent b2083c5 commit d012b5c

File tree

3 files changed

+39
-5
lines changed

3 files changed

+39
-5
lines changed

examples/main/main.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ struct whisper_params {
6969
bool speed_up = false;
7070
bool translate = false;
7171
bool diarize = false;
72+
bool split_on_word = false;
7273
bool no_fallback = false;
7374
bool output_txt = false;
7475
bool output_vtt = false;
@@ -118,6 +119,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
118119
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
119120
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
120121
else if (arg == "-di" || arg == "--diarize") { params.diarize = true; }
122+
else if (arg == "-sow" || arg == "--split-on-word") { params.split_on_word = true; }
121123
else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; }
122124
else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; }
123125
else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; }
@@ -156,6 +158,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
156158
fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms);
157159
fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context);
158160
fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len);
161+
fprintf(stderr, " -sow, --split-on-word [%-7s] split on word rather than on token\n", params.split_on_word ? "true" : "false");
159162
fprintf(stderr, " -bo N, --best-of N [%-7d] number of best candidates to keep\n", params.best_of);
160163
fprintf(stderr, " -bs N, --beam-size N [%-7d] beam size for beam search\n", params.beam_size);
161164
fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold);
@@ -651,6 +654,7 @@ int main(int argc, char ** argv) {
651654
wparams.token_timestamps = params.output_wts || params.max_len > 0;
652655
wparams.thold_pt = params.word_thold;
653656
wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
657+
wparams.split_on_word = params.split_on_word;
654658

655659
wparams.speed_up = params.speed_up;
656660

whisper.cpp

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2922,6 +2922,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
29222922
/*.thold_pt =*/ 0.01f,
29232923
/*.thold_ptsum =*/ 0.01f,
29242924
/*.max_len =*/ 0,
2925+
/*.split_on_word =*/ false,
29252926
/*.max_tokens =*/ 0,
29262927

29272928
/*.speed_up =*/ false,
@@ -2988,9 +2989,36 @@ static void whisper_exp_compute_token_level_timestamps(
29882989
float thold_pt,
29892990
float thold_ptsum);
29902991

2992+
// trim from start (in place)
2993+
static inline void ltrim(std::string &s) {
2994+
s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](unsigned char ch) {
2995+
return !std::isspace(ch);
2996+
}));
2997+
}
2998+
2999+
// trim from end (in place)
3000+
static inline void rtrim(std::string &s) {
3001+
s.erase(std::find_if(s.rbegin(), s.rend(), [](unsigned char ch) {
3002+
return !std::isspace(ch);
3003+
}).base(), s.end());
3004+
}
3005+
3006+
// trim from both ends (in place)
3007+
static inline void trim(std::string &s) {
3008+
rtrim(s);
3009+
ltrim(s);
3010+
}
3011+
3012+
static inline bool should_split_on_word(const char * txt, bool split_on_word) {
3013+
if (!split_on_word) return true;
3014+
3015+
std::string s = txt;
3016+
return s.substr(0, 1) == " ";
3017+
}
3018+
29913019
// wrap the last segment to max_len characters
29923020
// returns the number of new segments
2993-
static int whisper_wrap_segment(struct whisper_context & ctx, int max_len) {
3021+
static int whisper_wrap_segment(struct whisper_context & ctx, int max_len, bool split_on_word) {
29943022
auto segment = ctx.result_all.back();
29953023

29963024
int res = 1;
@@ -3005,11 +3033,11 @@ static int whisper_wrap_segment(struct whisper_context & ctx, int max_len) {
30053033
}
30063034

30073035
const auto txt = whisper_token_to_str(&ctx, token.id);
3008-
30093036
const int cur = strlen(txt);
30103037

3011-
if (acc + cur > max_len && i > 0) {
3038+
if (acc + cur > max_len && i > 0 && should_split_on_word(txt, split_on_word)) {
30123039
// split here
3040+
trim(text);
30133041
ctx.result_all.back().text = std::move(text);
30143042
ctx.result_all.back().t1 = token.t0;
30153043
ctx.result_all.back().tokens.resize(i);
@@ -3037,6 +3065,7 @@ static int whisper_wrap_segment(struct whisper_context & ctx, int max_len) {
30373065
}
30383066
}
30393067

3068+
trim(text);
30403069
ctx.result_all.back().text = std::move(text);
30413070

30423071
return res;
@@ -4069,7 +4098,7 @@ int whisper_full(
40694098
*ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
40704099

40714100
if (params.max_len > 0) {
4072-
n_new = whisper_wrap_segment(*ctx, params.max_len);
4101+
n_new = whisper_wrap_segment(*ctx, params.max_len, params.split_on_word);
40734102
}
40744103
}
40754104
if (params.new_segment_callback) {
@@ -4113,7 +4142,7 @@ int whisper_full(
41134142
*ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
41144143

41154144
if (params.max_len > 0) {
4116-
n_new = whisper_wrap_segment(*ctx, params.max_len);
4145+
n_new = whisper_wrap_segment(*ctx, params.max_len, params.split_on_word);
41174146
}
41184147
}
41194148
if (params.new_segment_callback) {

whisper.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,7 @@ extern "C" {
257257
float thold_pt; // timestamp token probability threshold (~0.01)
258258
float thold_ptsum; // timestamp token sum probability threshold (~0.01)
259259
int max_len; // max segment length in characters
260+
bool split_on_word; // split on word rather than on token (when used with max_len)
260261
int max_tokens; // max tokens per segment (0 = no limit)
261262

262263
// [EXPERIMENTAL] speed-up techniques

0 commit comments

Comments
 (0)