Skip to content

Commit 021eef1

Browse files
ruby : Add low-level methods to transcribe (#2585)
* Add tests for Whisper::Context#full * Add Whisper::Context#full * Add tests for Whisper::Error * Add document of Whisper::Context#full [skip ci] * Add additional signature for Whisper::Context#full * Add description to Whisper::Context#full * Add test for Whisper::Context#full_parallel * Add Whisper::Context#full_parallel * Hide Whisper's instance methods from Ruby code * Add class to test MemoryView * Build test class before running test * Add test for MemoryView * Make Whisper::Context#full and #full_parallel accept MemoryView * Use Ruby 3.1 on CI * Add comment on samples data type * Update README * Update README * Remove unused code
1 parent a9d06ce commit 021eef1

File tree

10 files changed

+486
-6
lines changed

10 files changed

+486
-6
lines changed

.github/workflows/bindings-ruby.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,6 @@ jobs:
5050
steps:
5151
- uses: ruby/setup-ruby@v1
5252
with:
53-
ruby-version: '3.0'
53+
ruby-version: '3.1'
5454
- uses: actions/checkout@v4
5555
- run: rake test

bindings/ruby/README.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,24 @@ Whisper.log_set ->(level, buffer, user_data) {
160160
Whisper::Context.new(MODEL)
161161
```
162162

163+
You can also call `Whisper::Context#full` and `#full_parallel` with a Ruby array as samples. Although `#transcribe` with audio file path is recommended because it extracts PCM samples in C++ and is fast, `#full` and `#full_parallel` give you flexibility.
164+
165+
```ruby
166+
require "whisper"
167+
require "wavefile"
168+
169+
reader = WaveFile::Reader.new("path/to/audio.wav", WaveFile::Format.new(:mono, :float, 16000))
170+
samples = reader.enum_for(:each_buffer).map(&:samples).flatten
171+
172+
whisper = Whisper::Context.new("path/to/model.bin")
173+
whisper.full(Whisper::Params.new, samples)
174+
whisper.each_segment do |segment|
175+
puts segment.text
176+
end
177+
```
178+
179+
The second argument `samples` may be an array, an object with `length` method, or a MemoryView. If you can prepare audio data as C array and export it as a MemoryView, whispercpp accepts and works with it with zero copy.
180+
163181
License
164182
-------
165183

bindings/ruby/Rakefile

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,3 +68,13 @@ file TEST_MODEL do
6868
sh "./models/download-ggml-model.sh base.en"
6969
end
7070
end
71+
72+
TEST_MEMORY_VIEW = "tests/jfk_reader/jfk_reader.#{RbConfig::CONFIG['DLEXT']}"
73+
file TEST_MEMORY_VIEW => "tests/jfk_reader/jfk_reader.c" do |t|
74+
Dir.chdir "tests/jfk_reader" do
75+
ruby "extconf.rb"
76+
sh "make"
77+
end
78+
end
79+
CLEAN.include "tests/jfk_reader/jfk_reader.{o,#{RbConfig::CONFIG['DLEXT']}}"
80+
task test: TEST_MEMORY_VIEW

bindings/ruby/ext/ruby_whisper.cpp

Lines changed: 221 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include <ruby.h>
2+
#include <ruby/memory_view.h>
23
#include "ruby_whisper.h"
34
#define DR_WAV_IMPLEMENTATION
45
#include "dr_wav.h"
@@ -35,11 +36,15 @@ extern "C" {
3536
VALUE mWhisper;
3637
VALUE cContext;
3738
VALUE cParams;
39+
VALUE eError;
3840

3941
static ID id_to_s;
4042
static ID id_call;
4143
static ID id___method__;
4244
static ID id_to_enum;
45+
static ID id_length;
46+
static ID id_next;
47+
static ID id_new;
4348

4449
static bool is_log_callback_finalized = false;
4550

@@ -100,13 +105,13 @@ static VALUE ruby_whisper_s_finalize_log_callback(VALUE self, VALUE id) {
100105
* log_set ->(level, buffer, user_data) { ... }, user_data -> nil
101106
*/
102107
static VALUE ruby_whisper_s_log_set(VALUE self, VALUE log_callback, VALUE user_data) {
103-
VALUE old_callback = rb_iv_get(self, "@log_callback");
108+
VALUE old_callback = rb_iv_get(self, "log_callback");
104109
if (!NIL_P(old_callback)) {
105110
rb_undefine_finalizer(old_callback);
106111
}
107112

108-
rb_iv_set(self, "@log_callback", log_callback);
109-
rb_iv_set(self, "@user_data", user_data);
113+
rb_iv_set(self, "log_callback", log_callback);
114+
rb_iv_set(self, "user_data", user_data);
110115

111116
VALUE finalize_log_callback = rb_funcall(mWhisper, rb_intern("method"), 1, rb_str_new2("finalize_log_callback"));
112117
rb_define_finalizer(log_callback, finalize_log_callback);
@@ -115,8 +120,8 @@ static VALUE ruby_whisper_s_log_set(VALUE self, VALUE log_callback, VALUE user_d
115120
if (is_log_callback_finalized) {
116121
return;
117122
}
118-
VALUE log_callback = rb_iv_get(mWhisper, "@log_callback");
119-
VALUE udata = rb_iv_get(mWhisper, "@user_data");
123+
VALUE log_callback = rb_iv_get(mWhisper, "log_callback");
124+
VALUE udata = rb_iv_get(mWhisper, "user_data");
120125
rb_funcall(log_callback, id_call, 3, INT2NUM(level), rb_str_new2(buffer), udata);
121126
}, nullptr);
122127

@@ -544,6 +549,168 @@ VALUE ruby_whisper_model_type(VALUE self) {
544549
return rb_str_new2(whisper_model_type_readable(rw->context));
545550
}
546551

552+
/*
553+
* Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
554+
* Not thread safe for same context
555+
* Uses the specified decoding strategy to obtain the text.
556+
*
557+
* call-seq:
558+
* full(params, samples, n_samples) -> nil
559+
* full(params, samples) -> nil
560+
*
561+
* The second argument +samples+ must be an array of samples, respond to :length, or be a MemoryView of an array of float. It must be 32 bit float PCM audio data.
562+
*/
563+
VALUE ruby_whisper_full(int argc, VALUE *argv, VALUE self) {
564+
if (argc < 2 || argc > 3) {
565+
rb_raise(rb_eArgError, "wrong number of arguments (given %d, expected 2..3)", argc);
566+
}
567+
568+
ruby_whisper *rw;
569+
ruby_whisper_params *rwp;
570+
Data_Get_Struct(self, ruby_whisper, rw);
571+
VALUE params = argv[0];
572+
Data_Get_Struct(params, ruby_whisper_params, rwp);
573+
VALUE samples = argv[1];
574+
int n_samples;
575+
rb_memory_view_t view;
576+
const bool memory_view_available_p = rb_memory_view_available_p(samples);
577+
if (argc == 3) {
578+
n_samples = NUM2INT(argv[2]);
579+
if (TYPE(samples) == T_ARRAY) {
580+
if (RARRAY_LEN(samples) < n_samples) {
581+
rb_raise(rb_eArgError, "samples length %ld is less than n_samples %d", RARRAY_LEN(samples), n_samples);
582+
}
583+
}
584+
// Should check when samples.respond_to?(:length)?
585+
} else {
586+
if (TYPE(samples) == T_ARRAY) {
587+
n_samples = RARRAY_LEN(samples);
588+
} else if (memory_view_available_p) {
589+
if (!rb_memory_view_get(samples, &view, RUBY_MEMORY_VIEW_SIMPLE)) {
590+
view.obj = Qnil;
591+
rb_raise(rb_eArgError, "unable to get a memory view");
592+
}
593+
n_samples = view.byte_size / view.item_size;
594+
} else if (rb_respond_to(samples, id_length)) {
595+
n_samples = NUM2INT(rb_funcall(samples, id_length, 0));
596+
} else {
597+
rb_raise(rb_eArgError, "samples must respond to :length or be a MemoryView of an array of flaot when n_samples is not given");
598+
}
599+
}
600+
float * c_samples = (float *)malloc(n_samples * sizeof(float));
601+
if (memory_view_available_p) {
602+
c_samples = (float *)view.data;
603+
} else {
604+
if (TYPE(samples) == T_ARRAY) {
605+
for (int i = 0; i < n_samples; i++) {
606+
c_samples[i] = RFLOAT_VALUE(rb_ary_entry(samples, i));
607+
}
608+
} else {
609+
// TODO: use rb_block_call
610+
VALUE iter = rb_funcall(samples, id_to_enum, 1, rb_str_new2("each"));
611+
for (int i = 0; i < n_samples; i++) {
612+
// TODO: check if iter is exhausted and raise ArgumentError appropriately
613+
VALUE sample = rb_funcall(iter, id_next, 0);
614+
c_samples[i] = RFLOAT_VALUE(sample);
615+
}
616+
}
617+
}
618+
const int result = whisper_full(rw->context, rwp->params, c_samples, n_samples);
619+
if (0 == result) {
620+
return Qnil;
621+
} else {
622+
rb_exc_raise(rb_funcall(eError, id_new, 1, result));
623+
}
624+
}
625+
626+
/*
627+
* Split the input audio in chunks and process each chunk separately using whisper_full_with_state()
628+
* Result is stored in the default state of the context
629+
* Not thread safe if executed in parallel on the same context.
630+
* It seems this approach can offer some speedup in some cases.
631+
* However, the transcription accuracy can be worse at the beginning and end of each chunk.
632+
*
633+
* call-seq:
634+
* full_parallel(params, samples) -> nil
635+
* full_parallel(params, samples, n_samples) -> nil
636+
* full_parallel(params, samples, n_samples, n_processors) -> nil
637+
* full_parallel(params, samples, nil, n_processors) -> nil
638+
*/
639+
static VALUE ruby_whisper_full_parallel(int argc, VALUE *argv,VALUE self) {
640+
if (argc < 2 || argc > 4) {
641+
rb_raise(rb_eArgError, "wrong number of arguments (given %d, expected 2..3)", argc);
642+
}
643+
644+
ruby_whisper *rw;
645+
ruby_whisper_params *rwp;
646+
Data_Get_Struct(self, ruby_whisper, rw);
647+
VALUE params = argv[0];
648+
Data_Get_Struct(params, ruby_whisper_params, rwp);
649+
VALUE samples = argv[1];
650+
int n_samples;
651+
int n_processors;
652+
rb_memory_view_t view;
653+
const bool memory_view_available_p = rb_memory_view_available_p(samples);
654+
switch (argc) {
655+
case 2:
656+
n_processors = 1;
657+
break;
658+
case 3:
659+
n_processors = 1;
660+
break;
661+
case 4:
662+
n_processors = NUM2INT(argv[3]);
663+
break;
664+
}
665+
if (argc >= 3 && !NIL_P(argv[2])) {
666+
n_samples = NUM2INT(argv[2]);
667+
if (TYPE(samples) == T_ARRAY) {
668+
if (RARRAY_LEN(samples) < n_samples) {
669+
rb_raise(rb_eArgError, "samples length %ld is less than n_samples %d", RARRAY_LEN(samples), n_samples);
670+
}
671+
}
672+
// Should check when samples.respond_to?(:length)?
673+
} else if (memory_view_available_p) {
674+
if (!rb_memory_view_get(samples, &view, RUBY_MEMORY_VIEW_SIMPLE)) {
675+
view.obj = Qnil;
676+
rb_raise(rb_eArgError, "unable to get a memory view");
677+
}
678+
n_samples = view.byte_size / view.item_size;
679+
} else {
680+
if (TYPE(samples) == T_ARRAY) {
681+
n_samples = RARRAY_LEN(samples);
682+
} else if (rb_respond_to(samples, id_length)) {
683+
n_samples = NUM2INT(rb_funcall(samples, id_length, 0));
684+
} else {
685+
rb_raise(rb_eArgError, "samples must respond to :length or be a MemoryView of an array of flaot when n_samples is not given");
686+
}
687+
}
688+
float * c_samples = (float *)malloc(n_samples * sizeof(float));
689+
if (memory_view_available_p) {
690+
c_samples = (float *)view.data;
691+
} else {
692+
if (TYPE(samples) == T_ARRAY) {
693+
for (int i = 0; i < n_samples; i++) {
694+
c_samples[i] = RFLOAT_VALUE(rb_ary_entry(samples, i));
695+
}
696+
} else {
697+
// FIXME: use rb_block_call
698+
VALUE iter = rb_funcall(samples, id_to_enum, 1, rb_str_new2("each"));
699+
for (int i = 0; i < n_samples; i++) {
700+
// TODO: check if iter is exhausted and raise ArgumentError
701+
VALUE sample = rb_funcall(iter, id_next, 0);
702+
c_samples[i] = RFLOAT_VALUE(sample);
703+
}
704+
}
705+
}
706+
const int result = whisper_full_parallel(rw->context, rwp->params, c_samples, n_samples, n_processors);
707+
if (0 == result) {
708+
return Qnil;
709+
} else {
710+
rb_exc_raise(rb_funcall(eError, id_new, 1, result));
711+
}
712+
}
713+
547714
/*
548715
* Number of segments.
549716
*
@@ -1518,15 +1685,59 @@ static VALUE ruby_whisper_c_model_type(VALUE self) {
15181685
return rb_str_new2(whisper_model_type_readable(rw->context));
15191686
}
15201687

1688+
static VALUE ruby_whisper_error_initialize(VALUE self, VALUE code) {
1689+
const int c_code = NUM2INT(code);
1690+
char *raw_message;
1691+
switch (c_code) {
1692+
case -2:
1693+
raw_message = "failed to compute log mel spectrogram";
1694+
break;
1695+
case -3:
1696+
raw_message = "failed to auto-detect language";
1697+
break;
1698+
case -4:
1699+
raw_message = "too many decoders requested";
1700+
break;
1701+
case -5:
1702+
raw_message = "audio_ctx is larger than the maximum allowed";
1703+
break;
1704+
case -6:
1705+
raw_message = "failed to encode";
1706+
break;
1707+
case -7:
1708+
raw_message = "whisper_kv_cache_init() failed for self-attention cache";
1709+
break;
1710+
case -8:
1711+
raw_message = "failed to decode";
1712+
break;
1713+
case -9:
1714+
raw_message = "failed to decode";
1715+
break;
1716+
default:
1717+
raw_message = "unknown error";
1718+
break;
1719+
}
1720+
const VALUE message = rb_str_new2(raw_message);
1721+
rb_call_super(1, &message);
1722+
rb_iv_set(self, "@code", code);
1723+
1724+
return self;
1725+
}
1726+
1727+
15211728
void Init_whisper() {
15221729
id_to_s = rb_intern("to_s");
15231730
id_call = rb_intern("call");
15241731
id___method__ = rb_intern("__method__");
15251732
id_to_enum = rb_intern("to_enum");
1733+
id_length = rb_intern("length");
1734+
id_next = rb_intern("next");
1735+
id_new = rb_intern("new");
15261736

15271737
mWhisper = rb_define_module("Whisper");
15281738
cContext = rb_define_class_under(mWhisper, "Context", rb_cObject);
15291739
cParams = rb_define_class_under(mWhisper, "Params", rb_cObject);
1740+
eError = rb_define_class_under(mWhisper, "Error", rb_eStandardError);
15301741

15311742
rb_define_const(mWhisper, "LOG_LEVEL_NONE", INT2NUM(GGML_LOG_LEVEL_NONE));
15321743
rb_define_const(mWhisper, "LOG_LEVEL_INFO", INT2NUM(GGML_LOG_LEVEL_INFO));
@@ -1564,6 +1775,8 @@ void Init_whisper() {
15641775
rb_define_method(cContext, "full_get_segment_t1", ruby_whisper_full_get_segment_t1, 1);
15651776
rb_define_method(cContext, "full_get_segment_speaker_turn_next", ruby_whisper_full_get_segment_speaker_turn_next, 1);
15661777
rb_define_method(cContext, "full_get_segment_text", ruby_whisper_full_get_segment_text, 1);
1778+
rb_define_method(cContext, "full", ruby_whisper_full, -1);
1779+
rb_define_method(cContext, "full_parallel", ruby_whisper_full_parallel, -1);
15671780

15681781
rb_define_alloc_func(cParams, ruby_whisper_params_allocate);
15691782

@@ -1623,6 +1836,9 @@ void Init_whisper() {
16231836
rb_define_method(cParams, "abort_callback=", ruby_whisper_params_set_abort_callback, 1);
16241837
rb_define_method(cParams, "abort_callback_user_data=", ruby_whisper_params_set_abort_callback_user_data, 1);
16251838

1839+
rb_define_attr(eError, "code", true, false);
1840+
rb_define_method(eError, "initialize", ruby_whisper_error_initialize, 1);
1841+
16261842
// High leve
16271843
cSegment = rb_define_class_under(mWhisper, "Segment", rb_cObject);
16281844

bindings/ruby/tests/helper.rb

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
require "test/unit"
22
require "whisper"
3+
require_relative "jfk_reader/jfk_reader"
34

45
class TestBase < Test::Unit::TestCase
56
MODEL = File.join(__dir__, "..", "..", "..", "models", "ggml-base.en.bin")
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
Makefile
2+
jfk_reader.o
3+
jfk_reader.so
4+
jfk_reader.bundle
5+
jfk_reader.dll
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
require "mkmf"
2+
3+
create_makefile("jfk_reader")

0 commit comments

Comments
 (0)