Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -738,7 +738,7 @@ const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count";
// MoE utils
//

const char * const LLM_FFN_EXPS_REGEX = "\\.ffn_(up|down|gate)_exps";
const char * const LLM_FFN_EXPS_REGEX = "\\.ffn_(up|down|gate)_(ch|)exps";

static std::string llm_ffn_exps_block_regex(int idx) {
return string_format("blk\\.%d%s", idx, LLM_FFN_EXPS_REGEX);
Expand Down
115 changes: 115 additions & 0 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -7930,6 +7930,121 @@ def prepare_tensors(self):
raise ValueError(f"Unprocessed experts: {experts}")


@ModelBase.register("GroveMoeForCausalLM", "modeling_grove_moe.GroveMoeForCausalLM")
class GroveMoeModel(TextModel):
model_arch = gguf.MODEL_ARCH.GROVEMOE

def set_gguf_parameters(self):
super().set_gguf_parameters()
if (n_experts := self.hparams.get("num_experts")) is not None:
self.gguf_writer.add_expert_count(n_experts)
if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None:
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
logger.info(f"gguf: expert feed forward length = {moe_intermediate_size}")
# FIXME?: Hardcoded https://huggingface.co/inclusionAI/GroveMoE-Inst/blob/c4c69e5970d18907b5e6ddccdfd55176fe292df1/modeling_grove_moe.py#L299
self.gguf_writer.add_expert_chunk_feed_forward_length(self.hparams.get("head_dim") or 128)
# FIXME?: Hardcoded https://huggingface.co/inclusionAI/GroveMoE-Inst/blob/c4c69e5970d18907b5e6ddccdfd55176fe292df1/modeling_grove_moe.py#L298
self.gguf_writer.add_experts_per_group(2)
# FIXME?: Hardcoded https://huggingface.co/inclusionAI/GroveMoE-Inst/blob/c4c69e5970d18907b5e6ddccdfd55176fe292df1/modeling_grove_moe.py#L376
self.gguf_writer.add_expert_group_scale(0.05)
# YaRN is not enabled by default
# To enable it, please refer to this guide: https://huggingface.co/Qwen/Qwen3-30B-A3B#processing-long-texts
rope_scaling = self.hparams.get("rope_scaling") or {}
if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling:
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"])

_experts: list[dict[str, Tensor]] | None = None
_chunk_experts: list[dict[str, Tensor]] | None = None

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if name.endswith(".expert_bias"):
# FIXME?: Unused https://huggingface.co/inclusionAI/GroveMoE-Inst/blob/c4c69e5970d18907b5e6ddccdfd55176fe292df1/modeling_grove_moe.py#L303
return []

# process the experts separately
if name.find("chunk_experts") != -1:
n_experts = self.hparams["num_experts"] // 2 # see add_experts_per_group
assert bid is not None

if self._chunk_experts is None:
self._chunk_experts = [{} for _ in range(self.block_count)]

self._chunk_experts[bid][name] = data_torch

if len(self._chunk_experts[bid]) >= n_experts * 3:
tensors: list[tuple[str, Tensor]] = []

# merge the experts into a single 3d tensor
for w_name in ["down_proj", "gate_proj", "up_proj"]:
datas: list[Tensor] = []

for xid in range(n_experts):
ename = f"model.layers.{bid}.mlp.chunk_experts.{xid}.{w_name}.weight"
datas.append(self._chunk_experts[bid][ename])
del self._chunk_experts[bid][ename]

data_torch = torch.stack(datas, dim=0)

merged_name = f"model.layers.{bid}.mlp.chunk_experts.{w_name}.weight"

new_name = self.map_tensor_name(merged_name)

tensors.append((new_name, data_torch))
return tensors
else:
return []
elif name.find("experts") != -1:
n_experts = self.hparams["num_experts"]
assert bid is not None

if self._experts is None:
self._experts = [{} for _ in range(self.block_count)]

self._experts[bid][name] = data_torch

if len(self._experts[bid]) >= n_experts * 3:
tensors: list[tuple[str, Tensor]] = []

# merge the experts into a single 3d tensor
for w_name in ["down_proj", "gate_proj", "up_proj"]:
datas: list[Tensor] = []

for xid in range(n_experts):
ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
datas.append(self._experts[bid][ename])
del self._experts[bid][ename]

data_torch = torch.stack(datas, dim=0)

merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"

new_name = self.map_tensor_name(merged_name)

tensors.append((new_name, data_torch))
return tensors
else:
return []

return [(self.map_tensor_name(name), data_torch)]

def prepare_tensors(self):
super().prepare_tensors()

if self._chunk_experts is not None:
# flatten `list[dict[str, Tensor]]` into `list[str]`
chunk_experts = [k for d in self._chunk_experts for k in d.keys()]
if len(chunk_experts) > 0:
raise ValueError(f"Unprocessed adjugate experts: {chunk_experts}")

if self._experts is not None:
# flatten `list[dict[str, Tensor]]` into `list[str]`
experts = [k for d in self._experts for k in d.keys()]
if len(experts) > 0:
raise ValueError(f"Unprocessed experts: {experts}")


@ModelBase.register("ChameleonForConditionalGeneration")
@ModelBase.register("ChameleonForCausalLM") # obsolete
class ChameleonModel(TextModel):
Expand Down
31 changes: 31 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ class LLM:
FEED_FORWARD_LENGTH = "{arch}.feed_forward_length"
EXPERT_FEED_FORWARD_LENGTH = "{arch}.expert_feed_forward_length"
EXPERT_SHARED_FEED_FORWARD_LENGTH = "{arch}.expert_shared_feed_forward_length"
EXPERT_CHUNK_FEED_FORWARD_LENGTH = "{arch}.expert_chunk_feed_forward_length"
USE_PARALLEL_RESIDUAL = "{arch}.use_parallel_residual"
TENSOR_DATA_LAYOUT = "{arch}.tensor_data_layout"
EXPERT_COUNT = "{arch}.expert_count"
Expand All @@ -104,6 +105,8 @@ class LLM:
EXPERT_WEIGHTS_SCALE = "{arch}.expert_weights_scale"
EXPERT_WEIGHTS_NORM = "{arch}.expert_weights_norm"
EXPERT_GATING_FUNC = "{arch}.expert_gating_func"
EXPERT_GROUP_SCALE = "{arch}.expert_group_scale"
EXPERTS_PER_GROUP = "{arch}.experts_per_group"
MOE_EVERY_N_LAYERS = "{arch}.moe_every_n_layers"
NEXTN_PREDICT_LAYERS = "{arch}.nextn_predict_layers"
POOLING_TYPE = "{arch}.pooling_type"
Expand Down Expand Up @@ -401,6 +404,7 @@ class MODEL_ARCH(IntEnum):
LLADA = auto()
LLADA_MOE = auto()
SEED_OSS = auto()
GROVEMOE = auto()


class VISION_PROJECTOR_TYPE(IntEnum):
Expand Down Expand Up @@ -450,6 +454,9 @@ class MODEL_TENSOR(IntEnum):
FFN_GATE_SHEXP = auto()
FFN_DOWN_SHEXP = auto()
FFN_UP_SHEXP = auto()
FFN_GATE_CHEXP = auto()
FFN_DOWN_CHEXP = auto()
FFN_UP_CHEXP = auto()
FFN_EXP_PROBS_B = auto()
ATTN_Q_NORM = auto()
ATTN_K_NORM = auto()
Expand Down Expand Up @@ -738,6 +745,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.LLADA: "llada",
MODEL_ARCH.LLADA_MOE: "llada-moe",
MODEL_ARCH.SEED_OSS: "seed_oss",
MODEL_ARCH.GROVEMOE: "grovemoe",
}

VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
Expand Down Expand Up @@ -784,6 +792,9 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_GATE_SHEXP: "blk.{bid}.ffn_gate_shexp",
MODEL_TENSOR.FFN_DOWN_SHEXP: "blk.{bid}.ffn_down_shexp",
MODEL_TENSOR.FFN_UP_SHEXP: "blk.{bid}.ffn_up_shexp",
MODEL_TENSOR.FFN_GATE_CHEXP: "blk.{bid}.ffn_gate_chexps",
MODEL_TENSOR.FFN_DOWN_CHEXP: "blk.{bid}.ffn_down_chexps",
MODEL_TENSOR.FFN_UP_CHEXP: "blk.{bid}.ffn_up_chexps",
MODEL_TENSOR.FFN_ACT: "blk.{bid}.ffn",
MODEL_TENSOR.FFN_NORM_EXP: "blk.{bid}.ffn_norm_exps",
MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate_exps",
Expand Down Expand Up @@ -2712,6 +2723,26 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_UP_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
],
MODEL_ARCH.GROVEMOE: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_Q_NORM,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_K_NORM,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
MODEL_TENSOR.FFN_GATE_CHEXP,
MODEL_TENSOR.FFN_DOWN_CHEXP,
MODEL_TENSOR.FFN_UP_CHEXP,
],
# TODO
}

Expand Down
9 changes: 9 additions & 0 deletions gguf-py/gguf/gguf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,9 @@ def add_expert_feed_forward_length(self, length: int) -> None:
def add_expert_shared_feed_forward_length(self, length: int) -> None:
self.add_uint32(Keys.LLM.EXPERT_SHARED_FEED_FORWARD_LENGTH.format(arch=self.arch), length)

def add_expert_chunk_feed_forward_length(self, length: int) -> None:
self.add_uint32(Keys.LLM.EXPERT_CHUNK_FEED_FORWARD_LENGTH.format(arch=self.arch), length)

def add_parallel_residual(self, use: bool) -> None:
self.add_bool(Keys.LLM.USE_PARALLEL_RESIDUAL.format(arch=self.arch), use)

Expand Down Expand Up @@ -757,6 +760,12 @@ def add_expert_weights_norm(self, value: bool) -> None:
def add_expert_gating_func(self, value: ExpertGatingFuncType) -> None:
self.add_uint32(Keys.LLM.EXPERT_GATING_FUNC.format(arch=self.arch), value.value)

def add_expert_group_scale(self, value: float) -> None:
self.add_float32(Keys.LLM.EXPERT_GROUP_SCALE.format(arch=self.arch), value)

def add_experts_per_group(self, count: int) -> None:
self.add_uint32(Keys.LLM.EXPERTS_PER_GROUP.format(arch=self.arch), count)

def add_moe_every_n_layers(self, value: int) -> None:
self.add_uint32(Keys.LLM.MOE_EVERY_N_LAYERS.format(arch=self.arch), value)

Expand Down
12 changes: 12 additions & 0 deletions gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,10 @@ class TensorNameMap:
"model.layers.{bid}.mlp.shared_mlp.up_proj", # hunyuan
),

MODEL_TENSOR.FFN_UP_CHEXP: (
"model.layers.{bid}.mlp.chunk_experts.up_proj", # grovemoe
),

# AWQ-activation gate
MODEL_TENSOR.FFN_ACT: (
"transformer.blocks.{bid}.ffn.act", # mpt
Expand Down Expand Up @@ -468,6 +472,10 @@ class TensorNameMap:
"model.layers.{bid}.mlp.shared_mlp.gate_proj", # hunyuan
),

MODEL_TENSOR.FFN_GATE_CHEXP: (
"model.layers.{bid}.mlp.chunk_experts.gate_proj", # grovemoe
),

# Feed-forward down
MODEL_TENSOR.FFN_DOWN: (
"gpt_neox.layers.{bid}.mlp.dense_4h_to_h", # gptneox
Expand Down Expand Up @@ -524,6 +532,10 @@ class TensorNameMap:
"model.layers.{bid}.mlp.shared_mlp.down_proj", # hunyuan
),

MODEL_TENSOR.FFN_DOWN_CHEXP: (
"model.layers.{bid}.mlp.chunk_experts.down_proj", # grovemoe
),

MODEL_TENSOR.ATTN_Q_NORM: (
"language_model.encoder.layers.{bid}.self_attention.q_layernorm",
"model.layers.{bid}.self_attn.q_layernorm", # persimmon
Expand Down
30 changes: 30 additions & 0 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_LLADA, "llada" },
{ LLM_ARCH_LLADA_MOE, "llada-moe" },
{ LLM_ARCH_SEED_OSS, "seed_oss" },
{ LLM_ARCH_GROVEMOE, "grovemoe" },
{ LLM_ARCH_UNKNOWN, "(unknown)" },
};

Expand Down Expand Up @@ -125,6 +126,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_FEED_FORWARD_LENGTH, "%s.feed_forward_length" },
{ LLM_KV_EXPERT_FEED_FORWARD_LENGTH, "%s.expert_feed_forward_length" },
{ LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, "%s.expert_shared_feed_forward_length" },
{ LLM_KV_EXPERT_CHUNK_FEED_FORWARD_LENGTH, "%s.expert_chunk_feed_forward_length" },
{ LLM_KV_USE_PARALLEL_RESIDUAL, "%s.use_parallel_residual" },
{ LLM_KV_TENSOR_DATA_LAYOUT, "%s.tensor_data_layout" },
{ LLM_KV_EXPERT_COUNT, "%s.expert_count" },
Expand All @@ -133,6 +135,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" },
{ LLM_KV_EXPERT_WEIGHTS_NORM, "%s.expert_weights_norm" },
{ LLM_KV_EXPERT_GATING_FUNC, "%s.expert_gating_func" },
{ LLM_KV_EXPERT_GROUP_SCALE, "%s.expert_group_scale" },
{ LLM_KV_EXPERTS_PER_GROUP, "%s.experts_per_group" },
{ LLM_KV_MOE_EVERY_N_LAYERS, "%s.moe_every_n_layers" },
{ LLM_KV_NEXTN_PREDICT_LAYERS, "%s.nextn_predict_layers" },
{ LLM_KV_POOLING_TYPE, "%s.pooling_type" },
Expand Down Expand Up @@ -2185,6 +2189,29 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_GROVEMOE,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
{ LLM_TENSOR_FFN_GATE_CHEXPS, "blk.%d.ffn_gate_chexps" },
{ LLM_TENSOR_FFN_DOWN_CHEXPS, "blk.%d.ffn_down_chexps" },
{ LLM_TENSOR_FFN_UP_CHEXPS, "blk.%d.ffn_up_chexps" },
},
},
{
LLM_ARCH_UNKNOWN,
{
Expand Down Expand Up @@ -2317,6 +2344,9 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_FFN_DOWN_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
{LLM_TENSOR_FFN_GATE_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
{LLM_TENSOR_FFN_UP_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
{LLM_TENSOR_FFN_DOWN_CHEXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
{LLM_TENSOR_FFN_GATE_CHEXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
{LLM_TENSOR_FFN_UP_CHEXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
{LLM_TENSOR_FFN_EXP_PROBS_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
// altup / laurel (gemma 3n)
{LLM_TENSOR_PER_LAYER_TOKEN_EMBD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}},
Expand Down
7 changes: 7 additions & 0 deletions src/llama-arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ enum llm_arch {
LLM_ARCH_LLADA,
LLM_ARCH_LLADA_MOE,
LLM_ARCH_SEED_OSS,
LLM_ARCH_GROVEMOE,
LLM_ARCH_UNKNOWN,
};

Expand Down Expand Up @@ -129,6 +130,7 @@ enum llm_kv {
LLM_KV_FEED_FORWARD_LENGTH,
LLM_KV_EXPERT_FEED_FORWARD_LENGTH,
LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH,
LLM_KV_EXPERT_CHUNK_FEED_FORWARD_LENGTH,
LLM_KV_USE_PARALLEL_RESIDUAL,
LLM_KV_TENSOR_DATA_LAYOUT,
LLM_KV_EXPERT_COUNT,
Expand All @@ -137,6 +139,8 @@ enum llm_kv {
LLM_KV_EXPERT_WEIGHTS_SCALE,
LLM_KV_EXPERT_WEIGHTS_NORM,
LLM_KV_EXPERT_GATING_FUNC,
LLM_KV_EXPERT_GROUP_SCALE,
LLM_KV_EXPERTS_PER_GROUP,
LLM_KV_MOE_EVERY_N_LAYERS,
LLM_KV_NEXTN_PREDICT_LAYERS,
LLM_KV_POOLING_TYPE,
Expand Down Expand Up @@ -301,6 +305,9 @@ enum llm_tensor {
LLM_TENSOR_FFN_DOWN_SHEXP,
LLM_TENSOR_FFN_GATE_SHEXP,
LLM_TENSOR_FFN_UP_SHEXP,
LLM_TENSOR_FFN_DOWN_CHEXPS,
LLM_TENSOR_FFN_GATE_CHEXPS,
LLM_TENSOR_FFN_UP_CHEXPS,
LLM_TENSOR_FFN_EXP_PROBS_B,
LLM_TENSOR_ATTN_Q_NORM,
LLM_TENSOR_ATTN_K_NORM,
Expand Down
17 changes: 15 additions & 2 deletions src/llama-graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -920,13 +920,26 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
selection_probs = logits;
}

if (arch == LLM_ARCH_GROVEMOE) {
selection_probs = ggml_sigmoid(ctx0, logits); // [n_expert, n_tokens]
cb(selection_probs, "ffn_moe_probs_biased", il);
}

// select experts
ggml_tensor * selected_experts = ggml_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
cb(selected_experts->src[0], "ffn_moe_argsort", il);
cb(selected_experts, "ffn_moe_topk", il);

ggml_tensor * weights = ggml_get_rows(ctx0,
ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
if (arch == LLM_ARCH_GROVEMOE && n_expert != hparams.n_expert) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When is n_expert != hparams.n_expert?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When doing the adjugate experts pass:

llama.cpp/src/llama-model.cpp

Lines 19025 to 19038 in ee51669

// TODO: Only do the expert selection and weights once
moe_out =
build_moe_ffn(cur,
nullptr,
model.layers[il].ffn_up_chexps,
model.layers[il].ffn_gate_chexps,
model.layers[il].ffn_down_chexps,
nullptr,
n_chunk_expert, n_expert_used > n_chunk_expert ? n_chunk_expert : n_expert_used,
LLM_FFN_SILU, true,
false, 0.0,
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
il, probs);
cb(moe_out, "ffn_adj_moe_out", il);

// TODO: Use scalar div instead when/if implemented
ggml_tensor * f_sel = ggml_cast(ctx0, selected_experts, GGML_TYPE_F32);
selected_experts = ggml_cast(ctx0, ggml_scale(ctx0, f_sel, 1.0f / float(hparams.n_group_experts)), GGML_TYPE_I32);
probs = ggml_reshape_3d(ctx0, probs, 1, hparams.n_expert, n_tokens);
} else {
probs = ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens);
}

ggml_tensor * weights = ggml_get_rows(ctx0, probs, selected_experts); // [1, n_expert_used, n_tokens]
cb(weights, "ffn_moe_weights", il);

if (gating_op == LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT) {
Expand Down
Loading
Loading