@@ -1334,28 +1334,31 @@ static void sched_compute_splits(ggml_backend_sched_t sched) {
1334
1334
//ggml_backend_synchronize(split_backend); // necessary to measure compute time
1335
1335
} else {
1336
1336
// similar to ggml_backend_compare_graph_backend
1337
- for (int j = 0 ; j < split -> graph .n_nodes ; j ++ ) {
1338
- struct ggml_tensor * t = split -> graph .nodes [j ];
1337
+ for (int j0 = 0 ; j0 < split -> graph .n_nodes ; j0 ++ ) {
1338
+ struct ggml_tensor * t = split -> graph .nodes [j0 ];
1339
1339
1340
- int k = j ;
1340
+ int j1 = j0 ;
1341
1341
1342
- // check if the user needs data from this node
1343
- while (!sched -> callback_eval (k , t , true, sched -> callback_eval_user_data ) && k < split -> graph .n_nodes - 1 ) {
1344
- t = split -> graph .nodes [++ k ];
1342
+ // determine the range [j0, j1] of nodes that can be computed together
1343
+ while (j1 < split -> graph .n_nodes - 1 ) {
1344
+ // check if the user needs data from this node
1345
+ if (sched -> callback_eval (t , true, sched -> callback_eval_user_data )) {
1346
+ break ;
1347
+ }
1348
+
1349
+ t = split -> graph .nodes [++ j1 ];
1345
1350
}
1346
1351
1347
- struct ggml_cgraph gv = ggml_graph_view (& split -> graph , j , k + 1 );
1352
+ struct ggml_cgraph gv = ggml_graph_view (& split -> graph , j0 , j1 + 1 );
1348
1353
1349
1354
ggml_backend_graph_compute (split_backend , & gv );
1350
1355
1351
- // TODO: k is node index in the split, not in the original graph
1352
- // TODO: avoid the ask == true call here
1353
- if (sched -> callback_eval (k , t , true, sched -> callback_eval_user_data ) &&
1354
- !sched -> callback_eval (k , t , false, sched -> callback_eval_user_data )) {
1356
+ if (sched -> callback_eval (t , true, sched -> callback_eval_user_data ) && // ask
1357
+ !sched -> callback_eval (t , false, sched -> callback_eval_user_data )) { // eval
1355
1358
break ;
1356
1359
}
1357
1360
1358
- j = k ;
1361
+ j0 = j1 ;
1359
1362
}
1360
1363
}
1361
1364
uint64_t compute_end_us = ggml_time_us ();
0 commit comments