Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,6 @@

#define GGML_MAX_DIMS 4
#define GGML_MAX_PARAMS 2048
#define GGML_MAX_CONTEXTS 64
#define GGML_MAX_SRC 10
#define GGML_MAX_N_THREADS 512
#define GGML_MAX_OP_PARAMS 64
Expand Down Expand Up @@ -657,6 +656,7 @@ extern "C" {
};

// scratch buffer
// TODO: deprecate and remove
struct ggml_scratch {
size_t offs;
size_t size;
Expand Down Expand Up @@ -760,8 +760,9 @@ extern "C" {

// main

GGML_API struct ggml_context * ggml_init(struct ggml_init_params params);
GGML_API void ggml_free(struct ggml_context * ctx);
GGML_API struct ggml_context * ggml_init (struct ggml_init_params params);
GGML_API void ggml_reset(struct ggml_context * ctx);
GGML_API void ggml_free (struct ggml_context * ctx);

GGML_API size_t ggml_used_mem(const struct ggml_context * ctx);

Expand Down
2 changes: 1 addition & 1 deletion ggml/src/ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -3129,7 +3129,7 @@ static enum ggml_status ggml_metal_graph_compute(

// default buffer
static id<MTLDevice> g_backend_device = nil;
static int g_backend_device_ref_count = 0;
static int g_backend_device_ref_count = 0; // TODO: make thread-safe

static id<MTLDevice> ggml_backend_metal_get_device(void) {
if (g_backend_device == nil) {
Expand Down
67 changes: 17 additions & 50 deletions ggml/src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,7 @@ void ggml_abort(const char * file, int line, const char * fmt, ...) {
}

#define GGML_DEBUG 0

#define GGML_GELU_FP16
#define GGML_GELU_QUICK_FP16

Expand Down Expand Up @@ -1985,7 +1986,7 @@ static const size_t GGML_OBJECT_SIZE = sizeof(struct ggml_object);

struct ggml_context {
size_t mem_size;
void* mem_buffer;
void * mem_buffer;
bool mem_buffer_owned;
bool no_alloc;
bool no_alloc_save; // this is used to save the no_alloc state when using scratch buffers
Expand Down Expand Up @@ -3234,7 +3235,6 @@ struct ggml_numa_nodes {
//

struct ggml_state {
struct ggml_context_container contexts[GGML_MAX_CONTEXTS];
struct ggml_numa_nodes numa;
};

Expand Down Expand Up @@ -3816,17 +3816,12 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
const uint64_t t_start = ggml_time_us(); UNUSED(t_start);

g_state = (struct ggml_state) {
/*.contexts =*/ { { 0 } },
/*.numa =*/ {
.n_nodes = 0,
.total_cpus = 0,
},
};

for (int i = 0; i < GGML_MAX_CONTEXTS; ++i) {
g_state.contexts[i].used = false;
}

const uint64_t t_end = ggml_time_us(); UNUSED(t_end);

GGML_PRINT_DEBUG("%s: g_state initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
Expand All @@ -3839,26 +3834,9 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
is_first_call = false;
}

// find non-used context in g_state
struct ggml_context * ctx = NULL;

for (int i = 0; i < GGML_MAX_CONTEXTS; i++) {
if (!g_state.contexts[i].used) {
g_state.contexts[i].used = true;
ctx = &g_state.contexts[i].context;

GGML_PRINT_DEBUG("%s: found unused context %d\n", __func__, i);
break;
}
}

if (ctx == NULL) {
GGML_PRINT_DEBUG("%s: no unused context found\n", __func__);

ggml_critical_section_end();
ggml_critical_section_end();

return NULL;
}
struct ggml_context * ctx = GGML_MALLOC(sizeof(struct ggml_context));

// allow to call ggml_init with 0 size
if (params.mem_size == 0) {
Expand Down Expand Up @@ -3886,42 +3864,31 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {

GGML_PRINT_DEBUG("%s: context initialized\n", __func__);

ggml_critical_section_end();

return ctx;
}

void ggml_free(struct ggml_context * ctx) {
void ggml_reset(struct ggml_context * ctx) {
if (ctx == NULL) {
return;
}

// make this function thread safe
ggml_critical_section_start();

bool found = false;

for (int i = 0; i < GGML_MAX_CONTEXTS; i++) {
if (&g_state.contexts[i].context == ctx) {
g_state.contexts[i].used = false;

GGML_PRINT_DEBUG("%s: context %d has been freed. memory used = %zu\n",
__func__, i, ggml_used_mem(ctx));

if (ctx->mem_buffer_owned) {
GGML_ALIGNED_FREE(ctx->mem_buffer);
}
ctx->n_objects = 0;
ctx->objects_begin = NULL;
ctx->objects_end = NULL;
ctx->scratch = (struct ggml_scratch) { 0, 0, NULL, };
ctx->scratch_save = (struct ggml_scratch) { 0, 0, NULL, };
}

found = true;
break;
}
void ggml_free(struct ggml_context * ctx) {
if (ctx == NULL) {
return;
}

if (!found) {
GGML_PRINT_DEBUG("%s: context not found\n", __func__);
if (ctx->mem_buffer_owned) {
GGML_ALIGNED_FREE(ctx->mem_buffer);
}

ggml_critical_section_end();
GGML_FREE(ctx);
}

size_t ggml_used_mem(const struct ggml_context * ctx) {
Expand Down
28 changes: 15 additions & 13 deletions src/whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -699,9 +699,9 @@ struct whisper_kv_cache {
struct ggml_tensor * k;
struct ggml_tensor * v;

struct ggml_context * ctx = nullptr;

ggml_backend_buffer_t buffer = nullptr;

std::vector<uint8_t> ctx_buf;
};

struct whisper_model {
Expand Down Expand Up @@ -941,9 +941,11 @@ static bool whisper_kv_cache_init(
const int64_t n_mem = n_text_layer*n_ctx;
const int64_t n_elements = n_text_state*n_mem;

cache.ctx_buf.resize(2*ggml_tensor_overhead());

struct ggml_init_params params = {
/*.mem_size =*/ 2*ggml_tensor_overhead(),
/*.mem_buffer =*/ nullptr,
/*.mem_size =*/ cache.ctx_buf.size(),
/*.mem_buffer =*/ cache.ctx_buf.data(),
/*.no_alloc =*/ true,
};

Expand All @@ -953,31 +955,31 @@ static bool whisper_kv_cache_init(
cache.cells.clear();
cache.cells.resize(n_ctx);

cache.ctx = ggml_init(params);
struct ggml_context * ctx = ggml_init(params);

if (!cache.ctx) {
if (!ctx) {
WHISPER_LOG_ERROR("%s: failed to allocate memory for the kv cache context\n", __func__);
return false;
}

cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
cache.k = ggml_new_tensor_1d(ctx, wtype, n_elements);
cache.v = ggml_new_tensor_1d(ctx, wtype, n_elements);

cache.buffer = ggml_backend_alloc_ctx_tensors(cache.ctx, backend);
cache.buffer = ggml_backend_alloc_ctx_tensors(ctx, backend);
if (!cache.buffer) {
WHISPER_LOG_ERROR("%s: failed to allocate memory for the kv cache\n", __func__);
return false;
}

ggml_backend_buffer_clear(cache.buffer, 0);

ggml_free(ctx);

return true;
}

static void whisper_kv_cache_free(struct whisper_kv_cache & cache) {
ggml_free(cache.ctx);
ggml_backend_buffer_free(cache.buffer);
cache.ctx = nullptr;
}

static bool whisper_kv_cache_find_slot(
Expand Down Expand Up @@ -2002,7 +2004,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(

auto & kv_pad = wstate.kv_pad;

WHISPER_ASSERT(!!kv_pad.ctx);
WHISPER_ASSERT(!!kv_pad.buffer);

const int n_ctx_pad = GGML_PAD(n_ctx, 256);

Expand Down Expand Up @@ -2416,7 +2418,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(

auto & kv_self = wstate.kv_self;

WHISPER_ASSERT(!!kv_self.ctx);
WHISPER_ASSERT(!!kv_self.buffer);

const int n_ctx = kv_self.size;
const int n_state = hparams.n_text_state;
Expand Down
Loading