-
Notifications
You must be signed in to change notification settings - Fork 13.2k
Feat: Apertus model implementation #15852
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
Would solve #15748 |
@CISC would be grateful for some hints if you see something obvious that could be wrong with the QK norm / RoPE. |
@pwilkin you can check rope-type and the soft-plus thing. |
Rope type is fine I thing (normal, not NeoX), softplus doesn't matter as the problem is before the xIELU activation (first tensor matches reference through all 30 layers, next ones diverge starting from first kqv-out). |
Norm application is wrong, but I totally can't figure out how exactly :/ |
Which norm? |
@ggerganov KQ norm. |
I think forgot to apply the norm bias? |
@ggerganov That's the thing, it doesn't have a norm bias: "weight_map": {
"lm_head.weight": "model-00004-of-00004.safetensors",
"model.embed_tokens.weight": "model-00001-of-00004.safetensors",
"model.layers.0.attention_layernorm.weight": "model-00001-of-00004.safetensors",
"model.layers.0.feedforward_layernorm.weight": "model-00001-of-00004.safetensors",
"model.layers.0.mlp.act_fn.alpha_n": "model-00001-of-00004.safetensors",
"model.layers.0.mlp.act_fn.alpha_p": "model-00001-of-00004.safetensors",
"model.layers.0.mlp.act_fn.beta": "model-00001-of-00004.safetensors",
"model.layers.0.mlp.act_fn.eps": "model-00001-of-00004.safetensors",
"model.layers.0.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
"model.layers.0.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
"model.layers.0.self_attn.k_norm.weight": "model-00001-of-00004.safetensors",
"model.layers.0.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
"model.layers.0.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
"model.layers.0.self_attn.q_norm.weight": "model-00001-of-00004.safetensors",
"model.layers.0.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
"model.layers.0.self_attn.v_proj.weight": "model-00001-of-00004.safetensors", Looks exactly like the QK norm in Qwen3, a 128 flat tensor for both. But it somehow diverges from the Transformers implementation. Here's a dump from Transformers: ggml_debug: model.layers.0.self_attn.k_norm_in = (f32) ... = {torch.Size([1, 8, 6, 128])}
[
[
[ 0.0260, 0.0009, -0.0400, ..., 0.0216, 0.0396, -0.0355]
[ -0.0006, 0.0012, -0.0057, ..., -0.0108, -0.0005, -0.0019]
[ 0.0505, -0.0432, 0.0101, ..., 0.0056, -0.0203, 0.0132]
[ -0.0055, -0.0051, -0.0081, ..., -0.3099, -0.0374, -0.0569]
[ -0.0056, -0.0075, -0.0123, ..., 0.1299, -0.1645, -0.0148]
[ 0.0052, -0.0125, 0.0299, ..., 0.0775, -0.1068, -0.0833]
],
]
sum = -4.458256
ggml_debug: model.layers.0.self_attn.k_norm_out = (f32) ... = {torch.Size([1, 8, 6, 128])}
[
[
[ 1.6487, 0.0560, -1.1489, ..., 0.8809, 1.1181, -1.0577]
[ -0.0343, 0.0733, -0.1573, ..., -0.5669, -0.0163, -0.0744]
[ 3.6375, -3.0737, 0.3281, ..., 0.3351, -0.8343, 0.5715]
[ -0.2226, -0.2069, -0.1493, ..., -10.3275, -0.8633, -1.3851]
[ -0.1887, -0.2521, -0.1895, ..., 3.9971, -3.5061, -0.3333]
[ 0.3068, -0.7309, 0.7994, ..., 3.6865, -3.5230, -2.8940]
],
]
sum = 518.159729 and from llama-eval-callback: ggml_debug: Kcur-0 = (f32) MUL_MAT(blk.0.attn_k.weight{4096, 1024, 1, 1}, attn_norm-0{4096, 6, 1, 1}}) = {1024, 6, 1, 1}
[
[
[ 0.0260, 0.0299, 0.0009, ..., 0.0173, 0.0058, -0.0157],
[ 0.0602, 0.0272, -0.0239, ..., 0.0578, 0.0108, -0.0360],
[ 0.1024, 0.0396, 0.0892, ..., 0.0203, 0.0063, -0.0092],
[ 0.0371, 0.0511, -0.0486, ..., 0.0725, 0.0086, -0.0277],
[ 0.0441, 0.0391, -0.0433, ..., 0.0447, -0.0085, -0.0458],
[ 0.0566, 0.0338, -0.0308, ..., 0.0555, 0.0202, -0.0168],
],
]
sum = -4.452065
ggml_debug: Kcur_normed-0 = (f32) MUL(CUDA0#norm-0#0{128, 8, 6, 1}, CUDA0#blk.0.attn_k_norm.weight#0{128, 1, 1, 1}}) = {128, 8, 6, 1}
[
[
[ 1.6491, 1.8751, 0.0248, ..., 3.8993, 0.2504, -0.7591],
[ -0.0332, -0.1835, 0.0339, ..., 0.7072, 0.2774, 0.4394],
[ 3.6366, 2.1278, -1.4065, ..., 0.4759, -0.0021, -0.0724],
...,
[ 0.3081, -0.1988, -0.3355, ..., -1.0101, -0.1090, -2.8183],
[ 2.0318, 0.6510, -0.9407, ..., -3.0910, -0.1242, 1.1163],
[ -1.1660, -1.4152, 0.5272, ..., 1.8761, 0.4326, -1.2444],
],
[
[ 1.5329, 0.6841, -0.2749, ..., 2.7403, -0.5124, -1.3517],
[ -2.0542, -1.4692, -0.0407, ..., 1.1136, 0.5337, 0.1236],
[ 3.9727, 6.0690, -2.1473, ..., -2.2552, -1.5008, -0.9382],
...,
[ -0.2568, -0.4579, -0.3575, ..., 1.8921, 1.0777, -3.7264],
[ 0.5295, 0.1762, -0.1126, ..., -3.3283, 0.6000, 1.7422],
[ -5.0935, -4.9426, 1.0877, ..., 2.7902, 0.3603, -1.2684],
],
[
[ 4.1699, 1.5921, 1.6440, ..., 1.3991, -0.0073, -1.1039],
[ -0.0246, -0.0408, -0.0693, ..., -0.2578, -0.3472, 0.1711],
[ 0.2052, 0.0656, -0.4119, ..., 0.1593, 0.0395, -0.5277],
...,
[ 1.1948, 0.1338, -0.5707, ..., 5.8187, 0.8847, -2.7834],
[ 0.4890, -0.2239, -0.0053, ..., -2.8752, -0.2403, 1.6677],
[ -0.2667, -0.4034, 0.6344, ..., 2.1922, 0.4676, -0.7278],
],
[
[ 0.9241, 1.2575, -0.5471, ..., 3.1444, 0.3413, -0.1515],
[ -2.0858, -0.2047, -0.6143, ..., 0.7987, 1.1992, 0.1377],
[ 4.4787, 6.5776, -1.9080, ..., -1.2320, -0.4378, 0.1753],
...,
[ -0.6235, -0.1112, 0.2428, ..., 0.5462, 0.9830, -2.6701],
[ 0.4527, 0.4643, -0.3325, ..., -2.3723, 0.1144, 1.5826],
[ -5.4386, -4.6627, 1.1660, ..., 3.1313, 0.2574, -0.8740],
],
[
[ 1.0395, 0.9112, -0.4622, ..., 1.5698, -0.1920, -2.1213],
[ -4.7884, -2.8707, 0.0744, ..., 1.5661, 0.1393, -0.9221],
[ 3.9333, 5.8780, -1.7432, ..., 1.1649, 0.4673, -0.8790],
...,
[ -0.8365, -0.5343, 0.1528, ..., 1.2459, 0.5092, -2.3725],
[ 0.1358, 0.4386, -0.0570, ..., -3.0326, 0.8687, 1.6394],
[ -4.0204, -3.5265, 0.9346, ..., 2.2506, -0.2951, -1.6851],
],
[
[ 1.3533, 0.7988, -0.3331, ..., 1.6144, 0.0773, -1.0566],
[ 0.0273, -0.4740, 0.0589, ..., -0.0224, 0.2123, -0.0747],
[ 4.2151, 5.8705, -1.8385, ..., -1.2067, 0.8555, 0.5714],
...,
[ -0.4960, 0.1881, 0.2486, ..., -5.0892, -0.0901, -2.8932],
[ -0.2326, 0.3233, 0.0688, ..., -2.0252, -0.3004, 1.8363],
[ -5.1900, -3.7914, 1.1631, ..., 2.5679, 0.6468, -0.5669],
],
]
sum = -38.386154 |
Hm, not sure. Only guess is probably the Q and K tensors are permuted and need to be unpermuted during conversion? |
@ggerganov yeah, looks like it, now I'm trying to figure how exactly. |
Looks like some interleave is going on, because the first three elements of the first vector in GGML are: |
Some models have permuted tensors which we reverse during the conversion. Take a look at this code for example: llama.cpp/convert_hf_to_gguf.py Lines 1627 to 1668 in 3976dfb
Maybe something similar has to be applied here too. |
My only worry is that there's nothing either in the Transformers implementation or the ChatLLM implementation (based on GGML after all) which suggests that there's any special treatment to be done. It seems like a normal extension of Llama3 with RMS norm. Even the tensor sizes look pretty standard. The conversion code is basically "extend Llama and add the extra tensors for the xIELU activation". |
src/llama-model.cpp
Outdated
ggml_tensor * alpha_n = model.layers[il].ffn_act_alpha_n; | ||
ggml_tensor * alpha_p = model.layers[il].ffn_act_alpha_p; | ||
ggml_tensor * beta = model.layers[il].ffn_act_beta; | ||
ggml_tensor * eps = model.layers[il].ffn_act_eps; | ||
|
||
float alpha_n_val = get_scalar_f32_val(alpha_n); | ||
float alpha_p_val = get_scalar_f32_val(alpha_p); | ||
float beta_val = get_scalar_f32_val(beta); | ||
float eps_val = get_scalar_f32_val(eps); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will end up being a very big performance penalty as there will be 4 extra round-trip CPU <--> backend each time a cgraph is constructed
A much simpler way is to read this as scalar value and write it as GGUF metadata upon conversion. This should also allow implementing xielu with existing GGML ops (no need to create new op)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, was going to optimize the xIELU act as soon as I got a correct implementation working. You mean store them as vectors and load them as something like n_head_arr
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, exactly. Also on the weight I noticed some tensors are empty on all layers, so probably you can just hard-code it to 0.0 on the first prototype.
A small tips when working with these kind of models is to start with a tiny random weight. For example, this is the code to generate tiny random llama 4
I also ported the model, note that Apertus does a transpose before the QK norm, while Qwen3 does it after; seems like a copy-paste blunder on Apertus. |
Actually, the Q projection already diverges and I have no idea why... the attention norms are the same, but then I get this weird interleaving:
What could cause this sort of interleaving? |
You're extending the conversion class from |
I don't think the order of transpose is important as the RMSNorm is applied only on the last dimension of the tensor. In both cases the dimension is something like (..., head_dim) which should result to the same thing. |
@ngxson Thanks! That fixed the qproj problem, adding Now onto the norm application (which is still wrong)... |
Update: norm application is good after doing reshape pre-norm, now of course RoPE is wrong... |
Try to change the RoPE type to NEoX |
@ggerganov yup, that fixed RoPE. On to xIELU :> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I think I may have looked at the wrong version. The CUDA changes as they are right now are fine.
(To be clear, the CI still needs to be fixed.) |
@JohannesGaessler aight, I'm flailing in the dark a bit here, but I think I fixed the CUDA code now. |
Co-authored-by: Johannes Gäßler <[email protected]>
I think we can merge after @slaren takes a look at the
Yes, going forward we should try to separate the implementation of new models better in order to simplify the review process. |
Ye, will try to stick to the "CPU first, everything else later as a PR" model for the future, seems much easier to handle. EDIT: I already regret the moment I decided to do any CUDA stuff here ;) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please document much better w
static inline float softplus(float input) { | ||
return (input > 20.0f) ? input : logf(1 + expf(input)); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function doesn't belong here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where should I put it? It's used in two more places (see discussion above), which is why I put it here.
ggml_set_op_params_f32(result, 1, beta + softplus(alpha_n)); | ||
ggml_set_op_params_f32(result, 2, softplus(alpha_p)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why the changes to alpha_n
and alpha_p
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Quoting from the paper:
• To constrain αp > 0, we apply the softplus function
to αp. This ensures a linearly increasing gradient for
positive inputs.
• To constrain αn > βn, we add βn to the softplus of αn.
This ensures the presence of negative-valued gradients
for negative inputs.
ggml/src/ggml-cpu/unary-ops.cpp
Outdated
float alpha_n = ggml_get_op_params_f32(dst, 1); | ||
float alpha_p = ggml_get_op_params_f32(dst, 2); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you read these values as alpha_n
and alpha_p
here when what is stored is something else?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's alpha_n and alpha_p, only constrained (see above comment).
…ate system to retain optimizations
I didn't manage to implement it the way @ngxson suggested, so I just added a new functor-based abstraction in unary-ops and reverted all the other operations back to how they were. |
Co-authored-by: Sigbjørn Skjæret <[email protected]>
c982ba7
to
d29cb45
Compare
Should be gtg. |
dtype = cls._dtype_str_map[st_slice.get_dtype()] | ||
shape: tuple[int, ...] = tuple(st_slice.get_shape()) | ||
lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(st_slice,), func=lambda s: s[:]) | ||
lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(st_slice,), func=lambda s: s[...] if len(s.get_shape()) == 0 else s[:]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why add this now, didn't you work around this already?
Probably belongs in a separate PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, it turns out the workaround that I had works if you use --remote, but not if you convert from a local repo.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a bugfix and I think it's too tiny for a separate PR anyway.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No such thing. :) It's better for visibility to have separate PRs, but since it's an issue specifically for this model it's probably ok to have it here.
Co-authored-by: Johannes Gäßler <[email protected]>
Help!
This seems like an easy enough model. I've done the XIELU implementation, but for some reason, the tensor outputs diverge in the QKV calculations (it's either the RoPE or the norm). I've tried everything, spent a whole day debugging it, maybe I'm missing something really obvious?
@foldl you did the implementation in chatllm.cpp and it works fine, maybe you can see something obvious here?