Skip to content

Commit f885540

Browse files
authored
correct group selection and rename n_group_exp
1 parent a3904c0 commit f885540

File tree

3 files changed

+11
-11
lines changed

3 files changed

+11
-11
lines changed

src/llama-graph.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -928,24 +928,24 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
928928
cb(selection_probs, "ffn_moe_probs_biased", il);
929929
}
930930

931-
// select top n_group_exp expert groups
931+
// select top n_group_used expert groups
932932
if (arch == LLM_ARCH_BAILINGMOE2) {
933933
const int64_t n_exp_per_group = n_expert / hparams.n_expert_groups;
934934

935935
// organize experts into n_expert_groups
936-
ggml_tensor * selection_groups = ggml_view_2d(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, selection_probs)), n_tokens * n_exp_per_group, hparams.n_expert_groups, n_tokens * n_exp_per_group * sizeof(float), 0); // [n_tokens, n_expert_groups]
936+
ggml_tensor * selection_groups = ggml_view_2d(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, selection_probs)), n_tokens * n_exp_per_group, hparams.n_expert_groups, n_tokens * n_exp_per_group * sizeof(float), 0); // [n_tokens * n_exp_per_group, n_expert_groups]
937937
ggml_tensor * group_scores = ggml_top_k(ctx0, selection_groups, 2); // [2, n_expert_groups]
938+
group_scores = ggml_get_rows(ctx0, ggml_reshape_3d(ctx0, selection_groups, 1, selection_groups->ne[0], selection_groups->ne[1]), group_scores); // [1, 2, n_expert_groups]
938939

939-
// get top n_group_exp expert groups
940-
group_scores = ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cast(ctx0, group_scores, GGML_TYPE_F32))); // [n_expert_groups, 1]
941-
ggml_tensor * expert_groups = ggml_top_k(ctx0, ggml_cont(ctx0, group_scores), hparams.n_group_exp); // [n_group_exp, 1]
940+
// get top n_group_used expert groups
941+
group_scores = ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_reshape_2d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2]))); // [n_expert_groups, 1]
942+
ggml_tensor * expert_groups = ggml_top_k(ctx0, ggml_cont(ctx0, group_scores), hparams.n_group_used); // [n_group_used, 1]
942943
cb(expert_groups->src[0], "ffn_moe_group_argsort", il);
943944
cb(expert_groups, "ffn_moe_group_topk", il);
944945

945946
// mask out the other groups
946-
selection_probs = ggml_scale_bias(ctx0, selection_groups, 0.0f, -INFINITY);
947-
group_scores = ggml_repeat_4d(ctx0, expert_groups, selection_probs->ne[1], 1, 1, 1); // [n_expert_groups, 1]
948-
selection_probs = ggml_set_rows(ctx0, selection_probs, selection_groups, group_scores); // [n_tokens, n_expert_groups]
947+
selection_probs = ggml_get_rows(ctx0, selection_groups, expert_groups); // [n_tokens * n_exp_per_group, n_group_used]
948+
selection_probs = ggml_set_rows(ctx0, ggml_scale_bias(ctx0, selection_groups, 0.0f, -INFINITY), selection_probs, expert_groups); // [n_tokens * n_exp_per_group, n_expert_groups]
949949
selection_probs = ggml_view_2d(ctx0, selection_probs, n_tokens, n_expert, n_tokens * sizeof(float), 0); // [n_tokens, n_expert]
950950
selection_probs = ggml_cont(ctx0, ggml_transpose(ctx0, selection_probs)); // [n_expert, n_tokens]
951951
cb(selection_probs, "ffn_moe_probs_masked", il);

src/llama-hparams.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ struct llama_hparams {
7373
uint32_t n_expert_shared = 0;
7474
uint32_t n_norm_groups = 0;
7575
uint32_t n_expert_groups = 0;
76-
uint32_t n_group_exp = 0;
76+
uint32_t n_group_used = 0;
7777
uint32_t n_group_experts = 0;
7878

7979
float expert_group_scale = 0.05f;

src/llama-model.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1877,7 +1877,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
18771877
ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp);
18781878
ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared);
18791879
ml.get_key(LLM_KV_EXPERT_GROUP_COUNT, hparams.n_expert_groups);
1880-
ml.get_key(LLM_KV_EXPERT_GROUP_USED_COUNT, hparams.n_group_exp);
1880+
ml.get_key(LLM_KV_EXPERT_GROUP_USED_COUNT, hparams.n_group_used);
18811881
ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale);
18821882
ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false);
18831883
ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func);
@@ -6338,7 +6338,7 @@ void llama_model::print_info() const {
63386338
LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp);
63396339
LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared);
63406340
LLAMA_LOG_INFO("%s: n_expert_groups = %d\n", __func__, hparams.n_expert_groups);
6341-
LLAMA_LOG_INFO("%s: n_group_exp = %d\n", __func__, hparams.n_group_exp);
6341+
LLAMA_LOG_INFO("%s: n_group_used = %d\n", __func__, hparams.n_group_used);
63426342
LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale);
63436343
LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm);
63446344
LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func));

0 commit comments

Comments
 (0)