Skip to content

Commit aa037a6

Browse files
authored
ggml : alloc ggml_contexts on the heap (#2525)
* whisper : reduce ggml_context usage * ggml : allocate contexts on the heap (v2) * ggml : aligned malloc -> malloc
1 parent 19dca2b commit aa037a6

File tree

4 files changed

+37
-67
lines changed

4 files changed

+37
-67
lines changed

ggml/include/ggml.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,6 @@
217217

218218
#define GGML_MAX_DIMS 4
219219
#define GGML_MAX_PARAMS 2048
220-
#define GGML_MAX_CONTEXTS 64
221220
#define GGML_MAX_SRC 10
222221
#define GGML_MAX_N_THREADS 512
223222
#define GGML_MAX_OP_PARAMS 64
@@ -657,6 +656,7 @@ extern "C" {
657656
};
658657

659658
// scratch buffer
659+
// TODO: deprecate and remove
660660
struct ggml_scratch {
661661
size_t offs;
662662
size_t size;
@@ -760,8 +760,9 @@ extern "C" {
760760

761761
// main
762762

763-
GGML_API struct ggml_context * ggml_init(struct ggml_init_params params);
764-
GGML_API void ggml_free(struct ggml_context * ctx);
763+
GGML_API struct ggml_context * ggml_init (struct ggml_init_params params);
764+
GGML_API void ggml_reset(struct ggml_context * ctx);
765+
GGML_API void ggml_free (struct ggml_context * ctx);
765766

766767
GGML_API size_t ggml_used_mem(const struct ggml_context * ctx);
767768

ggml/src/ggml-metal.m

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3129,7 +3129,7 @@ static enum ggml_status ggml_metal_graph_compute(
31293129

31303130
// default buffer
31313131
static id<MTLDevice> g_backend_device = nil;
3132-
static int g_backend_device_ref_count = 0;
3132+
static int g_backend_device_ref_count = 0; // TODO: make thread-safe
31333133

31343134
static id<MTLDevice> ggml_backend_metal_get_device(void) {
31353135
if (g_backend_device == nil) {

ggml/src/ggml.c

Lines changed: 17 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,7 @@ void ggml_abort(const char * file, int line, const char * fmt, ...) {
308308
}
309309

310310
#define GGML_DEBUG 0
311+
311312
#define GGML_GELU_FP16
312313
#define GGML_GELU_QUICK_FP16
313314

@@ -1985,7 +1986,7 @@ static const size_t GGML_OBJECT_SIZE = sizeof(struct ggml_object);
19851986

19861987
struct ggml_context {
19871988
size_t mem_size;
1988-
void* mem_buffer;
1989+
void * mem_buffer;
19891990
bool mem_buffer_owned;
19901991
bool no_alloc;
19911992
bool no_alloc_save; // this is used to save the no_alloc state when using scratch buffers
@@ -3234,7 +3235,6 @@ struct ggml_numa_nodes {
32343235
//
32353236

32363237
struct ggml_state {
3237-
struct ggml_context_container contexts[GGML_MAX_CONTEXTS];
32383238
struct ggml_numa_nodes numa;
32393239
};
32403240

@@ -3816,17 +3816,12 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
38163816
const uint64_t t_start = ggml_time_us(); UNUSED(t_start);
38173817

38183818
g_state = (struct ggml_state) {
3819-
/*.contexts =*/ { { 0 } },
38203819
/*.numa =*/ {
38213820
.n_nodes = 0,
38223821
.total_cpus = 0,
38233822
},
38243823
};
38253824

3826-
for (int i = 0; i < GGML_MAX_CONTEXTS; ++i) {
3827-
g_state.contexts[i].used = false;
3828-
}
3829-
38303825
const uint64_t t_end = ggml_time_us(); UNUSED(t_end);
38313826

38323827
GGML_PRINT_DEBUG("%s: g_state initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
@@ -3839,26 +3834,9 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
38393834
is_first_call = false;
38403835
}
38413836

3842-
// find non-used context in g_state
3843-
struct ggml_context * ctx = NULL;
3844-
3845-
for (int i = 0; i < GGML_MAX_CONTEXTS; i++) {
3846-
if (!g_state.contexts[i].used) {
3847-
g_state.contexts[i].used = true;
3848-
ctx = &g_state.contexts[i].context;
3849-
3850-
GGML_PRINT_DEBUG("%s: found unused context %d\n", __func__, i);
3851-
break;
3852-
}
3853-
}
3854-
3855-
if (ctx == NULL) {
3856-
GGML_PRINT_DEBUG("%s: no unused context found\n", __func__);
3857-
3858-
ggml_critical_section_end();
3837+
ggml_critical_section_end();
38593838

3860-
return NULL;
3861-
}
3839+
struct ggml_context * ctx = GGML_MALLOC(sizeof(struct ggml_context));
38623840

38633841
// allow to call ggml_init with 0 size
38643842
if (params.mem_size == 0) {
@@ -3886,42 +3864,31 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
38863864

38873865
GGML_PRINT_DEBUG("%s: context initialized\n", __func__);
38883866

3889-
ggml_critical_section_end();
3890-
38913867
return ctx;
38923868
}
38933869

3894-
void ggml_free(struct ggml_context * ctx) {
3870+
void ggml_reset(struct ggml_context * ctx) {
38953871
if (ctx == NULL) {
38963872
return;
38973873
}
38983874

3899-
// make this function thread safe
3900-
ggml_critical_section_start();
3901-
3902-
bool found = false;
3903-
3904-
for (int i = 0; i < GGML_MAX_CONTEXTS; i++) {
3905-
if (&g_state.contexts[i].context == ctx) {
3906-
g_state.contexts[i].used = false;
3907-
3908-
GGML_PRINT_DEBUG("%s: context %d has been freed. memory used = %zu\n",
3909-
__func__, i, ggml_used_mem(ctx));
3910-
3911-
if (ctx->mem_buffer_owned) {
3912-
GGML_ALIGNED_FREE(ctx->mem_buffer);
3913-
}
3875+
ctx->n_objects = 0;
3876+
ctx->objects_begin = NULL;
3877+
ctx->objects_end = NULL;
3878+
ctx->scratch = (struct ggml_scratch) { 0, 0, NULL, };
3879+
ctx->scratch_save = (struct ggml_scratch) { 0, 0, NULL, };
3880+
}
39143881

3915-
found = true;
3916-
break;
3917-
}
3882+
void ggml_free(struct ggml_context * ctx) {
3883+
if (ctx == NULL) {
3884+
return;
39183885
}
39193886

3920-
if (!found) {
3921-
GGML_PRINT_DEBUG("%s: context not found\n", __func__);
3887+
if (ctx->mem_buffer_owned) {
3888+
GGML_ALIGNED_FREE(ctx->mem_buffer);
39223889
}
39233890

3924-
ggml_critical_section_end();
3891+
GGML_FREE(ctx);
39253892
}
39263893

39273894
size_t ggml_used_mem(const struct ggml_context * ctx) {

src/whisper.cpp

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -699,9 +699,9 @@ struct whisper_kv_cache {
699699
struct ggml_tensor * k;
700700
struct ggml_tensor * v;
701701

702-
struct ggml_context * ctx = nullptr;
703-
704702
ggml_backend_buffer_t buffer = nullptr;
703+
704+
std::vector<uint8_t> ctx_buf;
705705
};
706706

707707
struct whisper_model {
@@ -941,9 +941,11 @@ static bool whisper_kv_cache_init(
941941
const int64_t n_mem = n_text_layer*n_ctx;
942942
const int64_t n_elements = n_text_state*n_mem;
943943

944+
cache.ctx_buf.resize(2*ggml_tensor_overhead());
945+
944946
struct ggml_init_params params = {
945-
/*.mem_size =*/ 2*ggml_tensor_overhead(),
946-
/*.mem_buffer =*/ nullptr,
947+
/*.mem_size =*/ cache.ctx_buf.size(),
948+
/*.mem_buffer =*/ cache.ctx_buf.data(),
947949
/*.no_alloc =*/ true,
948950
};
949951

@@ -953,31 +955,31 @@ static bool whisper_kv_cache_init(
953955
cache.cells.clear();
954956
cache.cells.resize(n_ctx);
955957

956-
cache.ctx = ggml_init(params);
958+
struct ggml_context * ctx = ggml_init(params);
957959

958-
if (!cache.ctx) {
960+
if (!ctx) {
959961
WHISPER_LOG_ERROR("%s: failed to allocate memory for the kv cache context\n", __func__);
960962
return false;
961963
}
962964

963-
cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
964-
cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
965+
cache.k = ggml_new_tensor_1d(ctx, wtype, n_elements);
966+
cache.v = ggml_new_tensor_1d(ctx, wtype, n_elements);
965967

966-
cache.buffer = ggml_backend_alloc_ctx_tensors(cache.ctx, backend);
968+
cache.buffer = ggml_backend_alloc_ctx_tensors(ctx, backend);
967969
if (!cache.buffer) {
968970
WHISPER_LOG_ERROR("%s: failed to allocate memory for the kv cache\n", __func__);
969971
return false;
970972
}
971973

972974
ggml_backend_buffer_clear(cache.buffer, 0);
973975

976+
ggml_free(ctx);
977+
974978
return true;
975979
}
976980

977981
static void whisper_kv_cache_free(struct whisper_kv_cache & cache) {
978-
ggml_free(cache.ctx);
979982
ggml_backend_buffer_free(cache.buffer);
980-
cache.ctx = nullptr;
981983
}
982984

983985
static bool whisper_kv_cache_find_slot(
@@ -2002,7 +2004,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
20022004

20032005
auto & kv_pad = wstate.kv_pad;
20042006

2005-
WHISPER_ASSERT(!!kv_pad.ctx);
2007+
WHISPER_ASSERT(!!kv_pad.buffer);
20062008

20072009
const int n_ctx_pad = GGML_PAD(n_ctx, 256);
20082010

@@ -2416,7 +2418,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
24162418

24172419
auto & kv_self = wstate.kv_self;
24182420

2419-
WHISPER_ASSERT(!!kv_self.ctx);
2421+
WHISPER_ASSERT(!!kv_self.buffer);
24202422

24212423
const int n_ctx = kv_self.size;
24222424
const int n_state = hparams.n_text_state;

0 commit comments

Comments
 (0)