Skip to content

Conversation

pwilkin
Copy link
Collaborator

@pwilkin pwilkin commented Sep 7, 2025

Help!

This seems like an easy enough model. I've done the XIELU implementation, but for some reason, the tensor outputs diverge in the QKV calculations (it's either the RoPE or the norm). I've tried everything, spent a whole day debugging it, maybe I'm missing something really obvious?

@foldl you did the implementation in chatllm.cpp and it works fine, maybe you can see something obvious here?

@pwilkin pwilkin marked this pull request as draft September 7, 2025 08:14
@github-actions github-actions bot added python python script changes ggml changes relating to the ggml tensor library for machine learning labels Sep 7, 2025
@pwilkin
Copy link
Collaborator Author

pwilkin commented Sep 7, 2025

Would solve #15748

@pwilkin
Copy link
Collaborator Author

pwilkin commented Sep 7, 2025

@CISC would be grateful for some hints if you see something obvious that could be wrong with the QK norm / RoPE.

@foldl
Copy link
Contributor

foldl commented Sep 7, 2025

@pwilkin you can check rope-type and the soft-plus thing.

@pwilkin
Copy link
Collaborator Author

pwilkin commented Sep 7, 2025

Rope type is fine I thing (normal, not NeoX), softplus doesn't matter as the problem is before the xIELU activation (first tensor matches reference through all 30 layers, next ones diverge starting from first kqv-out).

@pwilkin
Copy link
Collaborator Author

pwilkin commented Sep 7, 2025

Norm application is wrong, but I totally can't figure out how exactly :/

@ggerganov
Copy link
Member

Which norm?

@pwilkin
Copy link
Collaborator Author

pwilkin commented Sep 7, 2025

@ggerganov KQ norm.

@ggerganov
Copy link
Member

I think forgot to apply the norm bias?

@pwilkin
Copy link
Collaborator Author

pwilkin commented Sep 7, 2025

@ggerganov That's the thing, it doesn't have a norm bias:

  "weight_map": {
    "lm_head.weight": "model-00004-of-00004.safetensors",
    "model.embed_tokens.weight": "model-00001-of-00004.safetensors",
    "model.layers.0.attention_layernorm.weight": "model-00001-of-00004.safetensors",
    "model.layers.0.feedforward_layernorm.weight": "model-00001-of-00004.safetensors",
    "model.layers.0.mlp.act_fn.alpha_n": "model-00001-of-00004.safetensors",
    "model.layers.0.mlp.act_fn.alpha_p": "model-00001-of-00004.safetensors",
    "model.layers.0.mlp.act_fn.beta": "model-00001-of-00004.safetensors",
    "model.layers.0.mlp.act_fn.eps": "model-00001-of-00004.safetensors",
    "model.layers.0.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
    "model.layers.0.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
    "model.layers.0.self_attn.k_norm.weight": "model-00001-of-00004.safetensors",
    "model.layers.0.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
    "model.layers.0.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
    "model.layers.0.self_attn.q_norm.weight": "model-00001-of-00004.safetensors",
    "model.layers.0.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
    "model.layers.0.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",

Looks exactly like the QK norm in Qwen3, a 128 flat tensor for both. But it somehow diverges from the Transformers implementation. Here's a dump from Transformers:

ggml_debug: model.layers.0.self_attn.k_norm_in = (f32)  ... = {torch.Size([1, 8, 6, 128])}
                                     [
                                      [
                                       [      0.0260,       0.0009,      -0.0400, ...,       0.0216,       0.0396,      -0.0355]
                                       [     -0.0006,       0.0012,      -0.0057, ...,      -0.0108,      -0.0005,      -0.0019]
                                       [      0.0505,      -0.0432,       0.0101, ...,       0.0056,      -0.0203,       0.0132]
                                       [     -0.0055,      -0.0051,      -0.0081, ...,      -0.3099,      -0.0374,      -0.0569]
                                       [     -0.0056,      -0.0075,      -0.0123, ...,       0.1299,      -0.1645,      -0.0148]
                                       [      0.0052,      -0.0125,       0.0299, ...,       0.0775,      -0.1068,      -0.0833]
                                      ],
                                     ]
                                     sum = -4.458256

ggml_debug: model.layers.0.self_attn.k_norm_out = (f32)  ... = {torch.Size([1, 8, 6, 128])}
                                     [
                                      [
                                       [      1.6487,       0.0560,      -1.1489, ...,       0.8809,       1.1181,      -1.0577]
                                       [     -0.0343,       0.0733,      -0.1573, ...,      -0.5669,      -0.0163,      -0.0744]
                                       [      3.6375,      -3.0737,       0.3281, ...,       0.3351,      -0.8343,       0.5715]
                                       [     -0.2226,      -0.2069,      -0.1493, ...,     -10.3275,      -0.8633,      -1.3851]
                                       [     -0.1887,      -0.2521,      -0.1895, ...,       3.9971,      -3.5061,      -0.3333]
                                       [      0.3068,      -0.7309,       0.7994, ...,       3.6865,      -3.5230,      -2.8940]
                                      ],
                                     ]
                                     sum = 518.159729

and from llama-eval-callback:

ggml_debug:                   Kcur-0 = (f32)    MUL_MAT(blk.0.attn_k.weight{4096, 1024, 1, 1}, attn_norm-0{4096, 6, 1, 1}}) = {1024, 6, 1, 1}
                                     [
                                      [
                                       [      0.0260,       0.0299,       0.0009, ...,       0.0173,       0.0058,      -0.0157],
                                       [      0.0602,       0.0272,      -0.0239, ...,       0.0578,       0.0108,      -0.0360],
                                       [      0.1024,       0.0396,       0.0892, ...,       0.0203,       0.0063,      -0.0092],
                                       [      0.0371,       0.0511,      -0.0486, ...,       0.0725,       0.0086,      -0.0277],
                                       [      0.0441,       0.0391,      -0.0433, ...,       0.0447,      -0.0085,      -0.0458],
                                       [      0.0566,       0.0338,      -0.0308, ...,       0.0555,       0.0202,      -0.0168],
                                      ],
                                     ]
                                     sum = -4.452065
ggml_debug:            Kcur_normed-0 = (f32)        MUL(CUDA0#norm-0#0{128, 8, 6, 1}, CUDA0#blk.0.attn_k_norm.weight#0{128, 1, 1, 1}}) = {128, 8, 6, 1}
                                     [
                                      [
                                       [      1.6491,       1.8751,       0.0248, ...,       3.8993,       0.2504,      -0.7591],
                                       [     -0.0332,      -0.1835,       0.0339, ...,       0.7072,       0.2774,       0.4394],
                                       [      3.6366,       2.1278,      -1.4065, ...,       0.4759,      -0.0021,      -0.0724],
                                       ..., 
                                       [      0.3081,      -0.1988,      -0.3355, ...,      -1.0101,      -0.1090,      -2.8183],
                                       [      2.0318,       0.6510,      -0.9407, ...,      -3.0910,      -0.1242,       1.1163],
                                       [     -1.1660,      -1.4152,       0.5272, ...,       1.8761,       0.4326,      -1.2444],
                                      ],
                                      [
                                       [      1.5329,       0.6841,      -0.2749, ...,       2.7403,      -0.5124,      -1.3517],
                                       [     -2.0542,      -1.4692,      -0.0407, ...,       1.1136,       0.5337,       0.1236],
                                       [      3.9727,       6.0690,      -2.1473, ...,      -2.2552,      -1.5008,      -0.9382],
                                       ..., 
                                       [     -0.2568,      -0.4579,      -0.3575, ...,       1.8921,       1.0777,      -3.7264],
                                       [      0.5295,       0.1762,      -0.1126, ...,      -3.3283,       0.6000,       1.7422],
                                       [     -5.0935,      -4.9426,       1.0877, ...,       2.7902,       0.3603,      -1.2684],
                                      ],
                                      [
                                       [      4.1699,       1.5921,       1.6440, ...,       1.3991,      -0.0073,      -1.1039],
                                       [     -0.0246,      -0.0408,      -0.0693, ...,      -0.2578,      -0.3472,       0.1711],
                                       [      0.2052,       0.0656,      -0.4119, ...,       0.1593,       0.0395,      -0.5277],
                                       ..., 
                                       [      1.1948,       0.1338,      -0.5707, ...,       5.8187,       0.8847,      -2.7834],
                                       [      0.4890,      -0.2239,      -0.0053, ...,      -2.8752,      -0.2403,       1.6677],
                                       [     -0.2667,      -0.4034,       0.6344, ...,       2.1922,       0.4676,      -0.7278],
                                      ],
                                      [
                                       [      0.9241,       1.2575,      -0.5471, ...,       3.1444,       0.3413,      -0.1515],
                                       [     -2.0858,      -0.2047,      -0.6143, ...,       0.7987,       1.1992,       0.1377],
                                       [      4.4787,       6.5776,      -1.9080, ...,      -1.2320,      -0.4378,       0.1753],
                                       ..., 
                                       [     -0.6235,      -0.1112,       0.2428, ...,       0.5462,       0.9830,      -2.6701],
                                       [      0.4527,       0.4643,      -0.3325, ...,      -2.3723,       0.1144,       1.5826],
                                       [     -5.4386,      -4.6627,       1.1660, ...,       3.1313,       0.2574,      -0.8740],
                                      ],
                                      [
                                       [      1.0395,       0.9112,      -0.4622, ...,       1.5698,      -0.1920,      -2.1213],
                                       [     -4.7884,      -2.8707,       0.0744, ...,       1.5661,       0.1393,      -0.9221],
                                       [      3.9333,       5.8780,      -1.7432, ...,       1.1649,       0.4673,      -0.8790],
                                       ..., 
                                       [     -0.8365,      -0.5343,       0.1528, ...,       1.2459,       0.5092,      -2.3725],
                                       [      0.1358,       0.4386,      -0.0570, ...,      -3.0326,       0.8687,       1.6394],
                                       [     -4.0204,      -3.5265,       0.9346, ...,       2.2506,      -0.2951,      -1.6851],
                                      ],
                                      [
                                       [      1.3533,       0.7988,      -0.3331, ...,       1.6144,       0.0773,      -1.0566],
                                       [      0.0273,      -0.4740,       0.0589, ...,      -0.0224,       0.2123,      -0.0747],
                                       [      4.2151,       5.8705,      -1.8385, ...,      -1.2067,       0.8555,       0.5714],
                                       ..., 
                                       [     -0.4960,       0.1881,       0.2486, ...,      -5.0892,      -0.0901,      -2.8932],
                                       [     -0.2326,       0.3233,       0.0688, ...,      -2.0252,      -0.3004,       1.8363],
                                       [     -5.1900,      -3.7914,       1.1631, ...,       2.5679,       0.6468,      -0.5669],
                                      ],
                                     ]
                                     sum = -38.386154

@ggerganov
Copy link
Member

Hm, not sure. Only guess is probably the Q and K tensors are permuted and need to be unpermuted during conversion?

@pwilkin
Copy link
Collaborator Author

pwilkin commented Sep 7, 2025

@ggerganov yeah, looks like it, now I'm trying to figure how exactly.

@pwilkin
Copy link
Collaborator Author

pwilkin commented Sep 7, 2025

Looks like some interleave is going on, because the first three elements of the first vector in GGML are: 0.0260, 0.0299, 0.0009, which look like the first two elements of the first vector in Transformers, but with a value inserted between them. Not exactly sure what the operation is or why a totally standard op is suddenly misbehaving.

@ggerganov
Copy link
Member

Some models have permuted tensors which we reverse during the conversion. Take a look at this code for example:

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
head_count = self.hparams["num_attention_heads"]
head_count_kv = self.hparams.get("num_key_value_heads", head_count)
tensors: list[tuple[str, Tensor]] = []
if bid is not None and name == f"model.layers.{bid}.self_attn.W_pack.weight":
logger.info(f"Unpacking and permuting layer {bid}")
tensors = [
(self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid),
self._reverse_hf_permute_part(data_torch, 0, head_count, head_count)),
(self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid),
self._reverse_hf_permute_part(data_torch, 1, head_count, head_count_kv)),
(self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_V, bid),
self._reverse_hf_part(data_torch, 2)),
]
else:
tensors = [(self.map_tensor_name(name), data_torch)]
return tensors
def _reverse_hf_permute(self, weights: Tensor, n_head: int, n_kv_head: int | None = None) -> Tensor:
if n_kv_head is not None and n_head != n_kv_head:
n_head //= n_kv_head
return (
weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
.swapaxes(1, 2)
.reshape(weights.shape)
)
def _reverse_hf_permute_part(
self, weights: Tensor, n_part: int, n_head: int, n_head_kv: int | None = None,
) -> Tensor:
r = weights.shape[0] // 3
return self._reverse_hf_permute(weights[r * n_part:r * n_part + r, ...], n_head, n_head_kv)
def _reverse_hf_part(self, weights: Tensor, n_part: int) -> Tensor:
r = weights.shape[0] // 3
return weights[r * n_part:r * n_part + r, ...]

Maybe something similar has to be applied here too.

@pwilkin
Copy link
Collaborator Author

pwilkin commented Sep 7, 2025

My only worry is that there's nothing either in the Transformers implementation or the ChatLLM implementation (based on GGML after all) which suggests that there's any special treatment to be done. It seems like a normal extension of Llama3 with RMS norm. Even the tensor sizes look pretty standard. The conversion code is basically "extend Llama and add the extra tensors for the xIELU activation".

Comment on lines 18889 to 18897
ggml_tensor * alpha_n = model.layers[il].ffn_act_alpha_n;
ggml_tensor * alpha_p = model.layers[il].ffn_act_alpha_p;
ggml_tensor * beta = model.layers[il].ffn_act_beta;
ggml_tensor * eps = model.layers[il].ffn_act_eps;

float alpha_n_val = get_scalar_f32_val(alpha_n);
float alpha_p_val = get_scalar_f32_val(alpha_p);
float beta_val = get_scalar_f32_val(beta);
float eps_val = get_scalar_f32_val(eps);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will end up being a very big performance penalty as there will be 4 extra round-trip CPU <--> backend each time a cgraph is constructed

A much simpler way is to read this as scalar value and write it as GGUF metadata upon conversion. This should also allow implementing xielu with existing GGML ops (no need to create new op)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, was going to optimize the xIELU act as soon as I got a correct implementation working. You mean store them as vectors and load them as something like n_head_arr?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, exactly. Also on the weight I noticed some tensors are empty on all layers, so probably you can just hard-code it to 0.0 on the first prototype.

A small tips when working with these kind of models is to start with a tiny random weight. For example, this is the code to generate tiny random llama 4

@mukel
Copy link

mukel commented Sep 7, 2025

I also ported the model, note that Apertus does a transpose before the QK norm, while Qwen3 does it after; seems like a copy-paste blunder on Apertus.

https://github.com/huggingface/transformers/blob/bb45d3631ec7026db04a77d33a52b31766372160/src/transformers/models/apertus/modular_apertus.py#L254

https://github.com/huggingface/transformers/blob/bb45d3631ec7026db04a77d33a52b31766372160/src/transformers/models/apertus/modular_apertus.py#L254

@pwilkin
Copy link
Collaborator Author

pwilkin commented Sep 7, 2025

Actually, the Q projection already diverges and I have no idea why... the attention norms are the same, but then I get this weird interleaving:

[ **0.0287**, 0.0177, **-0.0021**, -0.0019, **-0.0123**, -0.0065, 0.0100... for first tensor of Q (attn_q * attn_norm) in llama.cpp

[ 0.0287, -0.0021, -0.0123, ... for first tensor of Q in Transformers

What could cause this sort of interleaving?

@ngxson
Copy link
Collaborator

ngxson commented Sep 8, 2025

You're extending the conversion class from LlamaModel, which undo a permutation for llama models. It's best to have a look at the base class

@theo77186
Copy link

I also ported the model, note that Apertus does a transpose before the QK norm, while Qwen3 does it after; seems like a copy-paste blunder on Apertus.

https://github.com/huggingface/transformers/blob/bb45d3631ec7026db04a77d33a52b31766372160/src/transformers/models/apertus/modular_apertus.py#L254

https://github.com/huggingface/transformers/blob/bb45d3631ec7026db04a77d33a52b31766372160/src/transformers/models/apertus/modular_apertus.py#L254

I don't think the order of transpose is important as the RMSNorm is applied only on the last dimension of the tensor. In both cases the dimension is something like (..., head_dim) which should result to the same thing.
Unrelated: how do you print the intermediate activations for llama.cpp? I'd like to take a shot at this.

@pwilkin
Copy link
Collaborator Author

pwilkin commented Sep 8, 2025

@theo77186

Unrelated: how do you print the intermediate activations for llama.cpp? I'd like to take a shot at this.

llama-eval-callback

@ngxson Thanks! That fixed the qproj problem, adding undo_permute = False makes them align correctly with the Transformers step.

Now onto the norm application (which is still wrong)...

@pwilkin
Copy link
Collaborator Author

pwilkin commented Sep 8, 2025

Update: norm application is good after doing reshape pre-norm, now of course RoPE is wrong...

@ggerganov
Copy link
Member

Update: norm application is good after doing reshape pre-norm, now of course RoPE is wrong...

Try to change the RoPE type to NEoX

@pwilkin
Copy link
Collaborator Author

pwilkin commented Sep 8, 2025

@ggerganov yup, that fixed RoPE. On to xIELU :>

Copy link
Collaborator

@JohannesGaessler JohannesGaessler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I think I may have looked at the wrong version. The CUDA changes as they are right now are fine.

@JohannesGaessler
Copy link
Collaborator

(To be clear, the CI still needs to be fixed.)

@pwilkin
Copy link
Collaborator Author

pwilkin commented Sep 24, 2025

@JohannesGaessler aight, I'm flailing in the dark a bit here, but I think I fixed the CUDA code now.

Co-authored-by: Johannes Gäßler <[email protected]>
@ggerganov
Copy link
Member

ggerganov commented Sep 25, 2025

I think we can merge after @slaren takes a look at the ggml changes and @CISC at the model changes, assuming the CUDA is all good now.

Generally speaking, for model support my preferred approach is to prioritize a working CPU implementation in the initial PR and to add support for the other backends in separate PRs.

Yes, going forward we should try to separate the implementation of new models better in order to simplify the review process.

@pwilkin
Copy link
Collaborator Author

pwilkin commented Sep 25, 2025

Yes, going forward we should try to separate the implementation of new models better in order to simplify the review process.

Ye, will try to stick to the "CPU first, everything else later as a PR" model for the future, seems much easier to handle.

EDIT: I already regret the moment I decided to do any CUDA stuff here ;)

Copy link
Member

@slaren slaren left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please document much better w

Comment on lines +105 to +107
static inline float softplus(float input) {
return (input > 20.0f) ? input : logf(1 + expf(input));
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function doesn't belong here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where should I put it? It's used in two more places (see discussion above), which is why I put it here.

Comment on lines +2667 to +2668
ggml_set_op_params_f32(result, 1, beta + softplus(alpha_n));
ggml_set_op_params_f32(result, 2, softplus(alpha_p));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the changes to alpha_n and alpha_p?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quoting from the paper:

• To constrain αp > 0, we apply the softplus function
to αp. This ensures a linearly increasing gradient for
positive inputs.
• To constrain αn > βn, we add βn to the softplus of αn.
This ensures the presence of negative-valued gradients
for negative inputs.

Comment on lines 199 to 200
float alpha_n = ggml_get_op_params_f32(dst, 1);
float alpha_p = ggml_get_op_params_f32(dst, 2);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you read these values as alpha_n and alpha_p here when what is stored is something else?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's alpha_n and alpha_p, only constrained (see above comment).

@pwilkin
Copy link
Collaborator Author

pwilkin commented Sep 25, 2025

I didn't manage to implement it the way @ngxson suggested, so I just added a new functor-based abstraction in unary-ops and reverted all the other operations back to how they were.

@pwilkin pwilkin requested a review from danbev as a code owner September 25, 2025 20:00
@github-actions github-actions bot added devops improvements to build systems and github actions server Apple Metal https://en.wikipedia.org/wiki/Metal_(API) labels Sep 25, 2025
@pwilkin pwilkin force-pushed the apertus-implementation branch from c982ba7 to d29cb45 Compare September 25, 2025 20:05
@pwilkin
Copy link
Collaborator Author

pwilkin commented Sep 25, 2025

Should be gtg.

dtype = cls._dtype_str_map[st_slice.get_dtype()]
shape: tuple[int, ...] = tuple(st_slice.get_shape())
lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(st_slice,), func=lambda s: s[:])
lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(st_slice,), func=lambda s: s[...] if len(s.get_shape()) == 0 else s[:])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why add this now, didn't you work around this already?

Probably belongs in a separate PR.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it turns out the workaround that I had works if you use --remote, but not if you convert from a local repo.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bugfix and I think it's too tiny for a separate PR anyway.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No such thing. :) It's better for visibility to have separate PRs, but since it's an issue specifically for this model it's probably ok to have it here.

Co-authored-by: Johannes Gäßler <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Apple Metal https://en.wikipedia.org/wiki/Metal_(API) devops improvements to build systems and github actions examples ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs python python script changes server testing Everything test related
Projects
None yet