@@ -928,24 +928,24 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
928
928
cb (selection_probs, " ffn_moe_probs_biased" , il);
929
929
}
930
930
931
- // select top n_group_exp expert groups
931
+ // select top n_group_used expert groups
932
932
if (arch == LLM_ARCH_BAILINGMOE2) {
933
933
const int64_t n_exp_per_group = n_expert / hparams.n_expert_groups ;
934
934
935
935
// organize experts into n_expert_groups
936
- ggml_tensor * selection_groups = ggml_view_2d (ctx0, ggml_cont (ctx0, ggml_transpose (ctx0, selection_probs)), n_tokens * n_exp_per_group, hparams.n_expert_groups , n_tokens * n_exp_per_group * sizeof (float ), 0 ); // [n_tokens, n_expert_groups]
936
+ ggml_tensor * selection_groups = ggml_view_2d (ctx0, ggml_cont (ctx0, ggml_transpose (ctx0, selection_probs)), n_tokens * n_exp_per_group, hparams.n_expert_groups , n_tokens * n_exp_per_group * sizeof (float ), 0 ); // [n_tokens * n_exp_per_group , n_expert_groups]
937
937
ggml_tensor * group_scores = ggml_top_k (ctx0, selection_groups, 2 ); // [2, n_expert_groups]
938
+ group_scores = ggml_get_rows (ctx0, ggml_reshape_3d (ctx0, selection_groups, 1 , selection_groups->ne [0 ], selection_groups->ne [1 ]), group_scores); // [1, 2, n_expert_groups]
938
939
939
- // get top n_group_exp expert groups
940
- group_scores = ggml_transpose (ctx0, ggml_sum_rows (ctx0, ggml_cast (ctx0, group_scores, GGML_TYPE_F32 ))); // [n_expert_groups, 1]
941
- ggml_tensor * expert_groups = ggml_top_k (ctx0, ggml_cont (ctx0, group_scores), hparams.n_group_exp ); // [n_group_exp , 1]
940
+ // get top n_group_used expert groups
941
+ group_scores = ggml_transpose (ctx0, ggml_sum_rows (ctx0, ggml_reshape_2d (ctx0, group_scores, group_scores-> ne [ 1 ], group_scores-> ne [ 2 ] ))); // [n_expert_groups, 1]
942
+ ggml_tensor * expert_groups = ggml_top_k (ctx0, ggml_cont (ctx0, group_scores), hparams.n_group_used ); // [n_group_used , 1]
942
943
cb (expert_groups->src [0 ], " ffn_moe_group_argsort" , il);
943
944
cb (expert_groups, " ffn_moe_group_topk" , il);
944
945
945
946
// mask out the other groups
946
- selection_probs = ggml_scale_bias (ctx0, selection_groups, 0 .0f , -INFINITY);
947
- group_scores = ggml_repeat_4d (ctx0, expert_groups, selection_probs->ne [1 ], 1 , 1 , 1 ); // [n_expert_groups, 1]
948
- selection_probs = ggml_set_rows (ctx0, selection_probs, selection_groups, group_scores); // [n_tokens, n_expert_groups]
947
+ selection_probs = ggml_get_rows (ctx0, selection_groups, expert_groups); // [n_tokens * n_exp_per_group, n_group_used]
948
+ selection_probs = ggml_set_rows (ctx0, ggml_scale_bias (ctx0, selection_groups, 0 .0f , -INFINITY), selection_probs, expert_groups); // [n_tokens * n_exp_per_group, n_expert_groups]
949
949
selection_probs = ggml_view_2d (ctx0, selection_probs, n_tokens, n_expert, n_tokens * sizeof (float ), 0 ); // [n_tokens, n_expert]
950
950
selection_probs = ggml_cont (ctx0, ggml_transpose (ctx0, selection_probs)); // [n_expert, n_tokens]
951
951
cb (selection_probs, " ffn_moe_probs_masked" , il);
0 commit comments