@@ -699,9 +699,9 @@ struct whisper_kv_cache {
699
699
struct ggml_tensor * k;
700
700
struct ggml_tensor * v;
701
701
702
- struct ggml_context * ctx = nullptr ;
703
-
704
702
ggml_backend_buffer_t buffer = nullptr ;
703
+
704
+ std::vector<uint8_t > ctx_buf;
705
705
};
706
706
707
707
struct whisper_model {
@@ -941,9 +941,11 @@ static bool whisper_kv_cache_init(
941
941
const int64_t n_mem = n_text_layer*n_ctx;
942
942
const int64_t n_elements = n_text_state*n_mem;
943
943
944
+ cache.ctx_buf .resize (2 *ggml_tensor_overhead ());
945
+
944
946
struct ggml_init_params params = {
945
- /* .mem_size =*/ 2 * ggml_tensor_overhead (),
946
- /* .mem_buffer =*/ nullptr ,
947
+ /* .mem_size =*/ cache. ctx_buf . size (),
948
+ /* .mem_buffer =*/ cache. ctx_buf . data () ,
947
949
/* .no_alloc =*/ true ,
948
950
};
949
951
@@ -953,31 +955,31 @@ static bool whisper_kv_cache_init(
953
955
cache.cells .clear ();
954
956
cache.cells .resize (n_ctx);
955
957
956
- cache. ctx = ggml_init (params);
958
+ struct ggml_context * ctx = ggml_init (params);
957
959
958
- if (!cache. ctx ) {
960
+ if (!ctx) {
959
961
WHISPER_LOG_ERROR (" %s: failed to allocate memory for the kv cache context\n " , __func__);
960
962
return false ;
961
963
}
962
964
963
- cache.k = ggml_new_tensor_1d (cache. ctx , wtype, n_elements);
964
- cache.v = ggml_new_tensor_1d (cache. ctx , wtype, n_elements);
965
+ cache.k = ggml_new_tensor_1d (ctx, wtype, n_elements);
966
+ cache.v = ggml_new_tensor_1d (ctx, wtype, n_elements);
965
967
966
- cache.buffer = ggml_backend_alloc_ctx_tensors (cache. ctx , backend);
968
+ cache.buffer = ggml_backend_alloc_ctx_tensors (ctx, backend);
967
969
if (!cache.buffer ) {
968
970
WHISPER_LOG_ERROR (" %s: failed to allocate memory for the kv cache\n " , __func__);
969
971
return false ;
970
972
}
971
973
972
974
ggml_backend_buffer_clear (cache.buffer , 0 );
973
975
976
+ ggml_free (ctx);
977
+
974
978
return true ;
975
979
}
976
980
977
981
static void whisper_kv_cache_free (struct whisper_kv_cache & cache) {
978
- ggml_free (cache.ctx );
979
982
ggml_backend_buffer_free (cache.buffer );
980
- cache.ctx = nullptr ;
981
983
}
982
984
983
985
static bool whisper_kv_cache_find_slot (
@@ -2002,7 +2004,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
2002
2004
2003
2005
auto & kv_pad = wstate.kv_pad ;
2004
2006
2005
- WHISPER_ASSERT (!!kv_pad.ctx );
2007
+ WHISPER_ASSERT (!!kv_pad.buffer );
2006
2008
2007
2009
const int n_ctx_pad = GGML_PAD (n_ctx, 256 );
2008
2010
@@ -2416,7 +2418,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
2416
2418
2417
2419
auto & kv_self = wstate.kv_self ;
2418
2420
2419
- WHISPER_ASSERT (!!kv_self.ctx );
2421
+ WHISPER_ASSERT (!!kv_self.buffer );
2420
2422
2421
2423
const int n_ctx = kv_self.size ;
2422
2424
const int n_state = hparams.n_text_state ;
0 commit comments