Skip to content

Commit ae5a7d4

Browse files
authored
sync : ggml (ggml_scale, ggml_row_size, etc.) (ggml-org#1677)
* sync : ggml * sync : llama.cpp * talk-llama : fix obsolete param * ggml-alloc : fix ggml_tallocr_is_own * talk.wasm : update to new ggml * ggml : fix type punning in ggml_scale * ggml : cuda jetson + arm quants warnings
1 parent 56b2abd commit ae5a7d4

18 files changed

+3520
-1578
lines changed

examples/talk-llama/llama.cpp

Lines changed: 2370 additions & 1092 deletions
Large diffs are not rendered by default.

examples/talk-llama/llama.h

Lines changed: 114 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,11 @@
3939

4040
#define LLAMA_MAX_RNG_STATE (64*1024)
4141

42+
#define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla'
4243
#define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
4344

4445
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
45-
#define LLAMA_SESSION_VERSION 2
46+
#define LLAMA_SESSION_VERSION 3
4647

4748
#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) || defined(GGML_USE_METAL)
4849
// Defined when llama.cpp is compiled with support for offloading model layers to GPU.
@@ -126,7 +127,7 @@ extern "C" {
126127
bool sorted;
127128
} llama_token_data_array;
128129

129-
typedef void (*llama_progress_callback)(float progress, void *ctx);
130+
typedef bool (*llama_progress_callback)(float progress, void *ctx);
130131

131132
// Input data for llama_decode
132133
// A llama_batch object can contain input about one or many sequences
@@ -158,16 +159,38 @@ extern "C" {
158159
llama_seq_id all_seq_id; // used if seq_id == NULL
159160
} llama_batch;
160161

162+
enum llama_model_kv_override_type {
163+
LLAMA_KV_OVERRIDE_INT,
164+
LLAMA_KV_OVERRIDE_FLOAT,
165+
LLAMA_KV_OVERRIDE_BOOL,
166+
};
167+
168+
struct llama_model_kv_override {
169+
char key[128];
170+
enum llama_model_kv_override_type tag;
171+
union {
172+
int64_t int_value;
173+
double float_value;
174+
bool bool_value;
175+
};
176+
};
177+
161178
struct llama_model_params {
162179
int32_t n_gpu_layers; // number of layers to store in VRAM
163180
int32_t main_gpu; // the GPU that is used for scratch and small tensors
164181
const float * tensor_split; // how to split layers across multiple GPUs (size: LLAMA_MAX_DEVICES)
165182

166-
// called with a progress value between 0 and 1, pass NULL to disable
183+
// Called with a progress value between 0.0 and 1.0. Pass NULL to disable.
184+
// If the provided progress_callback returns true, model loading continues.
185+
// If it returns false, model loading is immediately aborted.
167186
llama_progress_callback progress_callback;
187+
168188
// context pointer passed to the progress callback
169189
void * progress_callback_user_data;
170190

191+
// override key-value pairs of the model meta data
192+
const struct llama_model_kv_override * kv_overrides;
193+
171194
// Keep the booleans together to avoid misalignment during copy-by-value.
172195
bool vocab_only; // only load the vocabulary, no weights
173196
bool use_mmap; // use mmap if possible
@@ -185,17 +208,20 @@ extern "C" {
185208
// ref: https://github.com/ggerganov/llama.cpp/pull/2054
186209
float rope_freq_base; // RoPE base frequency, 0 = from model
187210
float rope_freq_scale; // RoPE frequency scaling factor, 0 = from model
188-
float yarn_ext_factor; // YaRN extrapolation mix factor, NaN = from model
211+
float yarn_ext_factor; // YaRN extrapolation mix factor, negative = from model
189212
float yarn_attn_factor; // YaRN magnitude scaling factor
190213
float yarn_beta_fast; // YaRN low correction dim
191214
float yarn_beta_slow; // YaRN high correction dim
192215
uint32_t yarn_orig_ctx; // YaRN original context size
193216

217+
enum ggml_type type_k; // data type for K cache
218+
enum ggml_type type_v; // data type for V cache
219+
194220
// Keep the booleans together to avoid misalignment during copy-by-value.
195-
bool mul_mat_q; // if true, use experimental mul_mat_q kernels (DEPRECATED - always true)
196-
bool f16_kv; // use fp16 for KV cache, fp32 otherwise
197-
bool logits_all; // the llama_eval() call computes all logits, not just the last one
198-
bool embedding; // embedding mode only
221+
bool mul_mat_q; // if true, use experimental mul_mat_q kernels (DEPRECATED - always true)
222+
bool logits_all; // the llama_eval() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
223+
bool embedding; // embedding mode only
224+
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
199225
};
200226

201227
// model quantization parameters
@@ -290,7 +316,9 @@ extern "C" {
290316

291317
LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx);
292318

293-
LLAMA_API int llama_n_ctx (const struct llama_context * ctx);
319+
// TODO: become more consistent with returned int types across the API
320+
LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx);
321+
LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
294322

295323
LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_model * model);
296324

@@ -301,6 +329,23 @@ extern "C" {
301329
// Get the model's RoPE frequency scaling factor
302330
LLAMA_API float llama_rope_freq_scale_train(const struct llama_model * model);
303331

332+
// Functions to access the model's GGUF metadata scalar values
333+
// - The functions return the length of the string on success, or -1 on failure
334+
// - The output string is always null-terminated and cleared on failure
335+
// - GGUF array values are not supported by these functions
336+
337+
// Get metadata value as a string by key name
338+
LLAMA_API int llama_model_meta_val_str(const struct llama_model * model, const char * key, char * buf, size_t buf_size);
339+
340+
// Get the number of metadata key/value pairs
341+
LLAMA_API int llama_model_meta_count(const struct llama_model * model);
342+
343+
// Get metadata key name by index
344+
LLAMA_API int llama_model_meta_key_by_index(const struct llama_model * model, int i, char * buf, size_t buf_size);
345+
346+
// Get metadata value as a string by index
347+
LLAMA_API int llama_model_meta_val_str_by_index(const struct llama_model * model, int i, char * buf, size_t buf_size);
348+
304349
// Get a string describing the model type
305350
LLAMA_API int llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size);
306351

@@ -344,9 +389,60 @@ extern "C" {
344389
// KV cache
345390
//
346391

347-
// Returns the number of tokens in the KV cache
348-
LLAMA_API DEPRECATED(int llama_get_kv_cache_token_count(const struct llama_context * ctx),
349-
"avoid using this, it will be removed in the future, instead - count the tokens in user code");
392+
// Information associated with an individual cell in the KV cache view.
393+
struct llama_kv_cache_view_cell {
394+
// The position for this cell. Takes KV cache shifts into account.
395+
// May be negative if the cell is not populated.
396+
llama_pos pos;
397+
};
398+
399+
// An updateable view of the KV cache.
400+
struct llama_kv_cache_view {
401+
// Number of KV cache cells. This will be the same as the context size.
402+
int32_t n_cells;
403+
404+
// Maximum number of sequences that can exist in a cell. It's not an error
405+
// if there are more sequences in a cell than this value, however they will
406+
// not be visible in the view cells_sequences.
407+
int32_t n_max_seq;
408+
409+
// Number of tokens in the cache. For example, if there are two populated
410+
// cells, the first with 1 sequence id in it and the second with 2 sequence
411+
// ids then you'll have 3 tokens.
412+
int32_t token_count;
413+
414+
// Number of populated cache cells.
415+
int32_t used_cells;
416+
417+
// Maximum contiguous empty slots in the cache.
418+
int32_t max_contiguous;
419+
420+
// Index to the start of the max_contiguous slot range. Can be negative
421+
// when cache is full.
422+
int32_t max_contiguous_idx;
423+
424+
// Information for an individual cell.
425+
struct llama_kv_cache_view_cell * cells;
426+
427+
// The sequences for each cell. There will be n_max_seq items per cell.
428+
llama_seq_id * cells_sequences;
429+
};
430+
431+
// Create an empty KV cache view. (use only for debugging purposes)
432+
LLAMA_API struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_max_seq);
433+
434+
// Free a KV cache view. (use only for debugging purposes)
435+
LLAMA_API void llama_kv_cache_view_free(struct llama_kv_cache_view * view);
436+
437+
// Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes)
438+
LLAMA_API void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view);
439+
440+
// Returns the number of tokens in the KV cache (slow, use only for debug)
441+
// If a KV cell has multiple sequences assigned to it, it will be counted multiple times
442+
LLAMA_API int llama_get_kv_cache_token_count(const struct llama_context * ctx);
443+
444+
// Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
445+
LLAMA_API int llama_get_kv_cache_used_cells(const struct llama_context * ctx);
350446

351447
// Clear the KV cache
352448
LLAMA_API void llama_kv_cache_clear(
@@ -517,6 +613,12 @@ extern "C" {
517613
LLAMA_API llama_token llama_token_eos(const struct llama_model * model); // end-of-sentence
518614
LLAMA_API llama_token llama_token_nl (const struct llama_model * model); // next-line
519615

616+
// Returns -1 if unknown, 1 for true or 0 for false.
617+
LLAMA_API int llama_add_bos_token(const struct llama_model * model);
618+
619+
// Returns -1 if unknown, 1 for true or 0 for false.
620+
LLAMA_API int llama_add_eos_token(const struct llama_model * model);
621+
520622
// codellama infill tokens
521623
LLAMA_API llama_token llama_token_prefix(const struct llama_model * model); // Beginning of infill prefix
522624
LLAMA_API llama_token llama_token_middle(const struct llama_model * model); // Beginning of infill middle

examples/talk-llama/talk-llama.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,6 @@ int main(int argc, char ** argv) {
282282
// tune these to your liking
283283
lcparams.n_ctx = 2048;
284284
lcparams.seed = 1;
285-
lcparams.f16_kv = true;
286285
lcparams.n_threads = params.n_threads;
287286

288287
struct llama_context * ctx_llama = llama_new_context_with_model(model_llama, lcparams);

examples/talk.wasm/gpt-2.cpp

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -155,33 +155,33 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
155155
const int n_ctx = hparams.n_ctx;
156156
const int n_vocab = hparams.n_vocab;
157157

158-
ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_g
159-
ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_b
158+
ctx_size += ggml_row_size(GGML_TYPE_F32, n_embd); // ln_f_g
159+
ctx_size += ggml_row_size(GGML_TYPE_F32, n_embd); // ln_f_b
160160

161-
ctx_size += n_vocab*n_embd*ggml_type_sizef(wtype); // wte
162-
ctx_size += n_ctx*n_embd*ggml_type_sizef(GGML_TYPE_F32); // wpe
163-
ctx_size += n_vocab*n_embd*ggml_type_sizef(wtype); // lm_head
161+
ctx_size += n_vocab*ggml_row_size(wtype, n_embd); // wte
162+
ctx_size += n_ctx*ggml_row_size(GGML_TYPE_F32, n_embd); // wpe
163+
ctx_size += n_vocab*ggml_row_size(wtype, n_embd); // lm_head
164164

165-
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_g
166-
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_b
165+
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_1_g
166+
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_1_b
167167

168-
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_2_g
169-
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_2_b
168+
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_2_g
169+
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_2_b
170170

171-
ctx_size += n_layer*(3*n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_attn_w
172-
ctx_size += n_layer*( 3*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_attn_attn_b
171+
ctx_size += n_layer*(ggml_row_size(wtype, 3*n_embd*n_embd)); // c_attn_attn_w
172+
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, 3*n_embd)); // c_attn_attn_b
173173

174-
ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_proj_w
175-
ctx_size += n_layer*( n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_attn_proj_b
174+
ctx_size += n_layer*(ggml_row_size(wtype, n_embd*n_embd)); // c_attn_proj_w
175+
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // c_attn_proj_b
176176

177-
ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype)); // c_mlp_fc_w
178-
ctx_size += n_layer*( 4*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_fc_b
177+
ctx_size += n_layer*(ggml_row_size(wtype, 4*n_embd*n_embd)); // c_mlp_fc_w
178+
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, 4*n_embd)); // c_mlp_fc_b
179179

180-
ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype)); // c_mlp_proj_w
181-
ctx_size += n_layer*( n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_proj_b
180+
ctx_size += n_layer*(ggml_row_size(wtype, 4*n_embd*n_embd)); // c_mlp_proj_w
181+
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // c_mlp_proj_b
182182

183-
ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_k
184-
ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_v
183+
ctx_size += n_ctx*n_layer*ggml_row_size(GGML_TYPE_F32, n_embd); // memory_k
184+
ctx_size += n_ctx*n_layer*ggml_row_size(GGML_TYPE_F32, n_embd); // memory_v
185185

186186
ctx_size += (6 + 12*n_layer)*256; // object overhead
187187

@@ -524,8 +524,7 @@ bool gpt2_eval(
524524
struct ggml_tensor * KQ_scaled =
525525
ggml_scale(ctx0,
526526
KQ,
527-
ggml_new_f32(ctx0, 1.0f/sqrt(float(n_embd)/n_head))
528-
);
527+
1.0f/sqrt(float(n_embd)/n_head));
529528

530529
// KQ_masked = mask_past(KQ_scaled)
531530
// [n_past + N, N, 12]

examples/talk/gpt-2.cpp

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -155,33 +155,33 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
155155
const int n_ctx = hparams.n_ctx;
156156
const int n_vocab = hparams.n_vocab;
157157

158-
ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_g
159-
ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_b
158+
ctx_size += ggml_row_size(GGML_TYPE_F32, n_embd); // ln_f_g
159+
ctx_size += ggml_row_size(GGML_TYPE_F32, n_embd); // ln_f_b
160160

161-
ctx_size += n_vocab*n_embd*ggml_type_sizef(wtype); // wte
162-
ctx_size += n_ctx*n_embd*ggml_type_sizef(GGML_TYPE_F32); // wpe
163-
ctx_size += n_vocab*n_embd*ggml_type_sizef(wtype); // lm_head
161+
ctx_size += n_vocab*ggml_row_size(wtype, n_embd); // wte
162+
ctx_size += n_ctx*ggml_row_size(GGML_TYPE_F32, n_embd); // wpe
163+
ctx_size += n_vocab*ggml_row_size(wtype, n_embd); // lm_head
164164

165-
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_g
166-
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_b
165+
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_1_g
166+
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_1_b
167167

168-
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_2_g
169-
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_2_b
168+
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_2_g
169+
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_2_b
170170

171-
ctx_size += n_layer*(3*n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_attn_w
172-
ctx_size += n_layer*( 3*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_attn_attn_b
171+
ctx_size += n_layer*(ggml_row_size(wtype, 3*n_embd*n_embd)); // c_attn_attn_w
172+
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, 3*n_embd)); // c_attn_attn_b
173173

174-
ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_proj_w
175-
ctx_size += n_layer*( n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_attn_proj_b
174+
ctx_size += n_layer*(ggml_row_size(wtype, n_embd*n_embd)); // c_attn_proj_w
175+
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // c_attn_proj_b
176176

177-
ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype)); // c_mlp_fc_w
178-
ctx_size += n_layer*( 4*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_fc_b
177+
ctx_size += n_layer*(ggml_row_size(wtype, 4*n_embd*n_embd)); // c_mlp_fc_w
178+
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, 4*n_embd)); // c_mlp_fc_b
179179

180-
ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype)); // c_mlp_proj_w
181-
ctx_size += n_layer*( n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_proj_b
180+
ctx_size += n_layer*(ggml_row_size(wtype, 4*n_embd*n_embd)); // c_mlp_proj_w
181+
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // c_mlp_proj_b
182182

183-
ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_k
184-
ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_v
183+
ctx_size += n_ctx*n_layer*ggml_row_size(GGML_TYPE_F32, n_embd); // memory_k
184+
ctx_size += n_ctx*n_layer*ggml_row_size(GGML_TYPE_F32, n_embd); // memory_v
185185

186186
ctx_size += (6 + 12*n_layer)*256; // object overhead
187187

@@ -525,8 +525,7 @@ bool gpt2_eval(
525525
struct ggml_tensor * KQ_scaled =
526526
ggml_scale(ctx0,
527527
KQ,
528-
ggml_new_f32(ctx0, 1.0f/sqrt(float(n_embd)/n_head))
529-
);
528+
1.0f/sqrt(float(n_embd)/n_head));
530529

531530
// KQ_masked = mask_past(KQ_scaled)
532531
// [n_past + N, N, 12]

extra/sync-llama.sh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#!/bin/bash
2+
3+
cp -rpv ../llama.cpp/llama.h ./examples/talk-llama/llama.h
4+
cp -rpv ../llama.cpp/llama.cpp ./examples/talk-llama/llama.cpp
5+
cp -rpv ../llama.cpp/unicode.h ./examples/talk-llama/unicode.h

ggml-alloc.c

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ static void remove_allocated_tensor(ggml_tallocr_t alloc, struct ggml_tensor * t
7272

7373
// check if a tensor is allocated by this buffer
7474
static bool ggml_tallocr_is_own(ggml_tallocr_t alloc, const struct ggml_tensor * tensor) {
75-
return tensor->buffer == alloc->buffer;
75+
return tensor->buffer == alloc->buffer && (!tensor->view_src || tensor->view_src->buffer == alloc->buffer);
7676
}
7777

7878
static bool ggml_is_view(struct ggml_tensor * t) {
@@ -449,11 +449,10 @@ static void init_view(ggml_gallocr_t galloc, struct ggml_tensor * view, bool upd
449449
if (update_backend) {
450450
view->backend = view->view_src->backend;
451451
}
452-
view->buffer = view->view_src->buffer;
452+
// views are initialized in the alloc buffer rather than the view_src buffer
453+
view->buffer = alloc->buffer;
453454
view->data = (char *)view->view_src->data + view->view_offs;
454455

455-
// FIXME: the view should be initialized by the owning buffer, but currently this breaks the CUDA backend
456-
// due to the ggml_tensor_extra_gpu ring buffer overwriting the KV cache extras
457456
assert(ggml_tallocr_is_measure(alloc) || !view->buffer || view->buffer->buft == alloc->buffer->buft);
458457

459458
if (!alloc->measure) {
@@ -736,6 +735,10 @@ void ggml_allocr_set_parse_seq(ggml_allocr_t alloc, const int * list, int n) {
736735
}
737736

738737
void ggml_allocr_free(ggml_allocr_t alloc) {
738+
if (alloc == NULL) {
739+
return;
740+
}
741+
739742
ggml_gallocr_free(alloc->galloc);
740743
ggml_tallocr_free(alloc->talloc);
741744
free(alloc);
@@ -775,7 +778,7 @@ ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_conte
775778
}
776779

777780
if (nbytes == 0) {
778-
fprintf(stderr, "%s: no tensors to allocate\n", __func__);
781+
// all the tensors in the context are already allocated
779782
return NULL;
780783
}
781784

@@ -789,6 +792,11 @@ ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_conte
789792
} else {
790793
ggml_backend_view_init(buffer, t);
791794
}
795+
} else {
796+
if (t->view_src != NULL) {
797+
// view of a pre-allocated tensor
798+
ggml_backend_view_init(buffer, t);
799+
}
792800
}
793801
}
794802

0 commit comments

Comments
 (0)