Skip to content

Conversation

CISC
Copy link
Collaborator

@CISC CISC commented Aug 22, 2025

Adds support for inclusionAI/GroveMoE, a novel adjugate experts grouped with ordinary experts architecture (paper).

The PR is in a fully working state, but I submit it as draft because it requires a scalar div implementation that was quickly hacked together just to get the model running. Only div is (very crudely) implemented, and only for CPU (doesn't matter, not much computation is spent here), and I'm not satisfied that the API makes sense, in short this requires more thought!

@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs Vulkan Issues specific to the Vulkan backend python python script changes ggml changes relating to the ggml tensor library for machine learning SYCL https://en.wikipedia.org/wiki/SYCL - GPU programming language Apple Metal https://en.wikipedia.org/wiki/Metal_(API) Ascend NPU issues specific to Ascend NPUs OpenCL Issues specific to the OpenCL backend labels Aug 22, 2025
@CISC CISC removed Nvidia GPU Issues specific to Nvidia GPUs Vulkan Issues specific to the Vulkan backend SYCL https://en.wikipedia.org/wiki/SYCL - GPU programming language Apple Metal https://en.wikipedia.org/wiki/Metal_(API) Ascend NPU issues specific to Ascend NPUs OpenCL Issues specific to the OpenCL backend labels Aug 22, 2025
@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs Vulkan Issues specific to the Vulkan backend SYCL https://en.wikipedia.org/wiki/SYCL - GPU programming language Apple Metal https://en.wikipedia.org/wiki/Metal_(API) Ascend NPU issues specific to Ascend NPUs OpenCL Issues specific to the OpenCL backend labels Aug 22, 2025
@CISC CISC removed Nvidia GPU Issues specific to Nvidia GPUs Vulkan Issues specific to the Vulkan backend SYCL https://en.wikipedia.org/wiki/SYCL - GPU programming language Apple Metal https://en.wikipedia.org/wiki/Metal_(API) Ascend NPU issues specific to Ascend NPUs OpenCL Issues specific to the OpenCL backend labels Aug 22, 2025
@CISC
Copy link
Collaborator Author

CISC commented Aug 22, 2025

Looks like ccache breaks the build (using cached files newer than this branch), not important right now though...

@CISC CISC marked this pull request as ready for review September 8, 2025 21:01
@CISC CISC requested a review from slaren September 8, 2025 21:03
@CISC
Copy link
Collaborator Author

CISC commented Sep 12, 2025

@slaren gentle ping

common/arg.cpp Outdated
"keep all Mixture of Experts (MoE) weights in the CPU",
[](common_params & params) {
params.tensor_buft_overrides.push_back({"\\.ffn_(up|down|gate)_exps", ggml_backend_cpu_buffer_type()});
params.tensor_buft_overrides.push_back({"\\.ffn_(up|down|gate)_(ch|)exps", ggml_backend_cpu_buffer_type()});
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The regexs are now in common.h

@CISC CISC requested a review from ggerganov as a code owner September 24, 2025 06:55

ggml_tensor * weights = ggml_get_rows(ctx0,
ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
if (arch == LLM_ARCH_GROVEMOE && n_expert != hparams.n_expert) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When is n_expert != hparams.n_expert?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When doing the adjugate experts pass:

llama.cpp/src/llama-model.cpp

Lines 19025 to 19038 in ee51669

// TODO: Only do the expert selection and weights once
moe_out =
build_moe_ffn(cur,
nullptr,
model.layers[il].ffn_up_chexps,
model.layers[il].ffn_gate_chexps,
model.layers[il].ffn_down_chexps,
nullptr,
n_chunk_expert, n_expert_used > n_chunk_expert ? n_chunk_expert : n_expert_used,
LLM_FFN_SILU, true,
false, 0.0,
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
il, probs);
cb(moe_out, "ffn_adj_moe_out", il);

@CISC
Copy link
Collaborator Author

CISC commented Sep 24, 2025

Oh, I just noticed it breaks (just outputs endless ?) if you're unlucky with the distribution of ch/exps with -ncmoe (and not always on the first pass), I'm guessing because of some buffer not being on the same backend? Works fine with -cmoe...

@ggerganov
Copy link
Member

ggerganov commented Sep 24, 2025

Are you running F16 weights? If yes, there is a chance you are hitting this assert:

// if you hit this, you are likely running outside the FP range
assert(!isnan(sumf) && !isinf(sumf));

Build in Debug to confirm that.

@CISC
Copy link
Collaborator Author

CISC commented Sep 24, 2025

Are you running F16 weights? If yes, there is a chance you are hitting this assert:

Nope, Q8_0.

Build in Debug to confirm that.

I will do that and try to figure out the issue later.

@CISC
Copy link
Collaborator Author

CISC commented Sep 25, 2025

Build in Debug to confirm that.

I will do that and try to figure out the issue later.

Didn't catch anything, however when I run it through llama-eval-callback I get encountered NaN - aborting whenever I'm using -ncmoe, but -cmoe works fine, so definitely something odd going on...

@CISC
Copy link
Collaborator Author

CISC commented Sep 25, 2025

Ok, fully offloading works fine too, so this is unlikely to be a model issue, just seems to be triggering some problem with partial offloading of experts.

Merging.

@CISC CISC merged commit 835b2b9 into master Sep 25, 2025
67 of 69 checks passed
@CISC CISC deleted the cisc/grovemoe branch September 25, 2025 17:50
@ggerganov
Copy link
Member

Btw, which was the first op that produced the NaN when you ran the eval-callback?

@CISC
Copy link
Collaborator Author

CISC commented Sep 25, 2025

Btw, which was the first op that produced the NaN when you ran the eval-callback?

ffn_moe_weighted-35 = (f32) SWIGLU(ffn_moe_gate-35, contains -inf, but the 2 previous ops have very suspicous values:

ggml_debug:          ffn_moe_gate-35 = (f32) MUL_MAT_ID(blk.35.ffn_gate_chexps.weight{2048, 128, 64, 1}, ffn_moe_out-35 (reshaped){2048, 1, 11, 1}}) = {128, 8, 11, 1}
                                     [
                                      [
                                       ..., 
                                      ],
                                      [
                                       ..., 
                                      ],
                                      [
                                       ..., 
                                       [692609189888710725876244691447447552.0000, 249074347992884282519164127694290944.0000, 1018055346873854459523945924594237440.0000, ..., 51586575710833024190924974872068096.0000, -2067560154401421119326973933186449408.0000, -1679141552374039747815821723257798656.0000],
                                       [      0.0699,       0.0112,      -0.0616, ...,       0.0889,       0.1895,      -0.0447],
                                       [      0.0000,       0.0000,       0.0000, ...,       0.0000,       0.0000,       0.0000],
                                      ],
                                      ..., 
                                      [
                                       ..., 
                                      ],
                                      [
                                       ..., 
                                      ],
                                      [
                                       ..., 
                                      ],
                                     ]
                                     sum = -34555508860415417087805366059915018240.000000
ggml_debug:            ffn_moe_up-35 = (f32) MUL_MAT_ID(blk.35.ffn_up_chexps.weight{2048, 128, 64, 1}, ffn_moe_out-35 (reshaped){2048, 1, 11, 1}}) = {128, 8, 11, 1}
                                     [
                                      [
                                       ..., 
                                      ],
                                      [
                                       ..., 
                                      ],
                                      [
                                       ..., 
                                       [946507641661885421081057718347759616.0000, -2441153741115458376516297386216128512.0000, -407195320016530705331302955210506240.0000, ..., 764541706286011546206824740660707328.0000, -81787806567653304114262973125492736.0000, -1564830842585809104478506137537216512.0000],
                                       [     -0.0081,       0.0035,       0.0455, ...,       0.0669,       0.1087,      -0.0017],
                                       [     -0.0052,      -0.0083,       0.0079, ...,       0.0068,       0.0106,      -0.0006],
                                      ],
                                      ..., 
                                      [
                                       ..., 
                                      ],
                                      [
                                       ..., 
                                      ],
                                      [
                                       ..., 
                                      ],
                                     ]
                                     sum = 16824229610265255367388010517576548352.000000

Edit: probs and topk look normal.

@ggerganov
Copy link
Member

Ok, I'll likely take a look when the GGUFs appear and if I don't forget.

pwilkin pushed a commit to pwilkin/llama.cpp that referenced this pull request Sep 25, 2025
* add GroveMoE support

* remove constexpr that fails on certain compilers

* revert crude scalar div implementation, use cast

* build_attn_inp_kv_unified -> build_attn_inp_kv

* fix build_attn

* re-apply ffn_exps regex changes
@gabriellarson
Copy link
Contributor

I'm attempting to make an imatrix and im getting this error:

/workspace/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu:2163: GGML_ASSERT(ids_to_sorted_host.size() == size_t(ne_get_rows)) failed
[New LWP 12890]
[New LWP 12892]
[New LWP 12893]
[Thread debugging using libthread_db enabled]
Using host libthread_db library "/usr/lib/x86_64-linux-gnu/libthread_db.so.1".
0x00007fe90a96f42f in wait4 () from /usr/lib/x86_64-linux-gnu/libc.so.6
#0  0x00007fe90a96f42f in wait4 () from /usr/lib/x86_64-linux-gnu/libc.so.6
#1  0x00007fe90ae0784b in ggml_print_backtrace () from /workspace/llama.cpp/build/bin/libggml-base.so
#2  0x00007fe90ae079e2 in ggml_abort () from /workspace/llama.cpp/build/bin/libggml-base.so
#3  0x00007fe8f6a5a1d1 in ggml_cuda_mul_mat_id(ggml_backend_cuda_context&, ggml_tensor*) () from /workspace/llama.cpp/build/bin/libggml-cuda.so
#4  0x00007fe8f6a5bc6b in ggml_backend_cuda_graph_compute(ggml_backend*, ggml_cgraph*) () from /workspace/llama.cpp/build/bin/libggml-cuda.so
#5  0x00007fe90ae22f9a in ggml_backend_sched_graph_compute_async () from /workspace/llama.cpp/build/bin/libggml-base.so
#6  0x00007fe90af3fdd1 in llama_context::graph_compute(ggml_cgraph*, bool) () from /workspace/llama.cpp/build/bin/libllama.so
#7  0x00007fe90af40165 in llama_context::process_ubatch(llama_ubatch const&, llm_graph_type, llama_memory_context_i*, ggml_status&) () from /workspace/llama.cpp/build/bin/libllama.so
#8  0x00007fe90af469a7 in llama_context::decode(llama_batch const&) () from /workspace/llama.cpp/build/bin/libllama.so
#9  0x00007fe90af47840 in llama_decode () from /workspace/llama.cpp/build/bin/libllama.so
#10 0x0000555ac1cedbc1 in main ()
[Inferior 1 (process 12889) detached]

@CISC
Copy link
Collaborator Author

CISC commented Sep 26, 2025

I'm attempting to make an imatrix and im getting this error:

Interesting, is it fully offloaded?

@k3d3
Copy link

k3d3 commented Sep 26, 2025

I'm also receiving a similar error:

/opt/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu:2163: GGML_ASSERT(ids_to_sorted_host.size() == size_t(ne_get_rows)) failed
/usr/local/lib64/libggml-base.so(+0x3525) [0x7f985d6b4525]
/usr/local/lib64/libggml-base.so(ggml_print_backtrace+0x1eb) [0x7f985d6b48eb]
/usr/local/lib64/libggml-base.so(ggml_abort+0x11f) [0x7f985d6b4a6f]
/usr/local/lib64/libggml-hip.so(+0x25207d9) [0x7f985fc747d9]
/usr/local/lib64/libggml-base.so(ggml_backend_sched_graph_compute_async+0x7f3) [0x7f985d6ce203]
/usr/local/lib64/libllama.so(_ZN13llama_context13graph_computeEP11ggml_cgraphb+0xa0) [0x7f986019ff10]
/usr/local/lib64/libllama.so(_ZN13llama_context14process_ubatchERK12llama_ubatch14llm_graph_typeP22llama_memory_context_iR11ggml_status+0xe2) [0x7f98601a1b72]
/usr/local/lib64/libllama.so(_ZN13llama_context6decodeERK11llama_batch+0x3af) [0x7f98601a688f]
/usr/local/lib64/libllama.so(llama_decode+0xe) [0x7f98601a76ce]
llama-server() [0x5c1058]
llama-server() [0x496bfe]
llama-server() [0x43731a]
/lib64/libc.so.6(+0x35b5) [0x7f985d1295b5]
/lib64/libc.so.6(__libc_start_main+0x88) [0x7f985d129668]
llama-server() [0x439525]
Aborted                    (core dumped) llama-server --no-mmap -ngl 999 -fa on -c 128000 --jinja -m models/grovemoe-inst/GroveMoE-Inst-128x4.2B-F16.gguf

This is running Gabriel Larson's F16 GGUF, fully offloaded and using a rocm rc7-rocwmma docker/toolbox (I'm using a Strix Halo APU), running llama.cpp version: 6588 (a86a580a).

Unfortunately it appears the core dump is getting snatched up by Fedora (the joys of Bazzite), but if it's useful, I could try to finagle my settings to get the file.

If I use This Q8 GGUF instead then it loads and runs without issue.

@CISC
Copy link
Collaborator Author

CISC commented Sep 26, 2025

Even more interesting, so seems to be an issue with that GGUF.

@gabriellarson Can you try creating a BF16 GGUF instead?

struct pushed a commit to struct/llama.cpp that referenced this pull request Sep 26, 2025
* add GroveMoE support

* remove constexpr that fails on certain compilers

* revert crude scalar div implementation, use cast

* build_attn_inp_kv_unified -> build_attn_inp_kv

* fix build_attn

* re-apply ffn_exps regex changes
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning python python script changes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants