Skip to content

Commit c982ba7

Browse files
CISCpwilkin
authored andcommitted
model : add GroveMoE support (#15510)
* add GroveMoE support * remove constexpr that fails on certain compilers * revert crude scalar div implementation, use cast * build_attn_inp_kv_unified -> build_attn_inp_kv * fix build_attn * re-apply ffn_exps regex changes
1 parent c638b8b commit c982ba7

File tree

11 files changed

+451
-4
lines changed

11 files changed

+451
-4
lines changed

common/common.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -738,7 +738,7 @@ const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count";
738738
// MoE utils
739739
//
740740

741-
const char * const LLM_FFN_EXPS_REGEX = "\\.ffn_(up|down|gate)_exps";
741+
const char * const LLM_FFN_EXPS_REGEX = "\\.ffn_(up|down|gate)_(ch|)exps";
742742

743743
static std::string llm_ffn_exps_block_regex(int idx) {
744744
return string_format("blk\\.%d%s", idx, LLM_FFN_EXPS_REGEX);

convert_hf_to_gguf.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7995,6 +7995,121 @@ def prepare_tensors(self):
79957995
raise ValueError(f"Unprocessed experts: {experts}")
79967996

79977997

7998+
@ModelBase.register("GroveMoeForCausalLM", "modeling_grove_moe.GroveMoeForCausalLM")
7999+
class GroveMoeModel(TextModel):
8000+
model_arch = gguf.MODEL_ARCH.GROVEMOE
8001+
8002+
def set_gguf_parameters(self):
8003+
super().set_gguf_parameters()
8004+
if (n_experts := self.hparams.get("num_experts")) is not None:
8005+
self.gguf_writer.add_expert_count(n_experts)
8006+
if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None:
8007+
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
8008+
logger.info(f"gguf: expert feed forward length = {moe_intermediate_size}")
8009+
# FIXME?: Hardcoded https://huggingface.co/inclusionAI/GroveMoE-Inst/blob/c4c69e5970d18907b5e6ddccdfd55176fe292df1/modeling_grove_moe.py#L299
8010+
self.gguf_writer.add_expert_chunk_feed_forward_length(self.hparams.get("head_dim") or 128)
8011+
# FIXME?: Hardcoded https://huggingface.co/inclusionAI/GroveMoE-Inst/blob/c4c69e5970d18907b5e6ddccdfd55176fe292df1/modeling_grove_moe.py#L298
8012+
self.gguf_writer.add_experts_per_group(2)
8013+
# FIXME?: Hardcoded https://huggingface.co/inclusionAI/GroveMoE-Inst/blob/c4c69e5970d18907b5e6ddccdfd55176fe292df1/modeling_grove_moe.py#L376
8014+
self.gguf_writer.add_expert_group_scale(0.05)
8015+
# YaRN is not enabled by default
8016+
# To enable it, please refer to this guide: https://huggingface.co/Qwen/Qwen3-30B-A3B#processing-long-texts
8017+
rope_scaling = self.hparams.get("rope_scaling") or {}
8018+
if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling:
8019+
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
8020+
self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
8021+
self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"])
8022+
8023+
_experts: list[dict[str, Tensor]] | None = None
8024+
_chunk_experts: list[dict[str, Tensor]] | None = None
8025+
8026+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
8027+
if name.endswith(".expert_bias"):
8028+
# FIXME?: Unused https://huggingface.co/inclusionAI/GroveMoE-Inst/blob/c4c69e5970d18907b5e6ddccdfd55176fe292df1/modeling_grove_moe.py#L303
8029+
return []
8030+
8031+
# process the experts separately
8032+
if name.find("chunk_experts") != -1:
8033+
n_experts = self.hparams["num_experts"] // 2 # see add_experts_per_group
8034+
assert bid is not None
8035+
8036+
if self._chunk_experts is None:
8037+
self._chunk_experts = [{} for _ in range(self.block_count)]
8038+
8039+
self._chunk_experts[bid][name] = data_torch
8040+
8041+
if len(self._chunk_experts[bid]) >= n_experts * 3:
8042+
tensors: list[tuple[str, Tensor]] = []
8043+
8044+
# merge the experts into a single 3d tensor
8045+
for w_name in ["down_proj", "gate_proj", "up_proj"]:
8046+
datas: list[Tensor] = []
8047+
8048+
for xid in range(n_experts):
8049+
ename = f"model.layers.{bid}.mlp.chunk_experts.{xid}.{w_name}.weight"
8050+
datas.append(self._chunk_experts[bid][ename])
8051+
del self._chunk_experts[bid][ename]
8052+
8053+
data_torch = torch.stack(datas, dim=0)
8054+
8055+
merged_name = f"model.layers.{bid}.mlp.chunk_experts.{w_name}.weight"
8056+
8057+
new_name = self.map_tensor_name(merged_name)
8058+
8059+
tensors.append((new_name, data_torch))
8060+
return tensors
8061+
else:
8062+
return []
8063+
elif name.find("experts") != -1:
8064+
n_experts = self.hparams["num_experts"]
8065+
assert bid is not None
8066+
8067+
if self._experts is None:
8068+
self._experts = [{} for _ in range(self.block_count)]
8069+
8070+
self._experts[bid][name] = data_torch
8071+
8072+
if len(self._experts[bid]) >= n_experts * 3:
8073+
tensors: list[tuple[str, Tensor]] = []
8074+
8075+
# merge the experts into a single 3d tensor
8076+
for w_name in ["down_proj", "gate_proj", "up_proj"]:
8077+
datas: list[Tensor] = []
8078+
8079+
for xid in range(n_experts):
8080+
ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
8081+
datas.append(self._experts[bid][ename])
8082+
del self._experts[bid][ename]
8083+
8084+
data_torch = torch.stack(datas, dim=0)
8085+
8086+
merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
8087+
8088+
new_name = self.map_tensor_name(merged_name)
8089+
8090+
tensors.append((new_name, data_torch))
8091+
return tensors
8092+
else:
8093+
return []
8094+
8095+
return [(self.map_tensor_name(name), data_torch)]
8096+
8097+
def prepare_tensors(self):
8098+
super().prepare_tensors()
8099+
8100+
if self._chunk_experts is not None:
8101+
# flatten `list[dict[str, Tensor]]` into `list[str]`
8102+
chunk_experts = [k for d in self._chunk_experts for k in d.keys()]
8103+
if len(chunk_experts) > 0:
8104+
raise ValueError(f"Unprocessed adjugate experts: {chunk_experts}")
8105+
8106+
if self._experts is not None:
8107+
# flatten `list[dict[str, Tensor]]` into `list[str]`
8108+
experts = [k for d in self._experts for k in d.keys()]
8109+
if len(experts) > 0:
8110+
raise ValueError(f"Unprocessed experts: {experts}")
8111+
8112+
79988113
@ModelBase.register("ChameleonForConditionalGeneration")
79998114
@ModelBase.register("ChameleonForCausalLM") # obsolete
80008115
class ChameleonModel(TextModel):

gguf-py/gguf/constants.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ class LLM:
9696
FEED_FORWARD_LENGTH = "{arch}.feed_forward_length"
9797
EXPERT_FEED_FORWARD_LENGTH = "{arch}.expert_feed_forward_length"
9898
EXPERT_SHARED_FEED_FORWARD_LENGTH = "{arch}.expert_shared_feed_forward_length"
99+
EXPERT_CHUNK_FEED_FORWARD_LENGTH = "{arch}.expert_chunk_feed_forward_length"
99100
USE_PARALLEL_RESIDUAL = "{arch}.use_parallel_residual"
100101
TENSOR_DATA_LAYOUT = "{arch}.tensor_data_layout"
101102
EXPERT_COUNT = "{arch}.expert_count"
@@ -104,6 +105,8 @@ class LLM:
104105
EXPERT_WEIGHTS_SCALE = "{arch}.expert_weights_scale"
105106
EXPERT_WEIGHTS_NORM = "{arch}.expert_weights_norm"
106107
EXPERT_GATING_FUNC = "{arch}.expert_gating_func"
108+
EXPERT_GROUP_SCALE = "{arch}.expert_group_scale"
109+
EXPERTS_PER_GROUP = "{arch}.experts_per_group"
107110
MOE_EVERY_N_LAYERS = "{arch}.moe_every_n_layers"
108111
NEXTN_PREDICT_LAYERS = "{arch}.nextn_predict_layers"
109112
POOLING_TYPE = "{arch}.pooling_type"
@@ -409,6 +412,7 @@ class MODEL_ARCH(IntEnum):
409412
LLADA_MOE = auto()
410413
SEED_OSS = auto()
411414
APERTUS = auto()
415+
GROVEMOE = auto()
412416

413417

414418
class VISION_PROJECTOR_TYPE(IntEnum):
@@ -458,6 +462,9 @@ class MODEL_TENSOR(IntEnum):
458462
FFN_GATE_SHEXP = auto()
459463
FFN_DOWN_SHEXP = auto()
460464
FFN_UP_SHEXP = auto()
465+
FFN_GATE_CHEXP = auto()
466+
FFN_DOWN_CHEXP = auto()
467+
FFN_UP_CHEXP = auto()
461468
FFN_EXP_PROBS_B = auto()
462469
ATTN_Q_NORM = auto()
463470
ATTN_K_NORM = auto()
@@ -747,6 +754,7 @@ class MODEL_TENSOR(IntEnum):
747754
MODEL_ARCH.LLADA_MOE: "llada-moe",
748755
MODEL_ARCH.SEED_OSS: "seed_oss",
749756
MODEL_ARCH.APERTUS: "apertus",
757+
MODEL_ARCH.GROVEMOE: "grovemoe",
750758
}
751759

752760
VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
@@ -793,6 +801,9 @@ class MODEL_TENSOR(IntEnum):
793801
MODEL_TENSOR.FFN_GATE_SHEXP: "blk.{bid}.ffn_gate_shexp",
794802
MODEL_TENSOR.FFN_DOWN_SHEXP: "blk.{bid}.ffn_down_shexp",
795803
MODEL_TENSOR.FFN_UP_SHEXP: "blk.{bid}.ffn_up_shexp",
804+
MODEL_TENSOR.FFN_GATE_CHEXP: "blk.{bid}.ffn_gate_chexps",
805+
MODEL_TENSOR.FFN_DOWN_CHEXP: "blk.{bid}.ffn_down_chexps",
806+
MODEL_TENSOR.FFN_UP_CHEXP: "blk.{bid}.ffn_up_chexps",
796807
MODEL_TENSOR.FFN_ACT: "blk.{bid}.ffn",
797808
MODEL_TENSOR.FFN_NORM_EXP: "blk.{bid}.ffn_norm_exps",
798809
MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate_exps",
@@ -2739,6 +2750,26 @@ class MODEL_TENSOR(IntEnum):
27392750
MODEL_TENSOR.FFN_UP_EXP,
27402751
MODEL_TENSOR.FFN_DOWN_EXP,
27412752
],
2753+
MODEL_ARCH.GROVEMOE: [
2754+
MODEL_TENSOR.TOKEN_EMBD,
2755+
MODEL_TENSOR.OUTPUT_NORM,
2756+
MODEL_TENSOR.OUTPUT,
2757+
MODEL_TENSOR.ATTN_NORM,
2758+
MODEL_TENSOR.ATTN_Q,
2759+
MODEL_TENSOR.ATTN_Q_NORM,
2760+
MODEL_TENSOR.ATTN_K,
2761+
MODEL_TENSOR.ATTN_K_NORM,
2762+
MODEL_TENSOR.ATTN_V,
2763+
MODEL_TENSOR.ATTN_OUT,
2764+
MODEL_TENSOR.FFN_NORM,
2765+
MODEL_TENSOR.FFN_GATE_INP,
2766+
MODEL_TENSOR.FFN_GATE_EXP,
2767+
MODEL_TENSOR.FFN_DOWN_EXP,
2768+
MODEL_TENSOR.FFN_UP_EXP,
2769+
MODEL_TENSOR.FFN_GATE_CHEXP,
2770+
MODEL_TENSOR.FFN_DOWN_CHEXP,
2771+
MODEL_TENSOR.FFN_UP_CHEXP,
2772+
],
27422773
# TODO
27432774
}
27442775

gguf-py/gguf/gguf_writer.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -670,6 +670,9 @@ def add_expert_feed_forward_length(self, length: int) -> None:
670670
def add_expert_shared_feed_forward_length(self, length: int) -> None:
671671
self.add_uint32(Keys.LLM.EXPERT_SHARED_FEED_FORWARD_LENGTH.format(arch=self.arch), length)
672672

673+
def add_expert_chunk_feed_forward_length(self, length: int) -> None:
674+
self.add_uint32(Keys.LLM.EXPERT_CHUNK_FEED_FORWARD_LENGTH.format(arch=self.arch), length)
675+
673676
def add_parallel_residual(self, use: bool) -> None:
674677
self.add_bool(Keys.LLM.USE_PARALLEL_RESIDUAL.format(arch=self.arch), use)
675678

@@ -757,6 +760,12 @@ def add_expert_weights_norm(self, value: bool) -> None:
757760
def add_expert_gating_func(self, value: ExpertGatingFuncType) -> None:
758761
self.add_uint32(Keys.LLM.EXPERT_GATING_FUNC.format(arch=self.arch), value.value)
759762

763+
def add_expert_group_scale(self, value: float) -> None:
764+
self.add_float32(Keys.LLM.EXPERT_GROUP_SCALE.format(arch=self.arch), value)
765+
766+
def add_experts_per_group(self, count: int) -> None:
767+
self.add_uint32(Keys.LLM.EXPERTS_PER_GROUP.format(arch=self.arch), count)
768+
760769
def add_moe_every_n_layers(self, value: int) -> None:
761770
self.add_uint32(Keys.LLM.MOE_EVERY_N_LAYERS.format(arch=self.arch), value)
762771

gguf-py/gguf/tensor_mapping.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,10 @@ class TensorNameMap:
429429
"model.layers.{bid}.mlp.shared_mlp.up_proj", # hunyuan
430430
),
431431

432+
MODEL_TENSOR.FFN_UP_CHEXP: (
433+
"model.layers.{bid}.mlp.chunk_experts.up_proj", # grovemoe
434+
),
435+
432436
# AWQ-activation gate
433437
MODEL_TENSOR.FFN_ACT: (
434438
"transformer.blocks.{bid}.ffn.act", # mpt
@@ -470,6 +474,10 @@ class TensorNameMap:
470474
"model.layers.{bid}.mlp.shared_mlp.gate_proj", # hunyuan
471475
),
472476

477+
MODEL_TENSOR.FFN_GATE_CHEXP: (
478+
"model.layers.{bid}.mlp.chunk_experts.gate_proj", # grovemoe
479+
),
480+
473481
# Feed-forward down
474482
MODEL_TENSOR.FFN_DOWN: (
475483
"gpt_neox.layers.{bid}.mlp.dense_4h_to_h", # gptneox
@@ -526,6 +534,10 @@ class TensorNameMap:
526534
"model.layers.{bid}.mlp.shared_mlp.down_proj", # hunyuan
527535
),
528536

537+
MODEL_TENSOR.FFN_DOWN_CHEXP: (
538+
"model.layers.{bid}.mlp.chunk_experts.down_proj", # grovemoe
539+
),
540+
529541
MODEL_TENSOR.ATTN_Q_NORM: (
530542
"language_model.encoder.layers.{bid}.self_attention.q_layernorm",
531543
"model.layers.{bid}.self_attn.q_layernorm", # persimmon

src/llama-arch.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
9999
{ LLM_ARCH_LLADA_MOE, "llada-moe" },
100100
{ LLM_ARCH_SEED_OSS, "seed_oss" },
101101
{ LLM_ARCH_APERTUS, "apertus" },
102+
{ LLM_ARCH_GROVEMOE, "grovemoe" },
102103
{ LLM_ARCH_UNKNOWN, "(unknown)" },
103104
};
104105

@@ -126,6 +127,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
126127
{ LLM_KV_FEED_FORWARD_LENGTH, "%s.feed_forward_length" },
127128
{ LLM_KV_EXPERT_FEED_FORWARD_LENGTH, "%s.expert_feed_forward_length" },
128129
{ LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, "%s.expert_shared_feed_forward_length" },
130+
{ LLM_KV_EXPERT_CHUNK_FEED_FORWARD_LENGTH, "%s.expert_chunk_feed_forward_length" },
129131
{ LLM_KV_USE_PARALLEL_RESIDUAL, "%s.use_parallel_residual" },
130132
{ LLM_KV_TENSOR_DATA_LAYOUT, "%s.tensor_data_layout" },
131133
{ LLM_KV_EXPERT_COUNT, "%s.expert_count" },
@@ -134,6 +136,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
134136
{ LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" },
135137
{ LLM_KV_EXPERT_WEIGHTS_NORM, "%s.expert_weights_norm" },
136138
{ LLM_KV_EXPERT_GATING_FUNC, "%s.expert_gating_func" },
139+
{ LLM_KV_EXPERT_GROUP_SCALE, "%s.expert_group_scale" },
140+
{ LLM_KV_EXPERTS_PER_GROUP, "%s.experts_per_group" },
137141
{ LLM_KV_MOE_EVERY_N_LAYERS, "%s.moe_every_n_layers" },
138142
{ LLM_KV_NEXTN_PREDICT_LAYERS, "%s.nextn_predict_layers" },
139143
{ LLM_KV_POOLING_TYPE, "%s.pooling_type" },
@@ -2211,6 +2215,29 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
22112215
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
22122216
},
22132217
},
2218+
{
2219+
LLM_ARCH_GROVEMOE,
2220+
{
2221+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
2222+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
2223+
{ LLM_TENSOR_OUTPUT, "output" },
2224+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
2225+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
2226+
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
2227+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
2228+
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
2229+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
2230+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
2231+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
2232+
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
2233+
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
2234+
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
2235+
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
2236+
{ LLM_TENSOR_FFN_GATE_CHEXPS, "blk.%d.ffn_gate_chexps" },
2237+
{ LLM_TENSOR_FFN_DOWN_CHEXPS, "blk.%d.ffn_down_chexps" },
2238+
{ LLM_TENSOR_FFN_UP_CHEXPS, "blk.%d.ffn_up_chexps" },
2239+
},
2240+
},
22142241
{
22152242
LLM_ARCH_UNKNOWN,
22162243
{
@@ -2343,6 +2370,9 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
23432370
{LLM_TENSOR_FFN_DOWN_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
23442371
{LLM_TENSOR_FFN_GATE_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
23452372
{LLM_TENSOR_FFN_UP_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
2373+
{LLM_TENSOR_FFN_DOWN_CHEXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
2374+
{LLM_TENSOR_FFN_GATE_CHEXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
2375+
{LLM_TENSOR_FFN_UP_CHEXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
23462376
{LLM_TENSOR_FFN_EXP_PROBS_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
23472377
// altup / laurel (gemma 3n)
23482378
{LLM_TENSOR_PER_LAYER_TOKEN_EMBD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}},

src/llama-arch.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ enum llm_arch {
103103
LLM_ARCH_LLADA_MOE,
104104
LLM_ARCH_SEED_OSS,
105105
LLM_ARCH_APERTUS,
106+
LLM_ARCH_GROVEMOE,
106107
LLM_ARCH_UNKNOWN,
107108
};
108109

@@ -130,6 +131,7 @@ enum llm_kv {
130131
LLM_KV_FEED_FORWARD_LENGTH,
131132
LLM_KV_EXPERT_FEED_FORWARD_LENGTH,
132133
LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH,
134+
LLM_KV_EXPERT_CHUNK_FEED_FORWARD_LENGTH,
133135
LLM_KV_USE_PARALLEL_RESIDUAL,
134136
LLM_KV_TENSOR_DATA_LAYOUT,
135137
LLM_KV_EXPERT_COUNT,
@@ -138,6 +140,8 @@ enum llm_kv {
138140
LLM_KV_EXPERT_WEIGHTS_SCALE,
139141
LLM_KV_EXPERT_WEIGHTS_NORM,
140142
LLM_KV_EXPERT_GATING_FUNC,
143+
LLM_KV_EXPERT_GROUP_SCALE,
144+
LLM_KV_EXPERTS_PER_GROUP,
141145
LLM_KV_MOE_EVERY_N_LAYERS,
142146
LLM_KV_NEXTN_PREDICT_LAYERS,
143147
LLM_KV_POOLING_TYPE,
@@ -307,6 +311,9 @@ enum llm_tensor {
307311
LLM_TENSOR_FFN_DOWN_SHEXP,
308312
LLM_TENSOR_FFN_GATE_SHEXP,
309313
LLM_TENSOR_FFN_UP_SHEXP,
314+
LLM_TENSOR_FFN_DOWN_CHEXPS,
315+
LLM_TENSOR_FFN_GATE_CHEXPS,
316+
LLM_TENSOR_FFN_UP_CHEXPS,
310317
LLM_TENSOR_FFN_EXP_PROBS_B,
311318
LLM_TENSOR_ATTN_Q_NORM,
312319
LLM_TENSOR_ATTN_K_NORM,

src/llama-graph.cpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -923,13 +923,26 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
923923
selection_probs = logits;
924924
}
925925

926+
if (arch == LLM_ARCH_GROVEMOE) {
927+
selection_probs = ggml_sigmoid(ctx0, logits); // [n_expert, n_tokens]
928+
cb(selection_probs, "ffn_moe_probs_biased", il);
929+
}
930+
926931
// select experts
927932
ggml_tensor * selected_experts = ggml_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
928933
cb(selected_experts->src[0], "ffn_moe_argsort", il);
929934
cb(selected_experts, "ffn_moe_topk", il);
930935

931-
ggml_tensor * weights = ggml_get_rows(ctx0,
932-
ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
936+
if (arch == LLM_ARCH_GROVEMOE && n_expert != hparams.n_expert) {
937+
// TODO: Use scalar div instead when/if implemented
938+
ggml_tensor * f_sel = ggml_cast(ctx0, selected_experts, GGML_TYPE_F32);
939+
selected_experts = ggml_cast(ctx0, ggml_scale(ctx0, f_sel, 1.0f / float(hparams.n_group_experts)), GGML_TYPE_I32);
940+
probs = ggml_reshape_3d(ctx0, probs, 1, hparams.n_expert, n_tokens);
941+
} else {
942+
probs = ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens);
943+
}
944+
945+
ggml_tensor * weights = ggml_get_rows(ctx0, probs, selected_experts); // [1, n_expert_used, n_tokens]
933946
cb(weights, "ffn_moe_weights", il);
934947

935948

0 commit comments

Comments
 (0)