1
1
#include < ruby.h>
2
+ #include < ruby/memory_view.h>
2
3
#include " ruby_whisper.h"
3
4
#define DR_WAV_IMPLEMENTATION
4
5
#include " dr_wav.h"
@@ -35,11 +36,15 @@ extern "C" {
35
36
VALUE mWhisper ;
36
37
VALUE cContext;
37
38
VALUE cParams;
39
+ VALUE eError;
38
40
39
41
static ID id_to_s;
40
42
static ID id_call;
41
43
static ID id___method__;
42
44
static ID id_to_enum;
45
+ static ID id_length;
46
+ static ID id_next;
47
+ static ID id_new;
43
48
44
49
static bool is_log_callback_finalized = false ;
45
50
@@ -100,13 +105,13 @@ static VALUE ruby_whisper_s_finalize_log_callback(VALUE self, VALUE id) {
100
105
* log_set ->(level, buffer, user_data) { ... }, user_data -> nil
101
106
*/
102
107
static VALUE ruby_whisper_s_log_set (VALUE self, VALUE log_callback, VALUE user_data) {
103
- VALUE old_callback = rb_iv_get (self, " @ log_callback" );
108
+ VALUE old_callback = rb_iv_get (self, " log_callback" );
104
109
if (!NIL_P (old_callback)) {
105
110
rb_undefine_finalizer (old_callback);
106
111
}
107
112
108
- rb_iv_set (self, " @ log_callback" , log_callback);
109
- rb_iv_set (self, " @ user_data" , user_data);
113
+ rb_iv_set (self, " log_callback" , log_callback);
114
+ rb_iv_set (self, " user_data" , user_data);
110
115
111
116
VALUE finalize_log_callback = rb_funcall (mWhisper , rb_intern (" method" ), 1 , rb_str_new2 (" finalize_log_callback" ));
112
117
rb_define_finalizer (log_callback, finalize_log_callback);
@@ -115,8 +120,8 @@ static VALUE ruby_whisper_s_log_set(VALUE self, VALUE log_callback, VALUE user_d
115
120
if (is_log_callback_finalized) {
116
121
return ;
117
122
}
118
- VALUE log_callback = rb_iv_get (mWhisper , " @ log_callback" );
119
- VALUE udata = rb_iv_get (mWhisper , " @ user_data" );
123
+ VALUE log_callback = rb_iv_get (mWhisper , " log_callback" );
124
+ VALUE udata = rb_iv_get (mWhisper , " user_data" );
120
125
rb_funcall (log_callback, id_call, 3 , INT2NUM (level), rb_str_new2 (buffer), udata);
121
126
}, nullptr );
122
127
@@ -544,6 +549,168 @@ VALUE ruby_whisper_model_type(VALUE self) {
544
549
return rb_str_new2 (whisper_model_type_readable (rw->context ));
545
550
}
546
551
552
+ /*
553
+ * Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
554
+ * Not thread safe for same context
555
+ * Uses the specified decoding strategy to obtain the text.
556
+ *
557
+ * call-seq:
558
+ * full(params, samples, n_samples) -> nil
559
+ * full(params, samples) -> nil
560
+ *
561
+ * The second argument +samples+ must be an array of samples, respond to :length, or be a MemoryView of an array of float. It must be 32 bit float PCM audio data.
562
+ */
563
+ VALUE ruby_whisper_full (int argc, VALUE *argv, VALUE self) {
564
+ if (argc < 2 || argc > 3 ) {
565
+ rb_raise (rb_eArgError, " wrong number of arguments (given %d, expected 2..3)" , argc);
566
+ }
567
+
568
+ ruby_whisper *rw;
569
+ ruby_whisper_params *rwp;
570
+ Data_Get_Struct (self, ruby_whisper, rw);
571
+ VALUE params = argv[0 ];
572
+ Data_Get_Struct (params, ruby_whisper_params, rwp);
573
+ VALUE samples = argv[1 ];
574
+ int n_samples;
575
+ rb_memory_view_t view;
576
+ const bool memory_view_available_p = rb_memory_view_available_p (samples);
577
+ if (argc == 3 ) {
578
+ n_samples = NUM2INT (argv[2 ]);
579
+ if (TYPE (samples) == T_ARRAY) {
580
+ if (RARRAY_LEN (samples) < n_samples) {
581
+ rb_raise (rb_eArgError, " samples length %ld is less than n_samples %d" , RARRAY_LEN (samples), n_samples);
582
+ }
583
+ }
584
+ // Should check when samples.respond_to?(:length)?
585
+ } else {
586
+ if (TYPE (samples) == T_ARRAY) {
587
+ n_samples = RARRAY_LEN (samples);
588
+ } else if (memory_view_available_p) {
589
+ if (!rb_memory_view_get (samples, &view, RUBY_MEMORY_VIEW_SIMPLE)) {
590
+ view.obj = Qnil;
591
+ rb_raise (rb_eArgError, " unable to get a memory view" );
592
+ }
593
+ n_samples = view.byte_size / view.item_size ;
594
+ } else if (rb_respond_to (samples, id_length)) {
595
+ n_samples = NUM2INT (rb_funcall (samples, id_length, 0 ));
596
+ } else {
597
+ rb_raise (rb_eArgError, " samples must respond to :length or be a MemoryView of an array of flaot when n_samples is not given" );
598
+ }
599
+ }
600
+ float * c_samples = (float *)malloc (n_samples * sizeof (float ));
601
+ if (memory_view_available_p) {
602
+ c_samples = (float *)view.data ;
603
+ } else {
604
+ if (TYPE (samples) == T_ARRAY) {
605
+ for (int i = 0 ; i < n_samples; i++) {
606
+ c_samples[i] = RFLOAT_VALUE (rb_ary_entry (samples, i));
607
+ }
608
+ } else {
609
+ // TODO: use rb_block_call
610
+ VALUE iter = rb_funcall (samples, id_to_enum, 1 , rb_str_new2 (" each" ));
611
+ for (int i = 0 ; i < n_samples; i++) {
612
+ // TODO: check if iter is exhausted and raise ArgumentError appropriately
613
+ VALUE sample = rb_funcall (iter, id_next, 0 );
614
+ c_samples[i] = RFLOAT_VALUE (sample);
615
+ }
616
+ }
617
+ }
618
+ const int result = whisper_full (rw->context , rwp->params , c_samples, n_samples);
619
+ if (0 == result) {
620
+ return Qnil;
621
+ } else {
622
+ rb_exc_raise (rb_funcall (eError, id_new, 1 , result));
623
+ }
624
+ }
625
+
626
+ /*
627
+ * Split the input audio in chunks and process each chunk separately using whisper_full_with_state()
628
+ * Result is stored in the default state of the context
629
+ * Not thread safe if executed in parallel on the same context.
630
+ * It seems this approach can offer some speedup in some cases.
631
+ * However, the transcription accuracy can be worse at the beginning and end of each chunk.
632
+ *
633
+ * call-seq:
634
+ * full_parallel(params, samples) -> nil
635
+ * full_parallel(params, samples, n_samples) -> nil
636
+ * full_parallel(params, samples, n_samples, n_processors) -> nil
637
+ * full_parallel(params, samples, nil, n_processors) -> nil
638
+ */
639
+ static VALUE ruby_whisper_full_parallel (int argc, VALUE *argv,VALUE self) {
640
+ if (argc < 2 || argc > 4 ) {
641
+ rb_raise (rb_eArgError, " wrong number of arguments (given %d, expected 2..3)" , argc);
642
+ }
643
+
644
+ ruby_whisper *rw;
645
+ ruby_whisper_params *rwp;
646
+ Data_Get_Struct (self, ruby_whisper, rw);
647
+ VALUE params = argv[0 ];
648
+ Data_Get_Struct (params, ruby_whisper_params, rwp);
649
+ VALUE samples = argv[1 ];
650
+ int n_samples;
651
+ int n_processors;
652
+ rb_memory_view_t view;
653
+ const bool memory_view_available_p = rb_memory_view_available_p (samples);
654
+ switch (argc) {
655
+ case 2 :
656
+ n_processors = 1 ;
657
+ break ;
658
+ case 3 :
659
+ n_processors = 1 ;
660
+ break ;
661
+ case 4 :
662
+ n_processors = NUM2INT (argv[3 ]);
663
+ break ;
664
+ }
665
+ if (argc >= 3 && !NIL_P (argv[2 ])) {
666
+ n_samples = NUM2INT (argv[2 ]);
667
+ if (TYPE (samples) == T_ARRAY) {
668
+ if (RARRAY_LEN (samples) < n_samples) {
669
+ rb_raise (rb_eArgError, " samples length %ld is less than n_samples %d" , RARRAY_LEN (samples), n_samples);
670
+ }
671
+ }
672
+ // Should check when samples.respond_to?(:length)?
673
+ } else if (memory_view_available_p) {
674
+ if (!rb_memory_view_get (samples, &view, RUBY_MEMORY_VIEW_SIMPLE)) {
675
+ view.obj = Qnil;
676
+ rb_raise (rb_eArgError, " unable to get a memory view" );
677
+ }
678
+ n_samples = view.byte_size / view.item_size ;
679
+ } else {
680
+ if (TYPE (samples) == T_ARRAY) {
681
+ n_samples = RARRAY_LEN (samples);
682
+ } else if (rb_respond_to (samples, id_length)) {
683
+ n_samples = NUM2INT (rb_funcall (samples, id_length, 0 ));
684
+ } else {
685
+ rb_raise (rb_eArgError, " samples must respond to :length or be a MemoryView of an array of flaot when n_samples is not given" );
686
+ }
687
+ }
688
+ float * c_samples = (float *)malloc (n_samples * sizeof (float ));
689
+ if (memory_view_available_p) {
690
+ c_samples = (float *)view.data ;
691
+ } else {
692
+ if (TYPE (samples) == T_ARRAY) {
693
+ for (int i = 0 ; i < n_samples; i++) {
694
+ c_samples[i] = RFLOAT_VALUE (rb_ary_entry (samples, i));
695
+ }
696
+ } else {
697
+ // FIXME: use rb_block_call
698
+ VALUE iter = rb_funcall (samples, id_to_enum, 1 , rb_str_new2 (" each" ));
699
+ for (int i = 0 ; i < n_samples; i++) {
700
+ // TODO: check if iter is exhausted and raise ArgumentError
701
+ VALUE sample = rb_funcall (iter, id_next, 0 );
702
+ c_samples[i] = RFLOAT_VALUE (sample);
703
+ }
704
+ }
705
+ }
706
+ const int result = whisper_full_parallel (rw->context , rwp->params , c_samples, n_samples, n_processors);
707
+ if (0 == result) {
708
+ return Qnil;
709
+ } else {
710
+ rb_exc_raise (rb_funcall (eError, id_new, 1 , result));
711
+ }
712
+ }
713
+
547
714
/*
548
715
* Number of segments.
549
716
*
@@ -1518,15 +1685,59 @@ static VALUE ruby_whisper_c_model_type(VALUE self) {
1518
1685
return rb_str_new2 (whisper_model_type_readable (rw->context ));
1519
1686
}
1520
1687
1688
+ static VALUE ruby_whisper_error_initialize (VALUE self, VALUE code) {
1689
+ const int c_code = NUM2INT (code);
1690
+ char *raw_message;
1691
+ switch (c_code) {
1692
+ case -2 :
1693
+ raw_message = " failed to compute log mel spectrogram" ;
1694
+ break ;
1695
+ case -3 :
1696
+ raw_message = " failed to auto-detect language" ;
1697
+ break ;
1698
+ case -4 :
1699
+ raw_message = " too many decoders requested" ;
1700
+ break ;
1701
+ case -5 :
1702
+ raw_message = " audio_ctx is larger than the maximum allowed" ;
1703
+ break ;
1704
+ case -6 :
1705
+ raw_message = " failed to encode" ;
1706
+ break ;
1707
+ case -7 :
1708
+ raw_message = " whisper_kv_cache_init() failed for self-attention cache" ;
1709
+ break ;
1710
+ case -8 :
1711
+ raw_message = " failed to decode" ;
1712
+ break ;
1713
+ case -9 :
1714
+ raw_message = " failed to decode" ;
1715
+ break ;
1716
+ default :
1717
+ raw_message = " unknown error" ;
1718
+ break ;
1719
+ }
1720
+ const VALUE message = rb_str_new2 (raw_message);
1721
+ rb_call_super (1 , &message);
1722
+ rb_iv_set (self, " @code" , code);
1723
+
1724
+ return self;
1725
+ }
1726
+
1727
+
1521
1728
void Init_whisper () {
1522
1729
id_to_s = rb_intern (" to_s" );
1523
1730
id_call = rb_intern (" call" );
1524
1731
id___method__ = rb_intern (" __method__" );
1525
1732
id_to_enum = rb_intern (" to_enum" );
1733
+ id_length = rb_intern (" length" );
1734
+ id_next = rb_intern (" next" );
1735
+ id_new = rb_intern (" new" );
1526
1736
1527
1737
mWhisper = rb_define_module (" Whisper" );
1528
1738
cContext = rb_define_class_under (mWhisper , " Context" , rb_cObject);
1529
1739
cParams = rb_define_class_under (mWhisper , " Params" , rb_cObject);
1740
+ eError = rb_define_class_under (mWhisper , " Error" , rb_eStandardError);
1530
1741
1531
1742
rb_define_const (mWhisper , " LOG_LEVEL_NONE" , INT2NUM (GGML_LOG_LEVEL_NONE));
1532
1743
rb_define_const (mWhisper , " LOG_LEVEL_INFO" , INT2NUM (GGML_LOG_LEVEL_INFO));
@@ -1564,6 +1775,8 @@ void Init_whisper() {
1564
1775
rb_define_method (cContext, " full_get_segment_t1" , ruby_whisper_full_get_segment_t1, 1 );
1565
1776
rb_define_method (cContext, " full_get_segment_speaker_turn_next" , ruby_whisper_full_get_segment_speaker_turn_next, 1 );
1566
1777
rb_define_method (cContext, " full_get_segment_text" , ruby_whisper_full_get_segment_text, 1 );
1778
+ rb_define_method (cContext, " full" , ruby_whisper_full, -1 );
1779
+ rb_define_method (cContext, " full_parallel" , ruby_whisper_full_parallel, -1 );
1567
1780
1568
1781
rb_define_alloc_func (cParams, ruby_whisper_params_allocate);
1569
1782
@@ -1623,6 +1836,9 @@ void Init_whisper() {
1623
1836
rb_define_method (cParams, " abort_callback=" , ruby_whisper_params_set_abort_callback, 1 );
1624
1837
rb_define_method (cParams, " abort_callback_user_data=" , ruby_whisper_params_set_abort_callback_user_data, 1 );
1625
1838
1839
+ rb_define_attr (eError, " code" , true , false );
1840
+ rb_define_method (eError, " initialize" , ruby_whisper_error_initialize, 1 );
1841
+
1626
1842
// High leve
1627
1843
cSegment = rb_define_class_under (mWhisper , " Segment" , rb_cObject);
1628
1844
0 commit comments