Skip to content

Commit 4480542

Browse files
baby-llama : allocate graphs in ggml_context (#5573)
* Fixed the baby-llama issue (see issue #4830) * minor : fix whitespaces --------- Co-authored-by: Georgi Gerganov <[email protected]>
1 parent 11b12de commit 4480542

File tree

1 file changed

+12
-10
lines changed

1 file changed

+12
-10
lines changed

examples/baby-llama/baby-llama.cpp

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1533,16 +1533,17 @@ int main(int argc, char ** argv) {
15331533

15341534
int n_past = 0;
15351535

1536-
ggml_cgraph gf = {};
1536+
struct ggml_cgraph * gf = NULL;
1537+
gf = ggml_new_graph_custom(ctx0, LLAMA_TRAIN_MAX_NODES, true);
15371538

15381539
get_example_targets_batch(ctx0, 64*ex+0, tokens_input, targets);
15391540

1540-
struct ggml_tensor * logits = forward_batch(&model, &kv_self, ctx0, &gf, tokens_input, n_tokens, n_past, n_batch);
1541+
struct ggml_tensor * logits = forward_batch(&model, &kv_self, ctx0, gf, tokens_input, n_tokens, n_past, n_batch);
15411542
// struct ggml_tensor * e = cross_entropy_loss(ctx0, targets, logits);
15421543
struct ggml_tensor * e = square_error_loss(ctx0, targets, logits);
15431544

1544-
ggml_build_forward_expand(&gf, e);
1545-
ggml_graph_compute_helper(work_buffer, &gf, /*n_threads*/ 1);
1545+
ggml_build_forward_expand(gf, e);
1546+
ggml_graph_compute_helper(work_buffer, gf, /*n_threads*/ 1);
15461547

15471548
float error_before_opt = ggml_get_f32_1d(e, 0);
15481549

@@ -1552,8 +1553,8 @@ int main(int argc, char ** argv) {
15521553
opt_params_lbfgs.lbfgs.n_iter = 16;
15531554
ggml_opt(ctx0, opt_params_lbfgs, e);
15541555
//
1555-
ggml_build_forward_expand(&gf, e);
1556-
ggml_graph_compute_helper(work_buffer, &gf, /*n_threads*/ 1);
1556+
ggml_build_forward_expand(gf, e);
1557+
ggml_graph_compute_helper(work_buffer, gf, /*n_threads*/ 1);
15571558

15581559
float error_after_opt = ggml_get_f32_1d(e, 0);
15591560

@@ -1600,13 +1601,14 @@ int main(int argc, char ** argv) {
16001601
};
16011602
struct ggml_context * ctx0 = ggml_init(params);
16021603

1603-
ggml_cgraph gf = {};
1604+
struct ggml_cgraph * gf = NULL;
1605+
gf = ggml_new_graph_custom(ctx0, LLAMA_TRAIN_MAX_NODES, true);
16041606

16051607
int n_past = 0;
1606-
struct ggml_tensor * logits = forward(&model, &kv_self, ctx0, &gf, tokens_input, sample_ctx, n_past);
1608+
struct ggml_tensor * logits = forward(&model, &kv_self, ctx0, gf, tokens_input, sample_ctx, n_past);
16071609

1608-
ggml_build_forward_expand(&gf, logits);
1609-
ggml_graph_compute_helper(work_buffer, &gf, /*n_threads*/ 1);
1610+
ggml_build_forward_expand(gf, logits);
1611+
ggml_graph_compute_helper(work_buffer, gf, /*n_threads*/ 1);
16101612

16111613
struct ggml_tensor * best_samples = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, sample_ctx);
16121614
struct ggml_tensor * probs = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_vocab, sample_ctx);

0 commit comments

Comments
 (0)