From 64add8224607afe0de6cbd5c74f27ce9c83e92d8 Mon Sep 17 00:00:00 2001 From: Piotr Wilkin Date: Sun, 7 Sep 2025 10:10:51 +0200 Subject: [PATCH 01/32] First attempt --- convert_hf_to_gguf.py | 11 ++ ggml/include/ggml.h | 11 ++ ggml/src/ggml-cpu/ggml-cpu.c | 5 + ggml/src/ggml-cpu/unary-ops.cpp | 89 ++++++--- ggml/src/ggml-cpu/unary-ops.h | 1 + ggml/src/ggml.c | 27 ++- gguf-py/gguf/constants.py | 36 ++++ gguf-py/gguf/tensor_mapping.py | 44 +++++ src/llama-arch.cpp | 28 +++ src/llama-arch.h | 5 + src/llama-model.cpp | 319 ++++++++++++++++++++++++++++++++ src/llama-model.h | 6 + 12 files changed, 551 insertions(+), 31 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 717b8a6595010..83936bc5e72c2 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -8566,6 +8566,17 @@ def prepare_tensors(self): raise ValueError(f"Unprocessed experts: {experts}") +@ModelBase.register("ApertusForCausalLM") +class ApertusModel(LlamaModel): + model_arch = gguf.MODEL_ARCH.APERTUS + + def modify_tensors(self, data_torch, name, bid): + # Handle xIELU activation parameters + if name.endswith(".act_fn.alpha_n") or name.endswith(".act_fn.alpha_p") or name.endswith(".act_fn.beta") or name.endswith(".act_fn.eps"): + return [(self.map_tensor_name(name), data_torch)] + + return super().modify_tensors(data_torch, name, bid) + class MistralModel(LlamaModel): model_arch = gguf.MODEL_ARCH.LLAMA model_name = "Mistral" diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index c01b98ac78f5a..4df020db9372b 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -554,6 +554,7 @@ extern "C" { GGML_OP_OPT_STEP_SGD, GGML_OP_GLU, + GGML_OP_XIELU, GGML_OP_COUNT, }; @@ -1148,6 +1149,16 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); + // xIELU activation function + // x = x * (alpha_n + alpha_p * sigmoid(beta * x)) + eps * (x > 0) + GGML_API struct ggml_tensor * ggml_xielu( + struct ggml_context * ctx, + struct ggml_tensor * a, + float alpha_n, + float alpha_p, + float beta, + float eps); + // gated linear unit ops // A: n columns, r rows, // result is n / 2 columns, r rows, diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 0d35d9333e3f5..98f2c90b25532 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -1971,6 +1971,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_unary(params, tensor); } break; + case GGML_OP_XIELU: + { + ggml_compute_forward_xielu(params, tensor); + } break; case GGML_OP_GLU: { ggml_compute_forward_glu(params, tensor); @@ -2160,6 +2164,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_REPEAT: case GGML_OP_REPEAT_BACK: case GGML_OP_LEAKY_RELU: + case GGML_OP_XIELU: { n_tasks = 1; } break; diff --git a/ggml/src/ggml-cpu/unary-ops.cpp b/ggml/src/ggml-cpu/unary-ops.cpp index 4fce569b3bfc8..1d3f0cef9400e 100644 --- a/ggml/src/ggml-cpu/unary-ops.cpp +++ b/ggml/src/ggml-cpu/unary-ops.cpp @@ -52,6 +52,15 @@ static inline float op_sqrt(float x) { return sqrtf(x); } +static inline float op_xielu(float x, float alpha_n, float alpha_p, float beta, float eps) { + if (x > 0.0f) { + return alpha_p * x * x + beta * x; + } else { + const float min_x_eps = fminf(x, eps); + return (expm1f(min_x_eps) - x) * alpha_n + beta * x; + } +} + static inline float op_sin(float x) { return sinf(x); } @@ -64,8 +73,8 @@ static inline float op_log(float x) { return logf(x); } -template -static inline void vec_unary_op(int64_t n, dst_t * y, const src0_t * x) { +template +static inline void vec_unary_op(const Op& op, int64_t n, dst_t * y, const src0_t * x) { constexpr auto src0_to_f32 = type_conversion_table::to_f32; constexpr auto f32_to_dst = type_conversion_table::from_f32; @@ -74,8 +83,8 @@ static inline void vec_unary_op(int64_t n, dst_t * y, const src0_t * x) { } } -template -static void apply_unary_op(const ggml_compute_params * params, ggml_tensor * dst) { +template +static void apply_unary_op(const Op& op, const ggml_compute_params * params, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; GGML_ASSERT(ggml_is_contiguous_1(src0) && ggml_is_contiguous_1(dst) && ggml_are_same_shape(src0, dst)); @@ -95,25 +104,25 @@ static void apply_unary_op(const ggml_compute_params * params, ggml_tensor * dst dst_t * dst_ptr = (dst_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); const src0_t * src0_ptr = (const src0_t *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); - vec_unary_op(ne0, dst_ptr, src0_ptr); + vec_unary_op(op, ne0, dst_ptr, src0_ptr); } } // TODO: Use the 'traits' lookup table (for type conversion fns), instead of a mass of 'if' conditions with long templates -template -static void unary_op(const ggml_compute_params * params, ggml_tensor * dst) { +template +static void unary_op(const Op& op, const ggml_compute_params * params, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; /* */ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { // all f32 - apply_unary_op(params, dst); + apply_unary_op(op, params, dst); } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { // all f16 - apply_unary_op(params, dst); + apply_unary_op(op, params, dst); } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_BF16) { // all bf16 - apply_unary_op(params, dst); + apply_unary_op(op, params, dst); } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_F32) { - apply_unary_op(params, dst); + apply_unary_op(op, params, dst); } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) { - apply_unary_op(params, dst); + apply_unary_op(op, params, dst); } else { fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s\n", __func__, ggml_type_name(dst->type), ggml_type_name(src0->type)); @@ -122,65 +131,89 @@ static void unary_op(const ggml_compute_params * params, ggml_tensor * dst) { } void ggml_compute_forward_abs(const ggml_compute_params * params, ggml_tensor * dst) { - unary_op(params, dst); + unary_op(op_abs, params, dst); } void ggml_compute_forward_sgn(const ggml_compute_params * params, ggml_tensor * dst) { - unary_op(params, dst); + unary_op(op_sgn, params, dst); } void ggml_compute_forward_neg(const ggml_compute_params * params, ggml_tensor * dst) { - unary_op(params, dst); + unary_op(op_neg, params, dst); } void ggml_compute_forward_step(const ggml_compute_params * params, ggml_tensor * dst) { - unary_op(params, dst); + unary_op(op_step, params, dst); } void ggml_compute_forward_tanh(const ggml_compute_params * params, ggml_tensor * dst) { - unary_op(params, dst); + unary_op(op_tanh, params, dst); } void ggml_compute_forward_elu(const ggml_compute_params * params, ggml_tensor * dst) { - unary_op(params, dst); + unary_op(op_elu, params, dst); } void ggml_compute_forward_relu(const ggml_compute_params * params, ggml_tensor * dst) { - unary_op(params, dst); + unary_op(op_relu, params, dst); } void ggml_compute_forward_sigmoid(const ggml_compute_params * params, ggml_tensor * dst) { - unary_op(params, dst); + unary_op(op_sigmoid, params, dst); } void ggml_compute_forward_hardsigmoid(const ggml_compute_params * params, ggml_tensor * dst) { - unary_op(params, dst); + unary_op(op_hardsigmoid, params, dst); } void ggml_compute_forward_exp(const ggml_compute_params * params, ggml_tensor * dst) { - unary_op(params, dst); + unary_op(op_exp, params, dst); } void ggml_compute_forward_hardswish(const ggml_compute_params * params, ggml_tensor * dst) { - unary_op(params, dst); + unary_op(op_hardswish, params, dst); } void ggml_compute_forward_sqr(const ggml_compute_params * params, ggml_tensor * dst) { - unary_op(params, dst); + unary_op(op_sqr, params, dst); } void ggml_compute_forward_sqrt(const ggml_compute_params * params, ggml_tensor * dst) { - unary_op(params, dst); + unary_op(op_sqrt, params, dst); } void ggml_compute_forward_sin(const ggml_compute_params * params, ggml_tensor * dst) { - unary_op(params, dst); + unary_op(op_sin, params, dst); } void ggml_compute_forward_cos(const ggml_compute_params * params, ggml_tensor * dst) { - unary_op(params, dst); + unary_op(op_cos, params, dst); } void ggml_compute_forward_log(const ggml_compute_params * params, ggml_tensor * dst) { - unary_op(params, dst); + unary_op(op_log, params, dst); +} + +static float softplus(float input, float beta=1.0f, float threshold=20.0f) { + if (input * beta > threshold) return input; + return (1/beta) * logf(1 + expf(beta * input)); } + +void ggml_compute_forward_xielu(const ggml_compute_params * params, ggml_tensor * dst) { + // Get the XIELU parameters from the operation + const float * op_params = (const float*)dst->op_params; + float alpha_n = op_params[0]; + float alpha_p = op_params[1]; + const float beta = op_params[2]; + const float eps = op_params[3]; + +// alpha_p = softplus(alpha_p); +// alpha_n = beta + softplus(alpha_n); + + const auto xielu_op_params = [alpha_n, alpha_p, beta, eps](float f) { + return op_xielu(f, alpha_n, alpha_p, beta, eps); + }; + + unary_op(xielu_op_params, params, dst); +} + diff --git a/ggml/src/ggml-cpu/unary-ops.h b/ggml/src/ggml-cpu/unary-ops.h index b1ade2c8e341f..697c1e0da0ace 100644 --- a/ggml/src/ggml-cpu/unary-ops.h +++ b/ggml/src/ggml-cpu/unary-ops.h @@ -22,6 +22,7 @@ void ggml_compute_forward_sqrt(const struct ggml_compute_params * params, struct void ggml_compute_forward_sin(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_cos(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_log(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_xielu(const struct ggml_compute_params * params, struct ggml_tensor * dst); #ifdef __cplusplus } diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index f35c337952ec3..ab7c628d7db97 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1017,9 +1017,10 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "OPT_STEP_SGD", "GLU", + "XIELU", }; -static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90"); +static_assert(GGML_OP_COUNT == 91, "GGML_OP_COUNT != 90"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1121,9 +1122,10 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "sgd(x)", "glu(x)", + "xielu(x)", }; -static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90"); +static_assert(GGML_OP_COUNT == 91, "GGML_OP_COUNT != 90"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -1147,7 +1149,6 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = { static_assert(GGML_UNARY_OP_COUNT == 15, "GGML_UNARY_OP_COUNT != 15"); - static const char * GGML_GLU_OP_NAME[GGML_GLU_OP_COUNT] = { "REGLU", "GEGLU", @@ -2646,6 +2647,26 @@ struct ggml_tensor * ggml_silu( return ggml_unary(ctx, a, GGML_UNARY_OP_SILU); } +// ggml_xielu +struct ggml_tensor * ggml_xielu( + struct ggml_context * ctx, + struct ggml_tensor * a, + float alpha_n, + float alpha_p, + float beta, + float eps) { + struct ggml_tensor * result = ggml_dup_tensor(ctx, a); + + // Store the parameters as operation parameters + float params[] = { alpha_n, alpha_p, beta, eps }; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_XIELU; + result->src[0] = a; + + return result; +} + struct ggml_tensor * ggml_silu_inplace( struct ggml_context * ctx, struct ggml_tensor * a) { diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index c922bacf6eb97..55a44920d2225 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -391,6 +391,7 @@ class MODEL_ARCH(IntEnum): SMALLTHINKER = auto() LLADA = auto() SEED_OSS = auto() + APERTUS = auto() class VISION_PROJECTOR_TYPE(IntEnum): @@ -440,6 +441,10 @@ class MODEL_TENSOR(IntEnum): FFN_GATE_SHEXP = auto() FFN_DOWN_SHEXP = auto() FFN_UP_SHEXP = auto() + FFN_ACT_ALPHA_N = auto() + FFN_ACT_ALPHA_P = auto() + FFN_ACT_BETA = auto() + FFN_ACT_EPS = auto() FFN_EXP_PROBS_B = auto() ATTN_Q_NORM = auto() ATTN_K_NORM = auto() @@ -727,6 +732,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.SMALLTHINKER: "smallthinker", MODEL_ARCH.LLADA: "llada", MODEL_ARCH.SEED_OSS: "seed_oss", + MODEL_ARCH.APERTUS: "apertus", } VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = { @@ -773,12 +779,20 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_GATE_SHEXP: "blk.{bid}.ffn_gate_shexp", MODEL_TENSOR.FFN_DOWN_SHEXP: "blk.{bid}.ffn_down_shexp", MODEL_TENSOR.FFN_UP_SHEXP: "blk.{bid}.ffn_up_shexp", + MODEL_TENSOR.FFN_ACT_ALPHA_N: "blk.{bid}.ffn_act_alpha_n", + MODEL_TENSOR.FFN_ACT_ALPHA_P: "blk.{bid}.ffn_act_alpha_p", + MODEL_TENSOR.FFN_ACT_BETA: "blk.{bid}.ffn_act_beta", + MODEL_TENSOR.FFN_ACT_EPS: "blk.{bid}.ffn_act_eps", MODEL_TENSOR.FFN_ACT: "blk.{bid}.ffn", MODEL_TENSOR.FFN_NORM_EXP: "blk.{bid}.ffn_norm_exps", MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate_exps", MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down_exps", MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up_exps", MODEL_TENSOR.FFN_EXP_PROBS_B: "blk.{bid}.exp_probs_b", + MODEL_TENSOR.FFN_ACT_ALPHA_N: "blk.{bid}.ffn_act_alpha_n", + MODEL_TENSOR.FFN_ACT_ALPHA_P: "blk.{bid}.ffn_act_alpha_p", + MODEL_TENSOR.FFN_ACT_BETA: "blk.{bid}.ffn_act_beta", + MODEL_TENSOR.FFN_ACT_EPS: "blk.{bid}.ffn_act_eps", MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm", MODEL_TENSOR.PER_LAYER_TOKEN_EMBD: "per_layer_token_embd", # gemma3n MODEL_TENSOR.PER_LAYER_MODEL_PROJ: "per_layer_model_proj", # gemma3n @@ -2683,6 +2697,28 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_DOWN_EXP, MODEL_TENSOR.FFN_UP_EXP, ], + MODEL_ARCH.APERTUS: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_ROT_EMBD, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_ACT_ALPHA_N, + MODEL_TENSOR.FFN_ACT_ALPHA_P, + MODEL_TENSOR.FFN_ACT_BETA, + MODEL_TENSOR.FFN_ACT_EPS, + ], # TODO } diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index b0c3d65e95847..cf0dfefecfe64 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -447,6 +447,22 @@ class TensorNameMap: "layers.{bid}.mlp.gate_proj", # qwen3-embedding ), + MODEL_TENSOR.FFN_ACT_ALPHA_N: ( + "model.layers.{bid}.mlp.act_fn.alpha_n", # apertus xIELU + ), + + MODEL_TENSOR.FFN_ACT_ALPHA_P: ( + "model.layers.{bid}.mlp.act_fn.alpha_p", # apertus xIELU + ), + + MODEL_TENSOR.FFN_ACT_BETA: ( + "model.layers.{bid}.mlp.act_fn.beta", # apertus xIELU + ), + + MODEL_TENSOR.FFN_ACT_EPS: ( + "model.layers.{bid}.mlp.act_fn.eps", # apertus xIELU + ), + MODEL_TENSOR.FFN_GATE_EXP: ( "layers.{bid}.feed_forward.experts.w1", # mixtral (merged) "transformer.decoder_layer.{bid}.moe.linear", # Grok (merged) @@ -1471,6 +1487,34 @@ class TensorNameMap: "model.layers.{bid}.post_attention_layernorm", ), }, + MODEL_ARCH.APERTUS: { + MODEL_TENSOR.ATTN_NORM: ( + "model.layers.{bid}.attention_layernorm", + ), + MODEL_TENSOR.ATTN_Q_NORM: ( + "model.layers.{bid}.attention.query_layernorm", + "model.layers.{bid}.self_attn.q_norm", + ), + MODEL_TENSOR.ATTN_K_NORM: ( + "model.layers.{bid}.attention.key_layernorm", + "model.layers.{bid}.self_attn.k_norm", + ), + MODEL_TENSOR.FFN_NORM: ( + "model.layers.{bid}.feedforward_layernorm", + ), + MODEL_TENSOR.FFN_ACT_ALPHA_N: ( + "model.layers.{bid}.mlp.act_fn.alpha_n", + ), + MODEL_TENSOR.FFN_ACT_ALPHA_P: ( + "model.layers.{bid}.mlp.act_fn.alpha_p", + ), + MODEL_TENSOR.FFN_ACT_BETA: ( + "model.layers.{bid}.mlp.act_fn.beta", + ), + MODEL_TENSOR.FFN_ACT_EPS: ( + "model.layers.{bid}.mlp.act_fn.eps", + ), + }, } mapping: dict[str, tuple[MODEL_TENSOR, str]] diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index d92bf2dd8515f..90ee4c707dbe5 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -97,6 +97,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_SMALLTHINKER, "smallthinker" }, { LLM_ARCH_LLADA, "llada" }, { LLM_ARCH_SEED_OSS, "seed_oss" }, + { LLM_ARCH_APERTUS, "apertus" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -2100,6 +2101,29 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" } }, }, + { + LLM_ARCH_APERTUS, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_ACT_ALPHA_N, "blk.%d.ffn_act_alpha_n" }, + { LLM_TENSOR_FFN_ACT_ALPHA_P, "blk.%d.ffn_act_alpha_p" }, + { LLM_TENSOR_FFN_ACT_BETA, "blk.%d.ffn_act_beta" }, + { LLM_TENSOR_FFN_ACT_EPS, "blk.%d.ffn_act_eps" }, + }, + }, { LLM_ARCH_DREAM, { @@ -2283,6 +2307,10 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_FFN_DOWN_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, {LLM_TENSOR_FFN_GATE_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, {LLM_TENSOR_FFN_UP_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, + {LLM_TENSOR_FFN_ACT_ALPHA_N, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_FFN_ACT_ALPHA_P, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_FFN_ACT_BETA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_FFN_ACT_EPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}}, {LLM_TENSOR_FFN_EXP_PROBS_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, // altup / laurel (gemma 3n) {LLM_TENSOR_PER_LAYER_TOKEN_EMBD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}}, diff --git a/src/llama-arch.h b/src/llama-arch.h index 41f377923f50f..08a5a1c50f6d3 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -101,6 +101,7 @@ enum llm_arch { LLM_ARCH_SMALLTHINKER, LLM_ARCH_LLADA, LLM_ARCH_SEED_OSS, + LLM_ARCH_APERTUS, LLM_ARCH_UNKNOWN, }; @@ -292,6 +293,10 @@ enum llm_tensor { LLM_TENSOR_FFN_GATE_SHEXP, LLM_TENSOR_FFN_UP_SHEXP, LLM_TENSOR_FFN_EXP_PROBS_B, + LLM_TENSOR_FFN_ACT_ALPHA_N, + LLM_TENSOR_FFN_ACT_ALPHA_P, + LLM_TENSOR_FFN_ACT_BETA, + LLM_TENSOR_FFN_ACT_EPS, LLM_TENSOR_ATTN_Q_NORM, LLM_TENSOR_ATTN_K_NORM, LLM_TENSOR_LAYER_OUT_NORM, diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 1813f06d7b308..f501c1392868f 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1945,6 +1945,15 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_APERTUS: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + hparams.n_ctx_orig_yarn = 8192; + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_8B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; default: throw std::runtime_error("unsupported model architecture"); } @@ -5719,6 +5728,55 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0); } } break; + case LLM_ARCH_APERTUS: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + + if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + else { + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_gqa }, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_gqa }, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + + // optional bias tensors + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), { n_embd }, TENSOR_NOT_REQUIRED); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), { n_embd_gqa }, TENSOR_NOT_REQUIRED); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), { n_embd_gqa }, TENSOR_NOT_REQUIRED); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), { n_embd }, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0); + + // Q and K layernorms for Apertus + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0); + layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), { n_embd_head_k }, TENSOR_NOT_REQUIRED); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0); + layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), { n_embd_head_k }, TENSOR_NOT_REQUIRED); + + // xIELU parameters for Apertus + layer.ffn_act_alpha_n = create_tensor(tn(LLM_TENSOR_FFN_ACT_ALPHA_N, i), { 1 }, 0); + layer.ffn_act_alpha_p = create_tensor(tn(LLM_TENSOR_FFN_ACT_ALPHA_P, i), { 1 }, 0); + layer.ffn_act_beta = create_tensor(tn(LLM_TENSOR_FFN_ACT_BETA, i), { 1 }, 0); + layer.ffn_act_eps = create_tensor(tn(LLM_TENSOR_FFN_ACT_EPS, i), { 1 }, 0); + } + } break; default: throw std::runtime_error("unknown architecture"); } @@ -18620,6 +18678,262 @@ struct llm_build_smallthinker : public llm_graph_context{ } }; +// TODO: maybe put this as a general helper in ggml.c? +static float get_scalar_f32_val(const ggml_tensor *t) { + float onef; + if (t->buffer) { + ggml_backend_tensor_get(t, &onef, 0, sizeof(float)); + } else { + GGML_ASSERT(t->data); + onef = *((float *) t->data); + } + return onef; +} + +static float ggml_get_float_value(uint8_t * data, ggml_type type, const size_t * nb, size_t i0, size_t i1, size_t i2, size_t i3) { + size_t i = i3 * nb[3] + i2 * nb[2] + i1 * nb[1] + i0 * nb[0]; + float v; + if (type == GGML_TYPE_F16) { + v = ggml_fp16_to_fp32(*(ggml_fp16_t *) &data[i]); + } else if (type == GGML_TYPE_F32) { + v = *(float *) &data[i]; + } else if (type == GGML_TYPE_I64) { + v = (float) *(int64_t *) &data[i]; + } else if (type == GGML_TYPE_I32) { + v = (float) *(int32_t *) &data[i]; + } else if (type == GGML_TYPE_I16) { + v = (float) *(int16_t *) &data[i]; + } else if (type == GGML_TYPE_I8) { + v = (float) *(int8_t *) &data[i]; + } else { + GGML_ABORT("fatal error"); + } + return v; +} + +static void ggml_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne, const size_t * nb, int64_t n) { + GGML_ASSERT(n > 0); + float sum = 0; + for (int64_t i3 = 0; i3 < ne[3]; i3++) { + for (int64_t i2 = 0; i2 < ne[2]; i2++) { + for (int64_t i1 = 0; i1 < ne[1]; i1++) { + for (int64_t i0 = 0; i0 < ne[0]; i0++) { + const float v = ggml_get_float_value(data, type, nb, i0, i1, i2, i3); + sum += v; + } + } + } + } + for (int64_t i3 = 0; i3 < ne[3]; i3++) { + LLAMA_LOG_DEBUG(" [\n"); + for (int64_t i2 = 0; i2 < ne[2]; i2++) { + if (i2 == n && ne[2] > 2*n) { + LLAMA_LOG_DEBUG(" ..., \n"); + i2 = ne[2] - n; + } + LLAMA_LOG_DEBUG(" [\n"); + for (int64_t i1 = 0; i1 < ne[1]; i1++) { + if (i1 == n && ne[1] > 2*n) { + LLAMA_LOG_DEBUG(" ..., \n"); + i1 = ne[1] - n; + } + LLAMA_LOG_DEBUG(" ["); + for (int64_t i0 = 0; i0 < ne[0]; i0++) { + if (i0 == n && ne[0] > 2*n) { + LLAMA_LOG_DEBUG("..., "); + i0 = ne[0] - n; + } + const float v = ggml_get_float_value(data, type, nb, i0, i1, i2, i3); + LLAMA_LOG_DEBUG("%12.4f", v); + if (i0 < ne[0] - 1) LLAMA_LOG_DEBUG(", "); + } + LLAMA_LOG_DEBUG("],\n"); + } + LLAMA_LOG_DEBUG(" ],\n"); + } + LLAMA_LOG_DEBUG(" ]\n"); + LLAMA_LOG_DEBUG(" sum = %f\n", sum); + } + + // TODO: make this abort configurable/optional? + if (std::isnan(sum)) { + LLAMA_LOG_ERROR("encountered NaN - aborting\n"); + exit(0); + } +} + +// Apertus model graph builder with xIELU activation +struct llm_build_apertus : public llm_graph_context { + llm_build_apertus(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + ggml_tensor * inp_pos = build_inp_pos(); + auto * inp_attn = build_attn_inp_kv(); + + const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); + + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + ggml_tensor * Q2d = ggml_reshape_2d(ctx0, Qcur, n_embd_head, n_head * n_tokens); + ggml_tensor * K2d = ggml_reshape_2d(ctx0, Kcur, n_embd_head, n_head_kv * n_tokens); + + cb(Q2d, "Q2D", il); + cb(K2d, "K2D", il); + + // apply existing rms-norm which was originally written for 2D + Q2d = build_norm(Q2d, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); + K2d = build_norm(K2d, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); + + cb(Q2d, "Q2D_normed", il); + cb(K2d, "K2D_normed", il); + + // reshape back to 3D + Qcur = ggml_reshape_3d(ctx0, Q2d, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, K2d, n_embd_head, n_head_kv, n_tokens); + + cb(Qcur, "Qcur_normed", il); + cb(Kcur, "Kcur_normed", il); + + // // copy the data from the GPU memory if needed + // const bool is_host = ggml_backend_buffer_is_host(rope_factors->buffer); + + // auto n_bytes = ggml_nbytes(rope_factors); + // uint8_t loaded_data[n_bytes]; + // if (!is_host) { + // ggml_backend_tensor_get(rope_factors, &loaded_data, 0, n_bytes); + // } + + // if (!ggml_is_quantized(rope_factors->type)) { + // uint8_t * data = is_host ? (uint8_t *) rope_factors->data : &loaded_data[0]; + // ggml_print_tensor(data, rope_factors->type, rope_factors->ne, rope_factors->nb, 64); + // } + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur_pos", il); + cb(Kcur, "Kcur_pos", il); + cb(Vcur, "Vcur_pos", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(cur, "attn_out", il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network with xIELU activation + { + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + // Up projection + ggml_tensor * up = build_lora_mm(model.layers[il].ffn_up, cur); + cb(up, "ffn_up", il); + + // xIELU activation + // Get the xIELU parameters from the model layers + ggml_tensor * alpha_n = model.layers[il].ffn_act_alpha_n; + ggml_tensor * alpha_p = model.layers[il].ffn_act_alpha_p; + ggml_tensor * beta = model.layers[il].ffn_act_beta; + ggml_tensor * eps = model.layers[il].ffn_act_eps; + + float alpha_n_val = get_scalar_f32_val(alpha_n); + float alpha_p_val = get_scalar_f32_val(alpha_p); + float beta_val = get_scalar_f32_val(beta); + float eps_val = get_scalar_f32_val(eps); + + // Apply xIELU activation + ggml_tensor * activated = ggml_xielu(ctx0, up, alpha_n_val, alpha_p_val, beta_val, eps_val); + cb(activated, "ffn_xielu", il); + + // Down projection + cur = build_lora_mm(model.layers[il].ffn_down, activated); + cb(cur, "ffn_down", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const { llama_memory_i * res; @@ -19132,6 +19446,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { llm = std::make_unique>(*this, params); } } break; + case LLM_ARCH_APERTUS: + { + llm = std::make_unique(*this, params); + } break; default: GGML_ABORT("fatal error"); } @@ -19288,6 +19606,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_ARCEE: case LLM_ARCH_ERNIE4_5: case LLM_ARCH_ERNIE4_5_MOE: + case LLM_ARCH_APERTUS: return LLAMA_ROPE_TYPE_NORM; // the pairs of head values are offset by n_rot/2 diff --git a/src/llama-model.h b/src/llama-model.h index 10b1767f27228..9a263edaf41c0 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -370,6 +370,12 @@ struct llama_layer { // openai-moe struct ggml_tensor * attn_sinks = nullptr; + // xIELU activation parameters for Apertus + struct ggml_tensor * ffn_act_alpha_n = nullptr; + struct ggml_tensor * ffn_act_alpha_p = nullptr; + struct ggml_tensor * ffn_act_beta = nullptr; + struct ggml_tensor * ffn_act_eps = nullptr; + struct llama_layer_posnet posnet; struct llama_layer_convnext convnext; From 86cfc18430ccc659af8c5f7133ecd7526590ba38 Mon Sep 17 00:00:00 2001 From: Piotr Wilkin Date: Mon, 8 Sep 2025 12:46:11 +0200 Subject: [PATCH 02/32] No permute during convert (fixes qk tensors), proper norm application. --- convert_hf_to_gguf.py | 2 + src/llama-model.cpp | 117 +++--------------------------------------- 2 files changed, 10 insertions(+), 109 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 83936bc5e72c2..66ae220be22c5 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -8569,6 +8569,7 @@ def prepare_tensors(self): @ModelBase.register("ApertusForCausalLM") class ApertusModel(LlamaModel): model_arch = gguf.MODEL_ARCH.APERTUS + undo_permute = False def modify_tensors(self, data_torch, name, bid): # Handle xIELU activation parameters @@ -8577,6 +8578,7 @@ def modify_tensors(self, data_torch, name, bid): return super().modify_tensors(data_torch, name, bid) + class MistralModel(LlamaModel): model_arch = gguf.MODEL_ARCH.LLAMA model_name = "Mistral" diff --git a/src/llama-model.cpp b/src/llama-model.cpp index f501c1392868f..aba93d669f2af 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -18690,78 +18690,6 @@ static float get_scalar_f32_val(const ggml_tensor *t) { return onef; } -static float ggml_get_float_value(uint8_t * data, ggml_type type, const size_t * nb, size_t i0, size_t i1, size_t i2, size_t i3) { - size_t i = i3 * nb[3] + i2 * nb[2] + i1 * nb[1] + i0 * nb[0]; - float v; - if (type == GGML_TYPE_F16) { - v = ggml_fp16_to_fp32(*(ggml_fp16_t *) &data[i]); - } else if (type == GGML_TYPE_F32) { - v = *(float *) &data[i]; - } else if (type == GGML_TYPE_I64) { - v = (float) *(int64_t *) &data[i]; - } else if (type == GGML_TYPE_I32) { - v = (float) *(int32_t *) &data[i]; - } else if (type == GGML_TYPE_I16) { - v = (float) *(int16_t *) &data[i]; - } else if (type == GGML_TYPE_I8) { - v = (float) *(int8_t *) &data[i]; - } else { - GGML_ABORT("fatal error"); - } - return v; -} - -static void ggml_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne, const size_t * nb, int64_t n) { - GGML_ASSERT(n > 0); - float sum = 0; - for (int64_t i3 = 0; i3 < ne[3]; i3++) { - for (int64_t i2 = 0; i2 < ne[2]; i2++) { - for (int64_t i1 = 0; i1 < ne[1]; i1++) { - for (int64_t i0 = 0; i0 < ne[0]; i0++) { - const float v = ggml_get_float_value(data, type, nb, i0, i1, i2, i3); - sum += v; - } - } - } - } - for (int64_t i3 = 0; i3 < ne[3]; i3++) { - LLAMA_LOG_DEBUG(" [\n"); - for (int64_t i2 = 0; i2 < ne[2]; i2++) { - if (i2 == n && ne[2] > 2*n) { - LLAMA_LOG_DEBUG(" ..., \n"); - i2 = ne[2] - n; - } - LLAMA_LOG_DEBUG(" [\n"); - for (int64_t i1 = 0; i1 < ne[1]; i1++) { - if (i1 == n && ne[1] > 2*n) { - LLAMA_LOG_DEBUG(" ..., \n"); - i1 = ne[1] - n; - } - LLAMA_LOG_DEBUG(" ["); - for (int64_t i0 = 0; i0 < ne[0]; i0++) { - if (i0 == n && ne[0] > 2*n) { - LLAMA_LOG_DEBUG("..., "); - i0 = ne[0] - n; - } - const float v = ggml_get_float_value(data, type, nb, i0, i1, i2, i3); - LLAMA_LOG_DEBUG("%12.4f", v); - if (i0 < ne[0] - 1) LLAMA_LOG_DEBUG(", "); - } - LLAMA_LOG_DEBUG("],\n"); - } - LLAMA_LOG_DEBUG(" ],\n"); - } - LLAMA_LOG_DEBUG(" ]\n"); - LLAMA_LOG_DEBUG(" sum = %f\n", sum); - } - - // TODO: make this abort configurable/optional? - if (std::isnan(sum)) { - LLAMA_LOG_ERROR("encountered NaN - aborting\n"); - exit(0); - } -} - // Apertus model graph builder with xIELU activation struct llm_build_apertus : public llm_graph_context { llm_build_apertus(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { @@ -18785,9 +18713,8 @@ struct llm_build_apertus : public llm_graph_context { for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; - // norm cur = build_norm(inpL, - model.layers[il].attn_norm, NULL, + model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); cb(cur, "attn_norm", il); @@ -18806,42 +18733,14 @@ struct llm_build_apertus : public llm_graph_context { cb(Vcur, "Vcur", il); Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - - ggml_tensor * Q2d = ggml_reshape_2d(ctx0, Qcur, n_embd_head, n_head * n_tokens); - ggml_tensor * K2d = ggml_reshape_2d(ctx0, Kcur, n_embd_head, n_head_kv * n_tokens); - - cb(Q2d, "Q2D", il); - cb(K2d, "K2D", il); - - // apply existing rms-norm which was originally written for 2D - Q2d = build_norm(Q2d, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); - K2d = build_norm(K2d, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); - - cb(Q2d, "Q2D_normed", il); - cb(K2d, "K2D_normed", il); - - // reshape back to 3D - Qcur = ggml_reshape_3d(ctx0, Q2d, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, K2d, n_embd_head, n_head_kv, n_tokens); - + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); cb(Qcur, "Qcur_normed", il); + + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); cb(Kcur, "Kcur_normed", il); - - // // copy the data from the GPU memory if needed - // const bool is_host = ggml_backend_buffer_is_host(rope_factors->buffer); - - // auto n_bytes = ggml_nbytes(rope_factors); - // uint8_t loaded_data[n_bytes]; - // if (!is_host) { - // ggml_backend_tensor_get(rope_factors, &loaded_data, 0, n_bytes); - // } - // if (!ggml_is_quantized(rope_factors->type)) { - // uint8_t * data = is_host ? (uint8_t *) rope_factors->data : &loaded_data[0]; - // ggml_print_tensor(data, rope_factors->type, rope_factors->ne, rope_factors->nb, 64); - // } + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, rope_factors, @@ -18876,7 +18775,7 @@ struct llm_build_apertus : public llm_graph_context { // feed-forward network with xIELU activation { cur = build_norm(ffn_inp, - model.layers[il].ffn_norm, NULL, + model.layers[il].ffn_norm, nullptr, LLM_NORM_RMS, il); cb(cur, "ffn_norm", il); @@ -18918,7 +18817,7 @@ struct llm_build_apertus : public llm_graph_context { cur = inpL; cur = build_norm(cur, - model.output_norm, NULL, + model.output_norm, nullptr, LLM_NORM_RMS, -1); cb(cur, "result_norm", -1); From 8c762a33c4b4a5454835d681f6a8d6554ae63db9 Mon Sep 17 00:00:00 2001 From: Piotr Wilkin Date: Mon, 8 Sep 2025 12:53:21 +0200 Subject: [PATCH 03/32] RoPE = NeoX --- src/llama-model.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index aba93d669f2af..40563367295ce 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -19505,7 +19505,6 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_ARCEE: case LLM_ARCH_ERNIE4_5: case LLM_ARCH_ERNIE4_5_MOE: - case LLM_ARCH_APERTUS: return LLAMA_ROPE_TYPE_NORM; // the pairs of head values are offset by n_rot/2 @@ -19554,6 +19553,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_SMALLTHINKER: case LLM_ARCH_GLM4_MOE: case LLM_ARCH_SEED_OSS: + case LLM_ARCH_APERTUS: return LLAMA_ROPE_TYPE_NEOX; case LLM_ARCH_QWEN2VL: From ffdfd1df9fe1aaebc249c67e72c24d0746cba256 Mon Sep 17 00:00:00 2001 From: Piotr Wilkin Date: Mon, 8 Sep 2025 12:58:39 +0200 Subject: [PATCH 04/32] Coherence! --- ggml/src/ggml.c | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index ab7c628d7db97..4352e193398a1 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -2648,6 +2648,11 @@ struct ggml_tensor * ggml_silu( } // ggml_xielu +static float softplus(float input) { + if (input > 20.0f) return input; + return logf(1 + expf(input)); +} + struct ggml_tensor * ggml_xielu( struct ggml_context * ctx, struct ggml_tensor * a, @@ -2658,7 +2663,7 @@ struct ggml_tensor * ggml_xielu( struct ggml_tensor * result = ggml_dup_tensor(ctx, a); // Store the parameters as operation parameters - float params[] = { alpha_n, alpha_p, beta, eps }; + float params[] = { beta + softplus(alpha_n), softplus(alpha_p), beta, eps }; ggml_set_op_params(result, params, sizeof(params)); result->op = GGML_OP_XIELU; From ab11d9467e8b689291284d063d0c86fd5b41c2a2 Mon Sep 17 00:00:00 2001 From: Piotr Wilkin Date: Tue, 9 Sep 2025 19:21:31 +0200 Subject: [PATCH 05/32] Migrate xielu params from tensors to hyperparameters --- convert_hf_to_gguf.py | 28 ++++++++++++++++++++++++++-- gguf-py/gguf/constants.py | 19 +++++++------------ gguf-py/gguf/gguf_writer.py | 12 ++++++++++++ src/llama-arch.cpp | 13 +++++-------- src/llama-arch.h | 9 +++++---- src/llama-hparams.h | 6 ++++++ src/llama-model-loader.cpp | 2 ++ src/llama-model.cpp | 33 ++++++++++++++------------------- 8 files changed, 77 insertions(+), 45 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index fcd8bd6b7b00c..59e893038c6fd 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -8585,10 +8585,34 @@ class ApertusModel(LlamaModel): model_arch = gguf.MODEL_ARCH.APERTUS undo_permute = False + _alpha_n = {} + _alpha_p = {} + _beta = {} + _eps = {} + def modify_tensors(self, data_torch, name, bid): # Handle xIELU activation parameters - if name.endswith(".act_fn.alpha_n") or name.endswith(".act_fn.alpha_p") or name.endswith(".act_fn.beta") or name.endswith(".act_fn.eps"): - return [(self.map_tensor_name(name), data_torch)] + n_layers = self.hparams.get("num_hidden_layers") + if name.endswith(".act_fn.alpha_n"): + self._alpha_n[bid] = data_torch.to("cpu").float().item() + if (len(self._alpha_n) == n_layers): + self.gguf_writer.add_xielu_alpha_n([self._alpha_n[k] for k in sorted(self._alpha_n)]) + return [] + if name.endswith(".act_fn.alpha_p"): + self._alpha_p[bid] = data_torch.to("cpu").float().item() + if (len(self._alpha_p) == n_layers): + self.gguf_writer.add_xielu_alpha_p([self._alpha_p[k] for k in sorted(self._alpha_p)]) + return [] + if name.endswith(".act_fn.beta"): + self._beta[bid] = data_torch.to("cpu").float().item() + if (len(self._beta) == n_layers): + self.gguf_writer.add_xielu_beta([self._beta[k] for k in sorted(self._beta)]) + return [] + if name.endswith(".act_fn.eps"): + self._eps[bid] = data_torch.to("cpu").float().item() + if (len(self._eps) == n_layers): + self.gguf_writer.add_xielu_eps([self._eps[k] for k in sorted(self._eps)]) + return [] return super().modify_tensors(data_torch, name, bid) diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index a0fe2f1f09f28..5ee1fd8e04176 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -286,6 +286,13 @@ class Projector: class Diffusion: SHIFT_LOGITS = "diffusion.shift_logits" + class xIELU: + XIELU_ALPHA_P = "xielu.alpha_p" + XIELU_ALPHA_N = "xielu.alpha_n" + XIELU_BETA = "xielu.beta" + XIELU_EPS = "xielu.eps" + + # # recommended mapping of model tensor names for storage in gguf # @@ -780,20 +787,12 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_GATE_SHEXP: "blk.{bid}.ffn_gate_shexp", MODEL_TENSOR.FFN_DOWN_SHEXP: "blk.{bid}.ffn_down_shexp", MODEL_TENSOR.FFN_UP_SHEXP: "blk.{bid}.ffn_up_shexp", - MODEL_TENSOR.FFN_ACT_ALPHA_N: "blk.{bid}.ffn_act_alpha_n", - MODEL_TENSOR.FFN_ACT_ALPHA_P: "blk.{bid}.ffn_act_alpha_p", - MODEL_TENSOR.FFN_ACT_BETA: "blk.{bid}.ffn_act_beta", - MODEL_TENSOR.FFN_ACT_EPS: "blk.{bid}.ffn_act_eps", MODEL_TENSOR.FFN_ACT: "blk.{bid}.ffn", MODEL_TENSOR.FFN_NORM_EXP: "blk.{bid}.ffn_norm_exps", MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate_exps", MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down_exps", MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up_exps", MODEL_TENSOR.FFN_EXP_PROBS_B: "blk.{bid}.exp_probs_b", - MODEL_TENSOR.FFN_ACT_ALPHA_N: "blk.{bid}.ffn_act_alpha_n", - MODEL_TENSOR.FFN_ACT_ALPHA_P: "blk.{bid}.ffn_act_alpha_p", - MODEL_TENSOR.FFN_ACT_BETA: "blk.{bid}.ffn_act_beta", - MODEL_TENSOR.FFN_ACT_EPS: "blk.{bid}.ffn_act_eps", MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm", MODEL_TENSOR.PER_LAYER_TOKEN_EMBD: "per_layer_token_embd", # gemma3n MODEL_TENSOR.PER_LAYER_MODEL_PROJ: "per_layer_model_proj", # gemma3n @@ -2715,10 +2714,6 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_GATE, MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, - MODEL_TENSOR.FFN_ACT_ALPHA_N, - MODEL_TENSOR.FFN_ACT_ALPHA_P, - MODEL_TENSOR.FFN_ACT_BETA, - MODEL_TENSOR.FFN_ACT_EPS, ], # TODO } diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index a6cc8a931eb27..a431e94befbb8 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -1051,6 +1051,18 @@ def add_audio_num_mel_bins(self, value: int) -> None: def add_audio_stack_factor(self, value: int) -> None: self.add_uint32(Keys.ClipAudio.Projector.STACK_FACTOR, value) + def add_xielu_alpha_p(self, value: Sequence[float]): + self.add_array(Keys.xIELU.XIELU_ALPHA_P, value) + + def add_xielu_alpha_n(self, value: Sequence[float]): + self.add_array(Keys.xIELU.XIELU_ALPHA_N, value) + + def add_xielu_beta(self, value: Sequence[float]): + self.add_array(Keys.xIELU.XIELU_BETA, value) + + def add_xielu_eps(self, value: Sequence[float]): + self.add_array(Keys.xIELU.XIELU_EPS, value) + # diffusion models def add_diffusion_shift_logits(self, value: bool) -> None: diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index bd4732ab4ff7b..aece7ff4c1107 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -244,6 +244,11 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ADAPTER_LORA_PROMPT_PREFIX, "adapter.lora.prompt_prefix" }, { LLM_KV_ADAPTER_ALORA_INVOCATION_TOKENS, "adapter.alora.invocation_tokens" }, + { LLM_KV_XIELU_ALPHA_N, "xielu.alpha_n" }, + { LLM_KV_XIELU_ALPHA_P, "xielu.alpha_p" }, + { LLM_KV_XIELU_BETA, "xielu.beta" }, + { LLM_KV_XIELU_EPS, "xielu.eps" }, + // deprecated { LLM_KV_TOKENIZER_PREFIX_ID, "tokenizer.ggml.prefix_token_id" }, { LLM_KV_TOKENIZER_SUFFIX_ID, "tokenizer.ggml.suffix_token_id" }, @@ -2119,10 +2124,6 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - { LLM_TENSOR_FFN_ACT_ALPHA_N, "blk.%d.ffn_act_alpha_n" }, - { LLM_TENSOR_FFN_ACT_ALPHA_P, "blk.%d.ffn_act_alpha_p" }, - { LLM_TENSOR_FFN_ACT_BETA, "blk.%d.ffn_act_beta" }, - { LLM_TENSOR_FFN_ACT_EPS, "blk.%d.ffn_act_eps" }, }, }, { @@ -2308,10 +2309,6 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_FFN_DOWN_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, {LLM_TENSOR_FFN_GATE_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, {LLM_TENSOR_FFN_UP_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, - {LLM_TENSOR_FFN_ACT_ALPHA_N, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}}, - {LLM_TENSOR_FFN_ACT_ALPHA_P, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}}, - {LLM_TENSOR_FFN_ACT_BETA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}}, - {LLM_TENSOR_FFN_ACT_EPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}}, {LLM_TENSOR_FFN_EXP_PROBS_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, // altup / laurel (gemma 3n) {LLM_TENSOR_PER_LAYER_TOKEN_EMBD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}}, diff --git a/src/llama-arch.h b/src/llama-arch.h index 03e80a48b83e0..b998dcf466754 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -248,6 +248,11 @@ enum llm_kv { LLM_KV_SHORTCONV_L_CACHE, + LLM_KV_XIELU_ALPHA_N, + LLM_KV_XIELU_ALPHA_P, + LLM_KV_XIELU_BETA, + LLM_KV_XIELU_EPS, + // deprecated: LLM_KV_TOKENIZER_PREFIX_ID, LLM_KV_TOKENIZER_SUFFIX_ID, @@ -294,10 +299,6 @@ enum llm_tensor { LLM_TENSOR_FFN_GATE_SHEXP, LLM_TENSOR_FFN_UP_SHEXP, LLM_TENSOR_FFN_EXP_PROBS_B, - LLM_TENSOR_FFN_ACT_ALPHA_N, - LLM_TENSOR_FFN_ACT_ALPHA_P, - LLM_TENSOR_FFN_ACT_BETA, - LLM_TENSOR_FFN_ACT_EPS, LLM_TENSOR_ATTN_Q_NORM, LLM_TENSOR_ATTN_K_NORM, LLM_TENSOR_LAYER_OUT_NORM, diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 89f5c7ab65dce..bb094ae68dd11 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -156,6 +156,12 @@ struct llama_hparams { uint32_t laurel_rank = 64; uint32_t n_embd_altup = 256; + // xIELU + std::array xielu_alpha_n; + std::array xielu_alpha_p; + std::array xielu_beta; + std::array xielu_eps; + // needed by encoder-decoder models (e.g. T5, FLAN-T5) // ref: https://github.com/ggerganov/llama.cpp/pull/8141 llama_token dec_start_token_id = LLAMA_TOKEN_NULL; diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp index 8182a9adf53a6..aa3a65f87a542 100644 --- a/src/llama-model-loader.cpp +++ b/src/llama-model-loader.cpp @@ -465,6 +465,8 @@ namespace GGUFMeta { // TODO: this is not very clever - figure out something better template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required); template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required); + template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required); + llama_model_loader::llama_model_loader( const std::string & fname, diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 314ade6b3be83..46b844250ef91 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -508,9 +508,13 @@ void llama_model::load_hparams(llama_model_loader & ml) { llm_arch_is_recurrent(ml.get_arch())); std::fill(hparams.rope_sections.begin(), hparams.rope_sections.end(), 0); - std::fill(hparams.swa_layers.begin(), hparams.swa_layers.end(), 0); + std::fill(hparams.xielu_alpha_n.begin(), hparams.xielu_alpha_n.end(), 0); + std::fill(hparams.xielu_alpha_p.begin(), hparams.xielu_alpha_p.end(), 0); + std::fill(hparams.xielu_beta.begin(), hparams.xielu_beta.end(), 0); + std::fill(hparams.xielu_eps.begin(), hparams.xielu_eps.end(), 0); + ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, hparams.n_layer, false); ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer, false); @@ -1948,7 +1952,11 @@ void llama_model::load_hparams(llama_model_loader & ml) { case LLM_ARCH_APERTUS: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - hparams.n_ctx_orig_yarn = 8192; + ml.get_key_or_arr(LLM_KV_XIELU_ALPHA_N, hparams.xielu_alpha_n, hparams.n_layer); + ml.get_key_or_arr(LLM_KV_XIELU_ALPHA_P, hparams.xielu_alpha_p, hparams.n_layer); + ml.get_key_or_arr(LLM_KV_XIELU_BETA, hparams.xielu_beta, hparams.n_layer); + ml.get_key_or_arr(LLM_KV_XIELU_EPS, hparams.xielu_eps, hparams.n_layer); + switch (hparams.n_layer) { case 32: type = LLM_TYPE_8B; break; default: type = LLM_TYPE_UNKNOWN; @@ -5769,12 +5777,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), { n_embd_head_k }, TENSOR_NOT_REQUIRED); layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0); layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), { n_embd_head_k }, TENSOR_NOT_REQUIRED); - - // xIELU parameters for Apertus - layer.ffn_act_alpha_n = create_tensor(tn(LLM_TENSOR_FFN_ACT_ALPHA_N, i), { 1 }, 0); - layer.ffn_act_alpha_p = create_tensor(tn(LLM_TENSOR_FFN_ACT_ALPHA_P, i), { 1 }, 0); - layer.ffn_act_beta = create_tensor(tn(LLM_TENSOR_FFN_ACT_BETA, i), { 1 }, 0); - layer.ffn_act_eps = create_tensor(tn(LLM_TENSOR_FFN_ACT_EPS, i), { 1 }, 0); } } break; default: @@ -18727,17 +18729,10 @@ struct llm_build_apertus : public llm_graph_context { ggml_tensor * up = build_lora_mm(model.layers[il].ffn_up, cur); cb(up, "ffn_up", il); - // xIELU activation - // Get the xIELU parameters from the model layers - ggml_tensor * alpha_n = model.layers[il].ffn_act_alpha_n; - ggml_tensor * alpha_p = model.layers[il].ffn_act_alpha_p; - ggml_tensor * beta = model.layers[il].ffn_act_beta; - ggml_tensor * eps = model.layers[il].ffn_act_eps; - - float alpha_n_val = get_scalar_f32_val(alpha_n); - float alpha_p_val = get_scalar_f32_val(alpha_p); - float beta_val = get_scalar_f32_val(beta); - float eps_val = get_scalar_f32_val(eps); + float alpha_n_val = hparams.xielu_alpha_n[il]; + float alpha_p_val = hparams.xielu_alpha_p[il]; + float beta_val = hparams.xielu_beta[il]; + float eps_val = hparams.xielu_eps[il]; // Apply xIELU activation ggml_tensor * activated = ggml_xielu(ctx0, up, alpha_n_val, alpha_p_val, beta_val, eps_val); From 1b184721502cfdd6e90ea537b0a7e023116553ff Mon Sep 17 00:00:00 2001 From: Piotr Wilkin Date: Sat, 13 Sep 2025 21:17:30 +0200 Subject: [PATCH 06/32] Simple CUDA kernel --- ggml/src/ggml-cpu/ggml-cpu.c | 2 +- ggml/src/ggml-cpu/unary-ops.cpp | 3 - ggml/src/ggml-cuda/ggml-cuda.cu | 3 + ggml/src/ggml-cuda/unary.cu | 282 ++++++++++-------- ggml/src/ggml-cuda/unary.cuh | 2 + models/templates/Apertus-8B-Instruct.jinja | 327 +++++++++++++++++++++ 6 files changed, 486 insertions(+), 133 deletions(-) create mode 100644 models/templates/Apertus-8B-Instruct.jinja diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 6f3291fc6f99e..f3922f988c97d 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -2144,6 +2144,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_ADD_ID: case GGML_OP_ADD1: case GGML_OP_ACC: + case GGML_OP_XIELU: { n_tasks = n_threads; } break; @@ -2167,7 +2168,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_REPEAT: case GGML_OP_REPEAT_BACK: case GGML_OP_LEAKY_RELU: - case GGML_OP_XIELU: { n_tasks = 1; } break; diff --git a/ggml/src/ggml-cpu/unary-ops.cpp b/ggml/src/ggml-cpu/unary-ops.cpp index 1d3f0cef9400e..80ac5f7f5588b 100644 --- a/ggml/src/ggml-cpu/unary-ops.cpp +++ b/ggml/src/ggml-cpu/unary-ops.cpp @@ -207,9 +207,6 @@ void ggml_compute_forward_xielu(const ggml_compute_params * params, ggml_tensor const float beta = op_params[2]; const float eps = op_params[3]; -// alpha_p = softplus(alpha_p); -// alpha_n = beta + softplus(alpha_n); - const auto xielu_op_params = [alpha_n, alpha_p, beta, eps](float f) { return op_xielu(f, alpha_n, alpha_p, beta, eps); }; diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 0f68d685348d2..c09584bb042ea 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2517,6 +2517,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_OPT_STEP_SGD: ggml_cuda_opt_step_sgd(ctx, dst); break; + case GGML_OP_XIELU: + ggml_cuda_op_xielu(ctx, dst); + break; default: return false; } diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index 5aff8a876af2c..ca34524b5c3f8 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -1,94 +1,8 @@ #include "unary.cuh" -static __device__ __forceinline__ float op_abs(float x) { - return fabsf(x); -} - -static __device__ __forceinline__ float op_sgn(float x) { - return (x > 0.f ? 1.f : ((x < 0.f ? -1.f : 0.f))); -} - -static __device__ __forceinline__ float op_neg(float x) { - return -x; -} - -static __device__ __forceinline__ float op_step(float x) { - return x > 0.0f; -} - -static __device__ __forceinline__ float op_gelu(float x) { - const float GELU_COEF_A = 0.044715f; - const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; - - return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); -} - -static __device__ __forceinline__ float op_gelu_erf(float x) { - const float SQRT_2_INV = 0.70710678118654752440084436210484f; - - return 0.5f*x*(1.0f + erff(x*SQRT_2_INV)); -} - -static __device__ __forceinline__ float op_gelu_quick(float x) { - const float GELU_QUICK_COEF = -1.702f; - - return x * (1.0f / (1.0f + expf(GELU_QUICK_COEF * x))); -} - -static __device__ __forceinline__ float op_silu(float x) { - return x / (1.0f + expf(-x)); -} - -static __device__ __forceinline__ float op_tanh(float x) { - return tanhf(x); -} - -static __device__ __forceinline__ float op_relu(float x) { - return fmaxf(x, 0); -} - -static __device__ __forceinline__ float op_sigmoid(float x) { - return 1.0f / (1.0f + expf(-x)); -} - -static __device__ __forceinline__ float op_hardsigmoid(float x) { - return fminf(1.0f, fmaxf(0.0f, (x + 3.0f) / 6.0f)); -} - -static __device__ __forceinline__ float op_hardswish(float x) { - return x * fminf(1.0f, fmaxf(0.0f, (x + 3.0f) / 6.0f)); -} - -static __device__ __forceinline__ float op_exp(float x) { - return expf(x); -} - -static __device__ __forceinline__ float op_sqr(float x) { - return x * x; -} - -static __device__ __forceinline__ float op_sqrt(float x) { - return sqrtf(x); -} - -static __device__ __forceinline__ float op_sin(float x) { - return sinf(x); -} - -static __device__ __forceinline__ float op_cos(float x) { - return cosf(x); -} - -static __device__ __forceinline__ float op_log(float x) { - return logf(x); -} - -static __device__ __forceinline__ float op_elu(float x) { - return (x > 0.f) ? x : expm1f(x); -} -template -static __global__ void unary_op_kernel(const T * x, T * dst, const int k) { +template +static __global__ void unary_op_kernel(const T * x, T * dst, const int k, const Op op) { const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { @@ -98,14 +12,14 @@ static __global__ void unary_op_kernel(const T * x, T * dst, const int k) { dst[i] = (T)op((float)x[i]); } -template -static void unary_cuda(const T * x, T * dst, const int k, cudaStream_t stream) { +template +static void unary_cuda(const T * x, T * dst, const int k, const Op op, cudaStream_t stream) { const int num_blocks = (k + CUDA_NEG_BLOCK_SIZE - 1) / CUDA_NEG_BLOCK_SIZE; - unary_op_kernel<<>>(x, dst, k); + unary_op_kernel<<>>(x, dst, k, op); } -template -void ggml_cuda_op_unary(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { +template +void ggml_cuda_op_unary(ggml_backend_cuda_context & ctx, ggml_tensor * dst, const Op op) { const ggml_tensor * src0 = dst->src[0]; const void * src0_d = src0->data; void * dst_d = dst->data; @@ -118,95 +32,159 @@ void ggml_cuda_op_unary(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { GGML_ASSERT(src0->type == dst->type); if (src0->type == GGML_TYPE_F16) { - unary_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream); + unary_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), op, stream); } else { - unary_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream); + unary_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), op, stream); } } void ggml_cuda_op_abs(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary(ctx, dst); + auto op = [] __device__ (float x) -> float { + return fabsf(x); + }; + ggml_cuda_op_unary(ctx, dst, op); } void ggml_cuda_op_sgn(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary(ctx, dst); + auto op = [] __device__ (float x) -> float { + return (x > 0.f ? 1.f : ((x < 0.f ? -1.f : 0.f))); + }; + ggml_cuda_op_unary(ctx, dst, op); } void ggml_cuda_op_neg(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary(ctx, dst); + auto op = [] __device__ (float x) -> float { + return -x; + }; + ggml_cuda_op_unary(ctx, dst, op); } void ggml_cuda_op_step(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary(ctx, dst); + auto op = [] __device__ (float x) -> float { + return x > 0.0f; + }; + ggml_cuda_op_unary(ctx, dst, op); } void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary(ctx, dst); + auto op = [] __device__ (float x) -> float { + const float GELU_COEF_A = 0.044715f; + const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; + return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); + }; + ggml_cuda_op_unary(ctx, dst, op); } void ggml_cuda_op_gelu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary(ctx, dst); + auto op = [] __device__ (float x) -> float { + const float SQRT_2_INV = 0.70710678118654752440084436210484f; + return 0.5f*x*(1.0f + erff(x*SQRT_2_INV)); + }; + ggml_cuda_op_unary(ctx, dst, op); } void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary(ctx, dst); + auto op = [] __device__ (float x) -> float { + const float GELU_QUICK_COEF = -1.702f; + return x * (1.0f / (1.0f + expf(GELU_QUICK_COEF * x))); + }; + ggml_cuda_op_unary(ctx, dst, op); } void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary(ctx, dst); + auto op = [] __device__ (float x) -> float { + return x / (1.0f + expf(-x)); + }; + ggml_cuda_op_unary(ctx, dst, op); } void ggml_cuda_op_tanh(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary(ctx, dst); + auto op = [] __device__ (float x) -> float { + return tanhf(x); + }; + ggml_cuda_op_unary(ctx, dst, op); } void ggml_cuda_op_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary(ctx, dst); + auto op = [] __device__ (float x) -> float { + return fmaxf(x, 0); + }; + ggml_cuda_op_unary(ctx, dst, op); } void ggml_cuda_op_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary(ctx, dst); + auto op = [] __device__ (float x) -> float { + return 1.0f / (1.0f + expf(-x)); + }; + ggml_cuda_op_unary(ctx, dst, op); } void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary(ctx, dst); + auto op = [] __device__ (float x) -> float { + return fminf(1.0f, fmaxf(0.0f, (x + 3.0f) / 6.0f)); + }; + ggml_cuda_op_unary(ctx, dst, op); } void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary(ctx, dst); + auto op = [] __device__ (float x) -> float { + return x * fminf(1.0f, fmaxf(0.0f, (x + 3.0f) / 6.0f)); + }; + ggml_cuda_op_unary(ctx, dst, op); } void ggml_cuda_op_exp(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary(ctx, dst); + auto op = [] __device__ (float x) -> float { + return expf(x); + }; + ggml_cuda_op_unary(ctx, dst, op); } void ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary(ctx, dst); + auto op = [] __device__ (float x) -> float { + return x * x; + }; + ggml_cuda_op_unary(ctx, dst, op); } void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary(ctx, dst); + auto op = [] __device__ (float x) -> float { + return sqrtf(x); + }; + ggml_cuda_op_unary(ctx, dst, op); } void ggml_cuda_op_sin(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary(ctx, dst); + auto op = [] __device__ (float x) -> float { + return sinf(x); + }; + ggml_cuda_op_unary(ctx, dst, op); } void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary(ctx, dst); + auto op = [] __device__ (float x) -> float { + return cosf(x); + }; + ggml_cuda_op_unary(ctx, dst, op); } void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary(ctx, dst); + auto op = [] __device__ (float x) -> float { + return logf(x); + }; + ggml_cuda_op_unary(ctx, dst, op); } void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary(ctx, dst); + auto op = [] __device__ (float x) -> float { + return (x > 0.f) ? x : expm1f(x); + }; + ggml_cuda_op_unary(ctx, dst, op); } /* gated ops */ -template -static __global__ void unary_gated_op_kernel(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1) { +template +static __global__ void unary_gated_op_kernel(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, const Op op) { const int64_t i = int64_t(blockDim.x)*blockIdx.x + threadIdx.x; if (i >= k) { @@ -220,14 +198,14 @@ static __global__ void unary_gated_op_kernel(const T * x, const T * g, T * dst, dst[i] = (T)(op((float)x[j0]) * (float)g[j1]); } -template -static void unary_gated_cuda(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, cudaStream_t stream) { +template +static void unary_gated_cuda(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, const Op op, cudaStream_t stream) { const int64_t num_blocks = (k + CUDA_GLU_BLOCK_SIZE - 1) / CUDA_GLU_BLOCK_SIZE; - unary_gated_op_kernel<<>>(x, g, dst, k, n, o0, o1); + unary_gated_op_kernel<<>>(x, g, dst, k, n, o0, o1, op); } -template -void ggml_cuda_op_unary_gated(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { +template +void ggml_cuda_op_unary_gated(ggml_backend_cuda_context & ctx, ggml_tensor * dst, const Op op) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; void * src0_d = src0->data; @@ -266,7 +244,7 @@ void ggml_cuda_op_unary_gated(ggml_backend_cuda_context & ctx, ggml_tensor * dst src1_p += swapped ? 0 : nc; } - unary_gated_cuda(src0_p, src1_p, (half *)dst_d, ggml_nelements(dst), nc, src0_o / sizeof(half), src1_o / sizeof(half), stream); + unary_gated_cuda(src0_p, src1_p, (half *)dst_d, ggml_nelements(dst), nc, src0_o / sizeof(half), src1_o / sizeof(half), op, stream); } else { float * src0_p = (float *) src0_d; float * src1_p = (float *) src1_d; @@ -276,28 +254,74 @@ void ggml_cuda_op_unary_gated(ggml_backend_cuda_context & ctx, ggml_tensor * dst src1_p += swapped ? 0 : nc; } - unary_gated_cuda(src0_p, src1_p, (float *)dst_d, ggml_nelements(dst), nc, src0_o / sizeof(float), src1_o / sizeof(float), stream); + unary_gated_cuda(src0_p, src1_p, (float *)dst_d, ggml_nelements(dst), nc, src0_o / sizeof(float), src1_o / sizeof(float), op, stream); } } void ggml_cuda_op_reglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary_gated(ctx, dst); + auto op = [] __device__ (float x) -> float { + return fmaxf(x, 0); + }; + ggml_cuda_op_unary_gated(ctx, dst, op); } void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary_gated(ctx, dst); + auto op = [] __device__ (float x) -> float { + const float GELU_COEF_A = 0.044715f; + const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; + return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); + }; + ggml_cuda_op_unary_gated(ctx, dst, op); } void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary_gated(ctx, dst); + auto op = [] __device__ (float x) -> float { + return x / (1.0f + expf(-x)); + }; + ggml_cuda_op_unary_gated(ctx, dst, op); } void ggml_cuda_op_geglu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary_gated(ctx, dst); + auto op = [] __device__ (float x) -> float { + const float SQRT_2_INV = 0.70710678118654752440084436210484f; + return 0.5f*x*(1.0f + erff(x*SQRT_2_INV)); + }; + ggml_cuda_op_unary_gated(ctx, dst, op); } void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary_gated(ctx, dst); + auto op = [] __device__ (float x) -> float { + const float GELU_QUICK_COEF = -1.702f; + return x * (1.0f / (1.0f + expf(GELU_QUICK_COEF * x))); + }; + ggml_cuda_op_unary_gated(ctx, dst, op); +} + +void ggml_cuda_op_xielu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + // Get the XIELU parameters from the operation + const float * op_params = (const float*)dst->op_params; + float alpha_n = op_params[0]; + float alpha_p = op_params[1]; + const float beta = op_params[2]; + const float eps = op_params[3]; + + const auto op = [alpha_n, alpha_p, beta, eps] __device__ (float x) -> float { + float out; + float gate_pos = (x > 0.0f); // positive branch gate + float gate_neg = 1.0f - gate_pos; // negative branch gate + + // Positive branch: alpha_p * v^2 + beta * v + float y_pos = alpha_p * x * x + beta * x; + + // Negative branch: + float min_v_eps = fminf(x, eps); // works fine even if eps < 0 + float y_neg = (expm1f(min_v_eps) - x) * alpha_n + beta * x; + out = y_pos * gate_pos + y_neg * gate_neg; + + return out; + }; + + ggml_cuda_op_unary_gated(ctx, dst, op); } // swiglu_oai diff --git a/ggml/src/ggml-cuda/unary.cuh b/ggml/src/ggml-cuda/unary.cuh index da3caf1d8962e..9213c23d76bf2 100644 --- a/ggml/src/ggml-cuda/unary.cuh +++ b/ggml/src/ggml-cuda/unary.cuh @@ -72,3 +72,5 @@ void ggml_cuda_op_swiglu_oai(ggml_backend_cuda_context & ctx, ggml_tensor * dst) void ggml_cuda_op_geglu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_xielu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/models/templates/Apertus-8B-Instruct.jinja b/models/templates/Apertus-8B-Instruct.jinja new file mode 100644 index 0000000000000..a71677eee94ea --- /dev/null +++ b/models/templates/Apertus-8B-Instruct.jinja @@ -0,0 +1,327 @@ +{%- macro render_typescript_type(param_spec, required_params, is_nullable=false) -%} + {%- if param_spec.type == "array" -%} + {%- if param_spec['items'] -%} + {%- if param_spec['items']['type'] == "string" -%} + {{- "string[]" }} + {%- elif param_spec['items']['type'] == "number" -%} + {{- "number[]" }} + {%- elif param_spec['items']['type'] == "integer" -%} + {{- "number[]" }} + {%- elif param_spec['items']['type'] == "boolean" -%} + {{- "boolean[]" }} + {%- else -%} + {%- set inner_type = render_typescript_type(param_spec['items'], required_params) -%} + {%- if inner_type == "object | object" or inner_type|length > 50 -%} + {{- "any[]" }} + {%- else -%} + {{- inner_type + "[]" }} + {%- endif -%} + {%- endif -%} + {%- if param_spec.nullable -%} + {{- " | null" }} + {%- endif -%} + {%- else -%} + {{- "any[]" }} + {%- if param_spec.nullable -%} + {{- " | null" }} + {%- endif -%} + {%- endif -%} + {%- elif param_spec.type is defined and param_spec.type is iterable and param_spec.type is not string and param_spec.type is not mapping and param_spec.type[0] is defined -%} + {#- Handle array of types like ["object", "object"] from Union[dict, list] #} + {%- if param_spec.type | length > 1 -%} + {{- param_spec.type | join(" | ") }} + {%- else -%} + {{- param_spec.type[0] }} + {%- endif -%} + {%- elif param_spec.oneOf -%} + {#- Handle oneOf schemas - check for complex unions and fallback to any #} + {%- set has_object_variants = false -%} + {%- for variant in param_spec.oneOf -%} + {%- if variant.type == "object" -%} + {%- set has_object_variants = true -%} + {%- endif -%} + {%- endfor -%} + {%- if has_object_variants and param_spec.oneOf|length > 1 -%} + {{- "any" }} + {%- else -%} + {%- for variant in param_spec.oneOf -%} + {{- render_typescript_type(variant, required_params) -}} + {%- if variant.description %} + {{- "// " + variant.description }} + {%- endif -%} + {%- if variant.default is defined %} + {{ "// default: " + variant.default|tojson }} + {%- endif -%} + {%- if not loop.last %} + {{- " | " }} + {% endif -%} + {%- endfor -%} + {%- endif -%} + {%- elif param_spec.type == "string" -%} + {%- if param_spec.enum -%} + {{- '"' + param_spec.enum|join('" | "') + '"' -}} + {%- else -%} + {{- "string" }} + {%- if param_spec.nullable %} + {{- " | null" }} + {%- endif -%} + {%- endif -%} + {%- elif param_spec.type == "number" -%} + {{- "number" }} + {%- elif param_spec.type == "integer" -%} + {{- "number" }} + {%- elif param_spec.type == "boolean" -%} + {{- "boolean" }} + {%- elif param_spec.type == "object" -%} + {%- if param_spec.properties -%} + {{- "{\n" }} + {%- for prop_name, prop_spec in param_spec.properties.items() -%} + {{- prop_name -}} + {%- if prop_name not in (param_spec.required or []) -%} + {{- "?" }} + {%- endif -%} + {{- ": " }} + {{ render_typescript_type(prop_spec, param_spec.required or []) }} + {%- if not loop.last -%} + {{-", " }} + {%- endif -%} + {%- endfor -%} + {{- "}" }} + {%- else -%} + {{- "object" }} + {%- endif -%} + {%- else -%} + {{- "any" }} + {%- endif -%} +{%- endmacro -%} + +{%- macro render_tools(tools) -%} + {%- for tool in tools %} + {{- "// " + tool.description + "\n" }} + {{- "type "+ tool.name + " = " }} + {%- if tool.parameters and tool.parameters.properties %} + {{- "(_: {\n" }} + {%- for param_name, param_spec in tool.parameters.properties.items() %} + {%- if param_spec.description %} + {{- "// " + param_spec.description + "\n" }} + {%- endif %} + {{- param_name }} + {%- if param_name not in (tool.parameters.required or []) -%} + {{- "?" }} + {%- endif -%} + {{- ": " }} + {{- render_typescript_type(param_spec, tool.parameters.required or []) }} + {%- if param_spec.default is defined -%} + {%- if param_spec.enum %} + {{- ", // default: " + param_spec.default }} + {%- elif param_spec.oneOf %} + {{- "// default: " + param_spec.default }} + {%- else %} + {{- ", // default: " + param_spec.default|tojson }} + {%- endif -%} + {%- endif -%} + {%- if not loop.last %} + {{- ",\n" }} + {%- else %} + {{- "\n" }} + {%- endif -%} + {%- endfor %} + {{- "}) => any;" }} + {%- else -%} + {{- "() => any;" }} + {%- endif -%} + {%- if not loop.last -%} + {{- "\n" }} + {%- endif -%} + {%- endfor %} +{%- endmacro -%} + +{{ bos_token }} + +{%- set system_token = '<|system_start|>' -%} +{%- set end_system_token = '<|system_end|>' -%} +{%- set developer_token = '<|developer_start|>' -%} +{%- set end_developer_token = '<|developer_end|>' -%} +{%- set user_token = '<|user_start|>' -%} +{%- set end_user_token = '<|user_end|>' -%} +{%- set assistant_token = '<|assistant_start|>' -%} +{%- set end_assistant_token = '<|assistant_end|>' -%} +{%- set inner_token = '<|inner_prefix|>' -%} +{%- set outer_token = '<|inner_suffix|>' -%} +{%- set tool_calls_token = '<|tools_prefix|>' -%} +{%- set end_tool_calls_token = '<|tools_suffix|>' -%} + +{%- set ns = namespace(in_assistant=false, in_tool=false, in_inner=false, assistant_format=none) -%} + +{%- if messages and messages[0].role == 'system' -%} + {%- if "content" in messages[0] -%} + {%- if messages[0].content is string -%} + {{ system_token + messages[0].content + end_system_token }} + {%- elif messages[0].content is mapping and "text" in messages[0].content -%} + {{ system_token + messages[0].content.text + end_system_token }} + {%- else -%} + {{- raise_exception("Invalid system message") -}} + {%- endif -%} + {%- else -%} + {{- raise_exception("Invalid system message") -}} + {%- endif -%} + {%- set loop_messages = messages[1:] -%} +{%- else -%} + {{ system_token + 'You are Apertus, a helpful assistant created by the SwissAI initiative.\nKnowledge cutoff: 2024-04\nCurrent date: ' + strftime_now('%Y-%m-%d') + end_system_token }} + {%- set loop_messages = messages -%} +{%- endif -%} + +{{ developer_token + 'Deliberation: ' }} +{%- if enable_thinking is defined and enable_thinking -%} + {{ 'enabled\n' }} +{%- else -%} + {{ 'disabled\n' }} +{%- endif -%} +{%- if tools is defined and tools -%} + {{ 'Tool Capabilities:\n' + render_tools(tools) }} +{%- else -%} + {{ 'Tool Capabilities: disabled' }} +{%- endif -%} +{{ end_developer_token }} + +{%- for message in loop_messages -%} + {%- if message.role == 'user' -%} + {%- set ns.in_inner = false -%} + {%- if ns.in_tool -%} + {{ ']' }} + {%- set ns.in_tool = false -%} + {%- endif -%} + {%- if ns.in_assistant -%} + {{ end_assistant_token }} + {%- set ns.in_assistant = false -%} + {%- endif -%} + {%- if "content" in message -%} + {{ user_token }} + {%- if message.content is string -%} + {{ message.content }} + {%- elif message.content is mapping and "parts" in message.content -%} + {%- set parts = message.content.parts -%} + {%- for part in parts -%} + {%- if part.type == "text" -%} + {{ part.text }} + {%- else -%} + {{- raise_exception("Invalid user part: " + part.type) -}} + {%- endif -%} + {%- endfor -%} + {%- else -%} + {{- raise_exception("Invalid user message: " + message.role) -}} + {%- endif -%} + {{ end_user_token }} + {%- endif -%} + {%- elif message.role == 'assistant' -%} + {%- if not ns.in_assistant -%} + {{ assistant_token }} + {%- set ns.in_assistant = true -%} + {%- endif -%} + {%- if "content" in message -%} + {%- if message.content is string and (ns.assistant_format is none or ns.assistant_format == "string") -%} + {%- if ns.in_tool -%} + {{ ']' }} + {%- set ns.in_tool = false -%} + {%- endif -%} + {%- set ns.assistant_format = "string" -%} + {{ message.content }} + {%- elif message.content is mapping and "blocks" in message.content and (ns.assistant_format is none or ns.assistant_format == "mapping") -%} + {%- set ns.assistant_format = "mapping" -%} + {%- set blocks = message.content.blocks -%} + {%- for block in blocks -%} + {%- if block.type == 'thoughts' -%} + {%- if ns.in_tool -%} + {{ ']' }} + {%- set ns.in_tool = false -%} + {%- endif -%} + {%- if not ns.in_inner -%} + {%- set ns.in_inner = true -%} + {{ inner_token }} + {%- endif -%} + {{ block.text }} + {%- elif block.type == 'tool_calls' -%} + {%- if ns.in_tool -%} + {{ ']' }} + {%- set ns.in_tool = false -%} + {%- endif -%} + {%- if ns.in_inner and not loop.first and block.calls|length == 1 and block.calls[0].name == 'display_answers' -%} + {%- set ns.in_inner = false -%} + {{ outer_token }} + {%- endif -%} + {{ tool_calls_token + '[' }} + {%- for tool_call in block.calls -%} + {{- '{"' + tool_call.name + '": ' + tool_call.arguments + '}' }} + {%- if not loop.last -%} + {{- ", " }} + {%- endif -%} + {%- endfor -%} + {{ ']' + end_tool_calls_token }} + {%- elif block.type == 'tool_outputs' -%} + {%- if ns.in_tool -%} + {{- raise_exception("Cannot have both tool outputs as separate messages and tool outputs as blocks") -}} + {%- endif -%} + {{ '[' }} + {%- for tool_output in block.outputs -%} + {{- tool_output.output }} + {%- if not loop.last -%} + {{- ", " }} + {%- endif -%} + {%- endfor -%} + {{- ']' }} + {%- elif block.type == 'response' -%} + {%- if ns.in_tool -%} + {{ ']' }} + {%- set ns.in_tool = false -%} + {%- endif -%} + {%- if (not loop.first and ns.in_inner) or (ns.in_assistant and ns.in_inner) -%} + {%- set ns.in_inner = false -%} + {{ outer_token }} + {%- endif -%} + {{ block.text }} + {%- else -%} + {{- raise_exception("Invalid assistant block type: " + block.type) -}} + {%- endif -%} + {%- endfor -%} + {%- else -%} + {{- raise_exception("Invalid assistant content") -}} + {%- endif -%} + {%- else -%} + {{- raise_exception("Invalid assistant message") -}} + {%- endif -%} + {%- if "tool_calls" in message and message.tool_calls -%} + {{ tool_calls_token + '[' }} + {%- for tool_call in message.tool_calls -%} + {%- if tool_call.type == 'function' -%} + {%- set function = tool_call.function -%} + {{- '{"' + function.name + '": ' + function.arguments + '}' }} + {%- if not loop.last -%} + {{- ", " }} + {%- endif -%} + {%- else -%} + {{- raise_exception("Invalid tool call type: " + tool_call.type) -}} + {%- endif -%} + {%- endfor -%} + {{ ']' + end_tool_calls_token }} + {%- endif -%} + {%- elif message.role == 'tool' -%} + {%- if not ns.in_assistant -%} + {{- raise_exception("Tool message outside of assistant") -}} + {%- endif -%} + {%- if not ns.in_tool -%} + {{ '[' }} + {%- set ns.in_tool = true -%} + {%- else -%} + {{ ", "}} + {%- endif -%} + {{ message.content }} + {%- else -%} + {{- raise_exception("Invalid message role") -}} + {%- endif -%} +{%- endfor -%} +{%- if ns.in_tool -%} + {{ ']' }} +{%- endif -%} +{%- if add_generation_prompt -%} + {{ assistant_token }} +{%- endif -%} \ No newline at end of file From 1606a3c7600689e1b6523a6f6cd69069f49b8f06 Mon Sep 17 00:00:00 2001 From: Piotr Wilkin Date: Sun, 14 Sep 2025 16:49:22 +0200 Subject: [PATCH 07/32] Revert stupid LLM refactorings --- ggml/src/ggml-cuda/unary.cu | 278 ++++++++++++++++++++---------------- 1 file changed, 155 insertions(+), 123 deletions(-) diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index ca34524b5c3f8..52f39fffdc0a6 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -38,148 +38,168 @@ void ggml_cuda_op_unary(ggml_backend_cuda_context & ctx, ggml_tensor * dst, cons } } +static __device__ __forceinline__ float op_abs(float x) { + return fabsf(x); +} + void ggml_cuda_op_abs(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - auto op = [] __device__ (float x) -> float { - return fabsf(x); - }; - ggml_cuda_op_unary(ctx, dst, op); + ggml_cuda_op_unary(ctx, dst, op_abs); +} + +static __device__ __forceinline__ float op_sgn(float x) { + return (x > 0.f ? 1.f : ((x < 0.f ? -1.f : 0.f))); } void ggml_cuda_op_sgn(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - auto op = [] __device__ (float x) -> float { - return (x > 0.f ? 1.f : ((x < 0.f ? -1.f : 0.f))); - }; - ggml_cuda_op_unary(ctx, dst, op); + ggml_cuda_op_unary(ctx, dst, op_sgn); +} + +static __device__ __forceinline__ float op_neg(float x) { + return -x; } void ggml_cuda_op_neg(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - auto op = [] __device__ (float x) -> float { - return -x; - }; - ggml_cuda_op_unary(ctx, dst, op); + ggml_cuda_op_unary(ctx, dst, op_neg); +} + +static __device__ __forceinline__ float op_step(float x) { + return x > 0.0f; } void ggml_cuda_op_step(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - auto op = [] __device__ (float x) -> float { - return x > 0.0f; - }; - ggml_cuda_op_unary(ctx, dst, op); + ggml_cuda_op_unary(ctx, dst, op_step); +} + +static __device__ __forceinline__ float op_gelu(float x) { + const float GELU_COEF_A = 0.044715f; + const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; + return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); } void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - auto op = [] __device__ (float x) -> float { - const float GELU_COEF_A = 0.044715f; - const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; - return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); - }; - ggml_cuda_op_unary(ctx, dst, op); + ggml_cuda_op_unary(ctx, dst, op_gelu); +} + +static __device__ __forceinline__ float op_gelu_erf(float x) { + const float SQRT_2_INV = 0.70710678118654752440084436210484f; + return 0.5f*x*(1.0f + erff(x*SQRT_2_INV)); } void ggml_cuda_op_gelu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - auto op = [] __device__ (float x) -> float { - const float SQRT_2_INV = 0.70710678118654752440084436210484f; - return 0.5f*x*(1.0f + erff(x*SQRT_2_INV)); - }; - ggml_cuda_op_unary(ctx, dst, op); + ggml_cuda_op_unary(ctx, dst, op_gelu_erf); +} + +static __device__ __forceinline__ float op_gelu_quick(float x) { + const float GELU_QUICK_COEF = -1.702f; + return x * (1.0f / (1.0f + expf(GELU_QUICK_COEF * x))); } void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - auto op = [] __device__ (float x) -> float { - const float GELU_QUICK_COEF = -1.702f; - return x * (1.0f / (1.0f + expf(GELU_QUICK_COEF * x))); - }; - ggml_cuda_op_unary(ctx, dst, op); + ggml_cuda_op_unary(ctx, dst, op_gelu_quick); +} + +static __device__ __forceinline__ float op_silu(float x) { + return x / (1.0f + expf(-x)); } void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - auto op = [] __device__ (float x) -> float { - return x / (1.0f + expf(-x)); - }; - ggml_cuda_op_unary(ctx, dst, op); + ggml_cuda_op_unary(ctx, dst, op_silu); +} + +static __device__ __forceinline__ float op_tanh(float x) { + return tanhf(x); } void ggml_cuda_op_tanh(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - auto op = [] __device__ (float x) -> float { - return tanhf(x); - }; - ggml_cuda_op_unary(ctx, dst, op); + ggml_cuda_op_unary(ctx, dst, op_tanh); +} + +static __device__ __forceinline__ float op_relu(float x) { + return fmaxf(x, 0); } void ggml_cuda_op_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - auto op = [] __device__ (float x) -> float { - return fmaxf(x, 0); - }; - ggml_cuda_op_unary(ctx, dst, op); + ggml_cuda_op_unary(ctx, dst, op_relu); +} + +static __device__ __forceinline__ float op_sigmoid(float x) { + return 1.0f / (1.0f + expf(-x)); } void ggml_cuda_op_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - auto op = [] __device__ (float x) -> float { - return 1.0f / (1.0f + expf(-x)); - }; - ggml_cuda_op_unary(ctx, dst, op); + ggml_cuda_op_unary(ctx, dst, op_sigmoid); +} + +static __device__ __forceinline__ float op_hardsigmoid(float x) { + return fminf(1.0f, fmaxf(0.0f, (x + 3.0f) / 6.0f)); } void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - auto op = [] __device__ (float x) -> float { - return fminf(1.0f, fmaxf(0.0f, (x + 3.0f) / 6.0f)); - }; - ggml_cuda_op_unary(ctx, dst, op); + ggml_cuda_op_unary(ctx, dst, op_hardsigmoid); +} + +static __device__ __forceinline__ float op_hardswish(float x) { + return x * fminf(1.0f, fmaxf(0.0f, (x + 3.0f) / 6.0f)); } void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - auto op = [] __device__ (float x) -> float { - return x * fminf(1.0f, fmaxf(0.0f, (x + 3.0f) / 6.0f)); - }; - ggml_cuda_op_unary(ctx, dst, op); + ggml_cuda_op_unary(ctx, dst, op_hardswish); +} + +static __device__ __forceinline__ float op_exp(float x) { + return expf(x); } void ggml_cuda_op_exp(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - auto op = [] __device__ (float x) -> float { - return expf(x); - }; - ggml_cuda_op_unary(ctx, dst, op); + ggml_cuda_op_unary(ctx, dst, op_exp); +} + +static __device__ __forceinline__ float op_sqr(float x) { + return x * x; } void ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - auto op = [] __device__ (float x) -> float { - return x * x; - }; - ggml_cuda_op_unary(ctx, dst, op); + ggml_cuda_op_unary(ctx, dst, op_sqr); +} + +static __device__ __forceinline__ float op_sqrt(float x) { + return sqrtf(x); } void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - auto op = [] __device__ (float x) -> float { - return sqrtf(x); - }; - ggml_cuda_op_unary(ctx, dst, op); + ggml_cuda_op_unary(ctx, dst, op_sqrt); +} + +static __device__ __forceinline__ float op_sin(float x) { + return sinf(x); } void ggml_cuda_op_sin(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - auto op = [] __device__ (float x) -> float { - return sinf(x); - }; - ggml_cuda_op_unary(ctx, dst, op); + ggml_cuda_op_unary(ctx, dst, op_sin); +} + +static __device__ __forceinline__ float op_cos(float x) { + return cosf(x); } void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - auto op = [] __device__ (float x) -> float { - return cosf(x); - }; - ggml_cuda_op_unary(ctx, dst, op); + ggml_cuda_op_unary(ctx, dst, op_cos); +} + +static __device__ __forceinline__ float op_log(float x) { + return logf(x); } void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - auto op = [] __device__ (float x) -> float { - return logf(x); - }; - ggml_cuda_op_unary(ctx, dst, op); + ggml_cuda_op_unary(ctx, dst, op_log); +} + +static __device__ __forceinline__ float op_elu(float x) { + return (x > 0.f) ? x : expm1f(x); } void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - auto op = [] __device__ (float x) -> float { - return (x > 0.f) ? x : expm1f(x); - }; - ggml_cuda_op_unary(ctx, dst, op); + ggml_cuda_op_unary(ctx, dst, op_elu); } /* gated ops */ @@ -258,57 +278,59 @@ void ggml_cuda_op_unary_gated(ggml_backend_cuda_context & ctx, ggml_tensor * dst } } +static __device__ __forceinline__ float op_reglu(float x) { + return fmaxf(x, 0); +} + void ggml_cuda_op_reglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - auto op = [] __device__ (float x) -> float { - return fmaxf(x, 0); - }; - ggml_cuda_op_unary_gated(ctx, dst, op); + ggml_cuda_op_unary_gated(ctx, dst, op_reglu); +} + +static __device__ __forceinline__ float op_geglu(float x) { + const float GELU_COEF_A = 0.044715f; + const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; + return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); } void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - auto op = [] __device__ (float x) -> float { - const float GELU_COEF_A = 0.044715f; - const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; - return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); - }; - ggml_cuda_op_unary_gated(ctx, dst, op); + ggml_cuda_op_unary_gated(ctx, dst, op_geglu); +} + +static __device__ __forceinline__ float op_swiglu(float x) { + return x / (1.0f + expf(-x)); } void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - auto op = [] __device__ (float x) -> float { - return x / (1.0f + expf(-x)); - }; - ggml_cuda_op_unary_gated(ctx, dst, op); + ggml_cuda_op_unary_gated(ctx, dst, op_swiglu); +} + +static __device__ __forceinline__ float op_geglu_erf(float x) { + const float SQRT_2_INV = 0.70710678118654752440084436210484f; + return 0.5f*x*(1.0f + erff(x*SQRT_2_INV)); } void ggml_cuda_op_geglu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - auto op = [] __device__ (float x) -> float { - const float SQRT_2_INV = 0.70710678118654752440084436210484f; - return 0.5f*x*(1.0f + erff(x*SQRT_2_INV)); - }; - ggml_cuda_op_unary_gated(ctx, dst, op); + ggml_cuda_op_unary_gated(ctx, dst, op_geglu_erf); } -void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - auto op = [] __device__ (float x) -> float { - const float GELU_QUICK_COEF = -1.702f; - return x * (1.0f / (1.0f + expf(GELU_QUICK_COEF * x))); - }; - ggml_cuda_op_unary_gated(ctx, dst, op); +static __device__ __forceinline__ float op_geglu_quick(float x) { + const float GELU_QUICK_COEF = -1.702f; + return x * (1.0f / (1.0f + expf(GELU_QUICK_COEF * x))); } -void ggml_cuda_op_xielu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - // Get the XIELU parameters from the operation - const float * op_params = (const float*)dst->op_params; - float alpha_n = op_params[0]; - float alpha_p = op_params[1]; - const float beta = op_params[2]; - const float eps = op_params[3]; +void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary_gated(ctx, dst, op_geglu_quick); +} - const auto op = [alpha_n, alpha_p, beta, eps] __device__ (float x) -> float { - float out; +// Functor for XIELU operation with parameters +struct op_xielu_functor { + float alpha_n, alpha_p, beta, eps; + + __host__ __device__ __forceinline__ op_xielu_functor(float a_n, float a_p, float b, float e) + : alpha_n(a_n), alpha_p(a_p), beta(b), eps(e) {} + + __device__ __forceinline__ float operator()(float x) const { float gate_pos = (x > 0.0f); // positive branch gate - float gate_neg = 1.0f - gate_pos; // negative branch gate // Positive branch: alpha_p * v^2 + beta * v float y_pos = alpha_p * x * x + beta * x; @@ -316,12 +338,22 @@ void ggml_cuda_op_xielu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { // Negative branch: float min_v_eps = fminf(x, eps); // works fine even if eps < 0 float y_neg = (expm1f(min_v_eps) - x) * alpha_n + beta * x; - out = y_pos * gate_pos + y_neg * gate_neg; - return out; - }; + // Select the appropriate branch based on the gate + return gate_pos * y_pos + (1.0f - gate_pos) * y_neg; + } +}; - ggml_cuda_op_unary_gated(ctx, dst, op); +void ggml_cuda_op_xielu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + // Get the XIELU parameters from the operation + const float * op_params = (const float*)dst->op_params; + float alpha_n = op_params[0]; + float alpha_p = op_params[1]; + const float beta = op_params[2]; + const float eps = op_params[3]; + + op_xielu_functor op(alpha_n, alpha_p, beta, eps); + ggml_cuda_op_unary(ctx, dst, op); } // swiglu_oai From eec384fe88e4ccd42cac3607825ca2c4ed3db254 Mon Sep 17 00:00:00 2001 From: Piotr Wilkin Date: Mon, 15 Sep 2025 00:12:01 +0200 Subject: [PATCH 08/32] Chat template support --- common/chat-parser.cpp | 29 ++++++ common/chat-parser.h | 3 + common/chat.cpp | 110 +++++++++++++++++++++ common/chat.h | 1 + models/templates/Apertus-8B-Instruct.jinja | 8 +- tests/test-chat.cpp | 73 ++++++++++++++ 6 files changed, 220 insertions(+), 4 deletions(-) diff --git a/common/chat-parser.cpp b/common/chat-parser.cpp index 96ba8f533ef1b..1f69fe783a84b 100644 --- a/common/chat-parser.cpp +++ b/common/chat-parser.cpp @@ -75,6 +75,35 @@ bool common_chat_msg_parser::add_tool_calls(const json & arr) { } return true; } + +bool common_chat_msg_parser::add_tool_call_short_form(const json & tool_call) { + if (!tool_call.is_object() || tool_call.size() != 1) { + return false; + } + + // Get the tool name (the single key in the object) + auto it = tool_call.begin(); + std::string name = it.key(); + + if (name.empty()) { + return false; + } + + // Get the arguments (the nested object) + const json & args_json = it.value(); + std::string arguments = ""; + + if (args_json.is_object()) { + arguments = args_json.dump(); + } else if (args_json.is_string()) { + arguments = args_json; + } else if (!args_json.is_null()) { + // For other types, convert to string representation + arguments = args_json.dump(); + } + + return add_tool_call(name, "", arguments); +} void common_chat_msg_parser::finish() { if (!is_partial_ && pos_ != input_.size()) { throw std::runtime_error("Unexpected content at end of input");// + input_.substr(pos_)); diff --git a/common/chat-parser.h b/common/chat-parser.h index 0e64c341a50aa..c8cdc63fb50f6 100644 --- a/common/chat-parser.h +++ b/common/chat-parser.h @@ -64,6 +64,9 @@ class common_chat_msg_parser { // Adds an array of tool calls using their "name", "id" and "arguments" fields. bool add_tool_calls(const nlohmann::ordered_json & arr); + // Adds a tool call using the short form: { "tool_name": { "arg1": val, "arg2": val } } + bool add_tool_call_short_form(const nlohmann::ordered_json & tool_call); + void finish(); bool consume_spaces(); diff --git a/common/chat.cpp b/common/chat.cpp index 4707c4fef4070..69cd51fa646c4 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -638,6 +638,7 @@ const char * common_chat_format_name(common_chat_format format) { case COMMON_CHAT_FORMAT_GPT_OSS: return "GPT-OSS"; case COMMON_CHAT_FORMAT_SEED_OSS: return "Seed-OSS"; case COMMON_CHAT_FORMAT_NEMOTRON_V2: return "Nemotron V2"; + case COMMON_CHAT_FORMAT_APERTUS: return "Apertus"; default: throw std::runtime_error("Unknown chat format"); } @@ -801,6 +802,7 @@ static std::string apply( } tmpl_inputs.add_generation_prompt = inputs.add_generation_prompt; tmpl_inputs.extra_context = inputs.extra_context; + tmpl_inputs.extra_context["enable_thinking"] = inputs.enable_thinking; if (additional_context) { tmpl_inputs.extra_context.merge_patch(*additional_context); } @@ -1264,6 +1266,75 @@ static common_chat_params common_chat_params_init_nemotron_v2(const common_chat_ } return data; } + +static common_chat_params common_chat_params_init_apertus(const common_chat_template & tmpl, const struct templates_params & inputs) { + common_chat_params data; + + // Generate the prompt using the apply() function with the template + data.prompt = apply(tmpl, inputs); + data.format = COMMON_CHAT_FORMAT_APERTUS; + + // Handle thinking tags appropriately based on inputs.enable_thinking + if (string_ends_with(data.prompt, "<|inner_prefix|>")) { + if (!inputs.enable_thinking) { + data.prompt += "<|inner_suffix|>"; + } else { + data.thinking_forced_open = true; + } + } + + // When tools are present, build grammar for the <|tools_prefix|> format + if (!inputs.tools.is_null() && inputs.tools.is_array() && !inputs.tools.empty()) { + data.grammar_lazy = true; + data.grammar = build_grammar([&](const common_grammar_builder & builder) { + auto schemas = json::array(); + foreach_function(inputs.tools, [&](const json & tool) { + const auto & function = tool.at("function"); + schemas.push_back({ + { "type", "object" }, + { "properties", + { + { function.at("name"), function.at("parameters") } + } }, + { "required", json::array({ function.at("name") }) }, + }); + }); + auto schema = json{ + { "type", "array" }, + { "items", schemas.size() == 1 ? schemas[0] : json{ { "anyOf", schemas } } }, + { "minItems", 1 }, + }; + if (!inputs.parallel_tool_calls) { + schema["maxItems"] = 1; + } + builder.add_rule("root", + std::string(data.thinking_forced_open ? "( \"<|inner_suffix|>\" space )? " : "") + + "\"<|tools_prefix|>\"" + builder.add_schema("tool_calls", schema) + "\"<|tools_suffix|>\""); + }); + data.grammar_triggers.push_back({ COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL, + // If thinking_forced_open, then we capture the <|inner_suffix|> tag in the grammar, + // (important for required tool choice) and in the trigger's first capture (decides what is sent to the grammar) + std::string(data.thinking_forced_open ? + "[\\s\\S]*?(<\\|inner_suffix\\|>\\s*)" : + "(?:<\\|inner_prefix\\|>[\\s\\S]*?<\\|inner_suffix\\|>\\s*)?") + + "(<\\|tools_prefix\\|>)[\\s\\S]*" }); + data.preserved_tokens = { + "<|system_start|>", + "<|system_end|>", + "<|developer_start|>", + "<|developer_end|>", + "<|user_start|>", + "<|user_end|>", + "<|assistant_start|>", + "<|assistant_end|>", + "<|inner_prefix|>", + "<|inner_suffix|>", + "<|tools_prefix|>", + "<|tools_suffix|>", + }; + } + return data; +} static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool with_builtin_tools = false) { if (!builder.syntax().parse_tool_calls) { builder.add_content(builder.consume_rest()); @@ -2292,6 +2363,37 @@ static void common_chat_parse_nemotron_v2(common_chat_msg_parser & builder) { builder.add_content(builder.consume_rest()); } +static void common_chat_parse_apertus(common_chat_msg_parser & builder) { + // Parse thinking tags + builder.try_parse_reasoning("<|inner_prefix|>", "<|inner_suffix|>"); + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + + // Look for tool calls + static const common_regex tool_call_regex(regex_escape("<|tools_prefix|>")); + if (auto res = builder.try_find_regex(tool_call_regex)) { + builder.move_to(res->groups[0].end); + + auto tool_calls_data = builder.consume_json(); + if (tool_calls_data.json.is_array()) { + builder.consume_spaces(); + if (!builder.try_consume_literal("<|tools_suffix|>")) { + throw common_chat_msg_partial_exception("Incomplete tool call"); + } + for (const auto & value : tool_calls_data.json) { + if (value.is_object()) { + builder.add_tool_call_short_form(value); + } + } + } else { + throw common_chat_msg_partial_exception("Incomplete tool call"); + } + } + builder.add_content(builder.consume_rest()); +} + static void common_chat_parse_seed_oss(common_chat_msg_parser & builder) { // Parse thinking tags first - this handles the main reasoning content builder.try_parse_reasoning("", ""); @@ -2536,6 +2638,11 @@ static common_chat_params common_chat_templates_apply_jinja( return common_chat_params_init_nemotron_v2(tmpl, params); } + // Apertus format detection + if (src.find("<|system_start|>") != std::string::npos && src.find("<|tools_prefix|>") != std::string::npos) { + return common_chat_params_init_apertus(tmpl, params); + } + // Use generic handler when mixing tools + JSON schema. // TODO: support that mix in handlers below. if ((params.tools.is_array() && params.json_schema.is_object())) { @@ -2703,6 +2810,9 @@ static void common_chat_parse(common_chat_msg_parser & builder) { case COMMON_CHAT_FORMAT_NEMOTRON_V2: common_chat_parse_nemotron_v2(builder); break; + case COMMON_CHAT_FORMAT_APERTUS: + common_chat_parse_apertus(builder); + break; default: throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format)); } diff --git a/common/chat.h b/common/chat.h index 5170fc14f4e63..3c277e15eba7f 100644 --- a/common/chat.h +++ b/common/chat.h @@ -114,6 +114,7 @@ enum common_chat_format { COMMON_CHAT_FORMAT_GPT_OSS, COMMON_CHAT_FORMAT_SEED_OSS, COMMON_CHAT_FORMAT_NEMOTRON_V2, + COMMON_CHAT_FORMAT_APERTUS, COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats }; diff --git a/models/templates/Apertus-8B-Instruct.jinja b/models/templates/Apertus-8B-Instruct.jinja index a71677eee94ea..10826ff6901ae 100644 --- a/models/templates/Apertus-8B-Instruct.jinja +++ b/models/templates/Apertus-8B-Instruct.jinja @@ -218,7 +218,7 @@ {{ assistant_token }} {%- set ns.in_assistant = true -%} {%- endif -%} - {%- if "content" in message -%} + {%- if "content" in message and message.content is not none -%} {%- if message.content is string and (ns.assistant_format is none or ns.assistant_format == "string") -%} {%- if ns.in_tool -%} {{ ']' }} @@ -284,10 +284,10 @@ {%- endif -%} {%- endfor -%} {%- else -%} - {{- raise_exception("Invalid assistant content") -}} + {{- raise_exception("Invalid assistant content '" + message.content + "', expected " + ns.assistant_format) -}} {%- endif -%} - {%- else -%} - {{- raise_exception("Invalid assistant message") -}} + {%- elif "tool_calls" not in message -%} + {{- raise_exception("Invalid assistant message " + message) -}} {%- endif -%} {%- if "tool_calls" in message and message.tool_calls -%} {{ tool_calls_token + '[' }} diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index ac8a0ade1f6e2..cf729db42df17 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -1963,6 +1963,79 @@ static void test_template_output_parsers() { /* .parse_tool_calls = */ true, })); } + { + auto tmpls = read_templates("models/templates/Apertus-8B-Instruct.jinja"); + std::vector end_tokens{ "<|assistant_end|>" }; + + assert_equals(COMMON_CHAT_FORMAT_APERTUS, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format); + assert_equals(COMMON_CHAT_FORMAT_APERTUS, common_chat_templates_apply(tmpls.get(), inputs_tools).format); + + // Test parsing regular content + assert_msg_equals(message_assist, + common_chat_parse( + "Hello, world!\nWhat's up?", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_APERTUS})); + + // Test parsing content with thinking + assert_msg_equals(message_assist_thoughts, + common_chat_parse( + "<|inner_prefix|>I'm\nthinking<|inner_suffix|>Hello, world!\nWhat's up?", + /* is_partial= */ false, + { + /* .format = */ COMMON_CHAT_FORMAT_APERTUS, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, + })); + + // Test parsing tool calls + assert_msg_equals(message_assist_call, + common_chat_parse( + "<|tools_prefix|>[{\"special_function\": {\"arg1\": 1}}]<|tools_suffix|>", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_APERTUS})); + + // Test parsing tool calls with thinking + assert_msg_equals(message_assist_call_thoughts, + common_chat_parse( + "<|inner_prefix|>I'm\nthinking<|inner_suffix|><|tools_prefix|>[{\"special_function\": {\"arg1\": 1}}]<|tools_suffix|>", + /* is_partial= */ false, + { + /* .format = */ COMMON_CHAT_FORMAT_APERTUS, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK + })); + + // Test tool calls with extra content + assert_msg_equals(message_assist_call_content, + common_chat_parse( + "<|tools_prefix|>[{\"special_function\": {\"arg1\": 1}}]<|tools_suffix|>Hello, world!\nWhat's up?", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_APERTUS} + )); + + // Test tool calls with extra content AND thinking + assert_msg_equals(message_assist_call_thoughts_content, + common_chat_parse( + "<|inner_prefix|>I'm\nthinking<|inner_suffix|><|tools_prefix|>[{\"special_function\": {\"arg1\": 1}}]<|tools_suffix|>Hello, world!\nWhat's up?", + /* is_partial= */ false, + { + /* .format = */ COMMON_CHAT_FORMAT_APERTUS, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK + })); + + // Test template generation for regular content + test_templates(tmpls.get(), end_tokens, message_assist, tools, + "Hello, world!\nWhat's up?", + /* expect_grammar_triggered= */ false); + + // Test template generation for tool calls + test_templates(tmpls.get(), end_tokens, message_assist_call, tools, + "<|tools_prefix|>[{\"special_function\": {\"arg1\": 1}}]<|tools_suffix|>", + /* expect_grammar_triggered= */ true + ); + + assert_equals(true, common_chat_templates_support_enable_thinking(tmpls.get())); + } + } static void test_msg_diffs_compute() { From b2a92d0d9b4c1e99448198fd78b023b44b573290 Mon Sep 17 00:00:00 2001 From: Piotr Wilkin Date: Mon, 15 Sep 2025 00:26:32 +0200 Subject: [PATCH 09/32] configchecker / flake8 errors --- common/chat-parser.cpp | 4 ++-- common/chat.cpp | 2 +- convert_hf_to_gguf.py | 2 +- ggml/src/ggml-cuda/unary.cu | 4 ++-- gguf-py/gguf/gguf_writer.py | 2 +- src/llama-model.cpp | 2 +- tools/tts/convert_pt_to_hf.py | 2 +- 7 files changed, 9 insertions(+), 9 deletions(-) diff --git a/common/chat-parser.cpp b/common/chat-parser.cpp index 1f69fe783a84b..b3362519a68f3 100644 --- a/common/chat-parser.cpp +++ b/common/chat-parser.cpp @@ -84,7 +84,7 @@ bool common_chat_msg_parser::add_tool_call_short_form(const json & tool_call) { // Get the tool name (the single key in the object) auto it = tool_call.begin(); std::string name = it.key(); - + if (name.empty()) { return false; } @@ -92,7 +92,7 @@ bool common_chat_msg_parser::add_tool_call_short_form(const json & tool_call) { // Get the arguments (the nested object) const json & args_json = it.value(); std::string arguments = ""; - + if (args_json.is_object()) { arguments = args_json.dump(); } else if (args_json.is_string()) { diff --git a/common/chat.cpp b/common/chat.cpp index 69cd51fa646c4..e831f0cd6a466 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -2375,7 +2375,7 @@ static void common_chat_parse_apertus(common_chat_msg_parser & builder) { static const common_regex tool_call_regex(regex_escape("<|tools_prefix|>")); if (auto res = builder.try_find_regex(tool_call_regex)) { builder.move_to(res->groups[0].end); - + auto tool_calls_data = builder.consume_json(); if (tool_calls_data.json.is_array()) { builder.consume_spaces(); diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 59e893038c6fd..208c027cb4cea 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -8613,7 +8613,7 @@ def modify_tensors(self, data_torch, name, bid): if (len(self._eps) == n_layers): self.gguf_writer.add_xielu_eps([self._eps[k] for k in sorted(self._eps)]) return [] - + return super().modify_tensors(data_torch, name, bid) diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index 52f39fffdc0a6..e80e2351d78a2 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -325,10 +325,10 @@ void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst // Functor for XIELU operation with parameters struct op_xielu_functor { float alpha_n, alpha_p, beta, eps; - + __host__ __device__ __forceinline__ op_xielu_functor(float a_n, float a_p, float b, float e) : alpha_n(a_n), alpha_p(a_p), beta(b), eps(e) {} - + __device__ __forceinline__ float operator()(float x) const { float gate_pos = (x > 0.0f); // positive branch gate diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index a431e94befbb8..8ace552569f15 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -1059,7 +1059,7 @@ def add_xielu_alpha_n(self, value: Sequence[float]): def add_xielu_beta(self, value: Sequence[float]): self.add_array(Keys.xIELU.XIELU_BETA, value) - + def add_xielu_eps(self, value: Sequence[float]): self.add_array(Keys.xIELU.XIELU_EPS, value) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 46b844250ef91..3d696f375aa32 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -18686,7 +18686,7 @@ struct llm_build_apertus : public llm_graph_context { Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); cb(Kcur, "Kcur_normed", il); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, rope_factors, diff --git a/tools/tts/convert_pt_to_hf.py b/tools/tts/convert_pt_to_hf.py index 8909a65fd1e13..ebd55d9657b24 100644 --- a/tools/tts/convert_pt_to_hf.py +++ b/tools/tts/convert_pt_to_hf.py @@ -12,7 +12,7 @@ from safetensors.torch import save_file # default -model_path = './model.pt'; +model_path = './model.pt' # read from CLI if len(sys.argv) > 1: From ef9ef66870732c06552cf2d752bf11b3f6eb20e3 Mon Sep 17 00:00:00 2001 From: Piotr Wilkin Date: Mon, 15 Sep 2025 10:49:01 +0200 Subject: [PATCH 10/32] Reorder unary.cu --- ggml/src/ggml-cuda/unary.cu | 443 ++++++++++++++++++------------------ 1 file changed, 224 insertions(+), 219 deletions(-) diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index e80e2351d78a2..fd6d6684c7538 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -1,206 +1,163 @@ #include "unary.cuh" -template -static __global__ void unary_op_kernel(const T * x, T * dst, const int k, const Op op) { - const int i = blockDim.x*blockIdx.x + threadIdx.x; - - if (i >= k) { - return; - } - - dst[i] = (T)op((float)x[i]); -} - -template -static void unary_cuda(const T * x, T * dst, const int k, const Op op, cudaStream_t stream) { - const int num_blocks = (k + CUDA_NEG_BLOCK_SIZE - 1) / CUDA_NEG_BLOCK_SIZE; - unary_op_kernel<<>>(x, dst, k, op); -} - -template -void ggml_cuda_op_unary(ggml_backend_cuda_context & ctx, ggml_tensor * dst, const Op op) { - const ggml_tensor * src0 = dst->src[0]; - const void * src0_d = src0->data; - void * dst_d = dst->data; - cudaStream_t stream = ctx.stream(); - - GGML_ASSERT(ggml_is_contiguous(src0)); - - GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); - GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); - GGML_ASSERT(src0->type == dst->type); - - if (src0->type == GGML_TYPE_F16) { - unary_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), op, stream); - } else { - unary_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), op, stream); - } -} - +// Simple inline operation functions (moved to top) static __device__ __forceinline__ float op_abs(float x) { return fabsf(x); } -void ggml_cuda_op_abs(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary(ctx, dst, op_abs); -} - static __device__ __forceinline__ float op_sgn(float x) { return (x > 0.f ? 1.f : ((x < 0.f ? -1.f : 0.f))); } -void ggml_cuda_op_sgn(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary(ctx, dst, op_sgn); -} - static __device__ __forceinline__ float op_neg(float x) { return -x; } -void ggml_cuda_op_neg(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary(ctx, dst, op_neg); -} - static __device__ __forceinline__ float op_step(float x) { return x > 0.0f; } -void ggml_cuda_op_step(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary(ctx, dst, op_step); -} - static __device__ __forceinline__ float op_gelu(float x) { const float GELU_COEF_A = 0.044715f; const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); } -void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary(ctx, dst, op_gelu); -} - static __device__ __forceinline__ float op_gelu_erf(float x) { const float SQRT_2_INV = 0.70710678118654752440084436210484f; return 0.5f*x*(1.0f + erff(x*SQRT_2_INV)); } -void ggml_cuda_op_gelu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary(ctx, dst, op_gelu_erf); -} - static __device__ __forceinline__ float op_gelu_quick(float x) { const float GELU_QUICK_COEF = -1.702f; return x * (1.0f / (1.0f + expf(GELU_QUICK_COEF * x))); } -void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary(ctx, dst, op_gelu_quick); -} - static __device__ __forceinline__ float op_silu(float x) { return x / (1.0f + expf(-x)); } -void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary(ctx, dst, op_silu); -} - static __device__ __forceinline__ float op_tanh(float x) { return tanhf(x); } -void ggml_cuda_op_tanh(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary(ctx, dst, op_tanh); -} - static __device__ __forceinline__ float op_relu(float x) { return fmaxf(x, 0); } -void ggml_cuda_op_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary(ctx, dst, op_relu); -} - static __device__ __forceinline__ float op_sigmoid(float x) { return 1.0f / (1.0f + expf(-x)); } -void ggml_cuda_op_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary(ctx, dst, op_sigmoid); -} - static __device__ __forceinline__ float op_hardsigmoid(float x) { return fminf(1.0f, fmaxf(0.0f, (x + 3.0f) / 6.0f)); } -void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary(ctx, dst, op_hardsigmoid); -} - static __device__ __forceinline__ float op_hardswish(float x) { return x * fminf(1.0f, fmaxf(0.0f, (x + 3.0f) / 6.0f)); } -void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary(ctx, dst, op_hardswish); -} - static __device__ __forceinline__ float op_exp(float x) { return expf(x); } -void ggml_cuda_op_exp(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary(ctx, dst, op_exp); -} - static __device__ __forceinline__ float op_sqr(float x) { return x * x; } -void ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary(ctx, dst, op_sqr); -} - static __device__ __forceinline__ float op_sqrt(float x) { return sqrtf(x); } -void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary(ctx, dst, op_sqrt); -} - static __device__ __forceinline__ float op_sin(float x) { return sinf(x); } -void ggml_cuda_op_sin(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary(ctx, dst, op_sin); -} - static __device__ __forceinline__ float op_cos(float x) { return cosf(x); } -void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary(ctx, dst, op_cos); -} - static __device__ __forceinline__ float op_log(float x) { return logf(x); } -void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary(ctx, dst, op_log); -} - static __device__ __forceinline__ float op_elu(float x) { return (x > 0.f) ? x : expm1f(x); } -void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary(ctx, dst, op_elu); +static __device__ __forceinline__ float op_reglu(float x) { + return fmaxf(x, 0); +} + +static __device__ __forceinline__ float op_geglu(float x) { + const float GELU_COEF_A = 0.044715f; + const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; + return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); +} + +static __device__ __forceinline__ float op_swiglu(float x) { + return x / (1.0f + expf(-x)); +} + +static __device__ __forceinline__ float op_geglu_erf(float x) { + const float SQRT_2_INV = 0.70710678118654752440084436210484f; + return 0.5f*x*(1.0f + erff(x*SQRT_2_INV)); } + +static __device__ __forceinline__ float op_geglu_quick(float x) { + const float GELU_QUICK_COEF = -1.702f; + return x * (1.0f / (1.0f + expf(GELU_QUICK_COEF * x))); +} + +// Special operation functions (stay in original locations) +static __device__ __forceinline__ float op_silu_back(float grad, float x) { + const float s = 1.0f / (1.0f + expf(-x)); + return grad * s * (1.0f + x * (1.0f - s)); +} + +static __device__ __forceinline__ float op_leaky_relu(float x, const float negative_slope) { + return fmaxf(x, 0) + fminf(x, 0.0f) * negative_slope; +} + +// Template functions and kernels +template +static __global__ void unary_op_kernel(const T * x, T * dst, const int k, const Op op) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + + dst[i] = (T)op((float)x[i]); +} + +template +static void unary_cuda(const T * x, T * dst, const int k, const Op op, cudaStream_t stream) { + const int num_blocks = (k + CUDA_NEG_BLOCK_SIZE - 1) / CUDA_NEG_BLOCK_SIZE; + unary_op_kernel<<>>(x, dst, k, op); +} + +template +void ggml_cuda_op_unary(ggml_backend_cuda_context & ctx, ggml_tensor * dst, const Op op) { + const ggml_tensor * src0 = dst->src[0]; + const void * src0_d = src0->data; + void * dst_d = dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(ggml_is_contiguous(src0)); + + GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); + GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); + GGML_ASSERT(src0->type == dst->type); + + if (src0->type == GGML_TYPE_F16) { + unary_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), op, stream); + } else { + unary_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), op, stream); + } +} + /* gated ops */ template @@ -278,50 +235,6 @@ void ggml_cuda_op_unary_gated(ggml_backend_cuda_context & ctx, ggml_tensor * dst } } -static __device__ __forceinline__ float op_reglu(float x) { - return fmaxf(x, 0); -} - -void ggml_cuda_op_reglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary_gated(ctx, dst, op_reglu); -} - -static __device__ __forceinline__ float op_geglu(float x) { - const float GELU_COEF_A = 0.044715f; - const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; - return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); -} - -void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary_gated(ctx, dst, op_geglu); -} - -static __device__ __forceinline__ float op_swiglu(float x) { - return x / (1.0f + expf(-x)); -} - -void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary_gated(ctx, dst, op_swiglu); -} - -static __device__ __forceinline__ float op_geglu_erf(float x) { - const float SQRT_2_INV = 0.70710678118654752440084436210484f; - return 0.5f*x*(1.0f + erff(x*SQRT_2_INV)); -} - -void ggml_cuda_op_geglu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary_gated(ctx, dst, op_geglu_erf); -} - -static __device__ __forceinline__ float op_geglu_quick(float x) { - const float GELU_QUICK_COEF = -1.702f; - return x * (1.0f / (1.0f + expf(GELU_QUICK_COEF * x))); -} - -void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary_gated(ctx, dst, op_geglu_quick); -} - // Functor for XIELU operation with parameters struct op_xielu_functor { float alpha_n, alpha_p, beta, eps; @@ -344,18 +257,6 @@ struct op_xielu_functor { } }; -void ggml_cuda_op_xielu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - // Get the XIELU parameters from the operation - const float * op_params = (const float*)dst->op_params; - float alpha_n = op_params[0]; - float alpha_p = op_params[1]; - const float beta = op_params[2]; - const float eps = op_params[3]; - - op_xielu_functor op(alpha_n, alpha_p, beta, eps); - ggml_cuda_op_unary(ctx, dst, op); -} - // swiglu_oai template @@ -387,6 +288,157 @@ static void swiglu_oai_cuda(const T * x, const T * g, T * dst, const int64_t k, swiglu_oai_kernel<<>>(x, g, dst, k, n, o0, o1, alpha, limit); } +/* silu_back */ + +template +static __global__ void silu_back_kernel(const T * grad, const T * xf, T * dst, const int k) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + + dst[i] = (T)op_silu_back((float)grad[i], (float)xf[i]); +} + +template +static void silu_back_cuda(const T * grad, const T * x, T * dst, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_SILU_BACK_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE; + silu_back_kernel<<>>(grad, x, dst, k); +} + +/* leaky relu */ + +template +static __global__ void leaky_relu_kernel(const T * x, T * dst, const int k, const float negative_slope) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + + dst[i] = (T)op_leaky_relu((float)x[i], negative_slope); +} + +template +static void leaky_relu_cuda(const T * x, T * dst, const int k, const float negative_slope, cudaStream_t stream) { + const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE; + leaky_relu_kernel<<>>(x, dst, k, negative_slope); +} + +// All ggml_cuda_op_* functions moved to bottom +void ggml_cuda_op_abs(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst, op_abs); +} + +void ggml_cuda_op_sgn(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst, op_sgn); +} + +void ggml_cuda_op_neg(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst, op_neg); +} + +void ggml_cuda_op_step(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst, op_step); +} + +void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst, op_gelu); +} + +void ggml_cuda_op_gelu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst, op_gelu_erf); +} + +void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst, op_gelu_quick); +} + +void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst, op_silu); +} + +void ggml_cuda_op_tanh(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst, op_tanh); +} + +void ggml_cuda_op_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst, op_relu); +} + +void ggml_cuda_op_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst, op_sigmoid); +} + +void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst, op_hardsigmoid); +} + +void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst, op_hardswish); +} + +void ggml_cuda_op_exp(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst, op_exp); +} + +void ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst, op_sqr); +} + +void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst, op_sqrt); +} + +void ggml_cuda_op_sin(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst, op_sin); +} + +void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst, op_cos); +} + +void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst, op_log); +} + +void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst, op_elu); +} + +void ggml_cuda_op_reglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary_gated(ctx, dst, op_reglu); +} + +void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary_gated(ctx, dst, op_geglu); +} + +void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary_gated(ctx, dst, op_swiglu); +} + +void ggml_cuda_op_geglu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary_gated(ctx, dst, op_geglu_erf); +} + +void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary_gated(ctx, dst, op_geglu_quick); +} + +void ggml_cuda_op_xielu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + // Get the XIELU parameters from the operation + const float * op_params = (const float*)dst->op_params; + float alpha_n = op_params[0]; + float alpha_p = op_params[1]; + const float beta = op_params[2]; + const float eps = op_params[3]; + + op_xielu_functor op(alpha_n, alpha_p, beta, eps); + ggml_cuda_op_unary(ctx, dst, op); +} + void ggml_cuda_op_swiglu_oai(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; @@ -431,30 +483,6 @@ void ggml_cuda_op_swiglu_oai(ggml_backend_cuda_context & ctx, ggml_tensor * dst) swiglu_oai_cuda(src0_p, src1_p, (float *)dst_d, ggml_nelements(dst), nc, src0_o / sizeof(float), src1_o / sizeof(float), alpha, limit, stream); } -/* silu_back */ - -static __device__ __forceinline__ float op_silu_back(float grad, float x) { - const float s = 1.0f / (1.0f + expf(-x)); - return grad * s * (1.0f + x * (1.0f - s)); -} - -template -static __global__ void silu_back_kernel(const T * grad, const T * xf, T * dst, const int k) { - const int i = blockDim.x*blockIdx.x + threadIdx.x; - - if (i >= k) { - return; - } - - dst[i] = (T)op_silu_back((float)grad[i], (float)xf[i]); -} - -template -static void silu_back_cuda(const T * grad, const T * x, T * dst, const int k, cudaStream_t stream) { - const int num_blocks = (k + CUDA_SILU_BACK_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE; - silu_back_kernel<<>>(grad, x, dst, k); -} - void ggml_cuda_op_silu_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; // input from forward pass const ggml_tensor * src1 = dst->src[1]; // grads of forward pass output @@ -478,29 +506,6 @@ void ggml_cuda_op_silu_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) } } -/* leaky relu */ - -static __device__ __forceinline__ float op_leaky_relu(float x, const float negative_slope) { - return fmaxf(x, 0) + fminf(x, 0.0f) * negative_slope; -} - -template -static __global__ void leaky_relu_kernel(const T * x, T * dst, const int k, const float negative_slope) { - const int i = blockDim.x*blockIdx.x + threadIdx.x; - - if (i >= k) { - return; - } - - dst[i] = (T)op_leaky_relu((float)x[i], negative_slope); -} - -template -static void leaky_relu_cuda(const T * x, T * dst, const int k, const float negative_slope, cudaStream_t stream) { - const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE; - leaky_relu_kernel<<>>(x, dst, k, negative_slope); -} - void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const void * src0_d = src0->data; @@ -521,4 +526,4 @@ void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) } else { leaky_relu_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), negative_slope, stream); } -} +} \ No newline at end of file From d009194e92563bbbc17af48071804e720e7a7eaa Mon Sep 17 00:00:00 2001 From: Piotr Wilkin Date: Mon, 15 Sep 2025 10:56:17 +0200 Subject: [PATCH 11/32] I do conclude that LLMs are, in fact, stupid. --- ggml/src/ggml-cuda/unary.cu | 36 +++++++----------------------------- 1 file changed, 7 insertions(+), 29 deletions(-) diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index fd6d6684c7538..37a8111eb01e3 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -86,30 +86,6 @@ static __device__ __forceinline__ float op_elu(float x) { return (x > 0.f) ? x : expm1f(x); } -static __device__ __forceinline__ float op_reglu(float x) { - return fmaxf(x, 0); -} - -static __device__ __forceinline__ float op_geglu(float x) { - const float GELU_COEF_A = 0.044715f; - const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; - return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); -} - -static __device__ __forceinline__ float op_swiglu(float x) { - return x / (1.0f + expf(-x)); -} - -static __device__ __forceinline__ float op_geglu_erf(float x) { - const float SQRT_2_INV = 0.70710678118654752440084436210484f; - return 0.5f*x*(1.0f + erff(x*SQRT_2_INV)); -} - -static __device__ __forceinline__ float op_geglu_quick(float x) { - const float GELU_QUICK_COEF = -1.702f; - return x * (1.0f / (1.0f + expf(GELU_QUICK_COEF * x))); -} - // Special operation functions (stay in original locations) static __device__ __forceinline__ float op_silu_back(float grad, float x) { const float s = 1.0f / (1.0f + expf(-x)); @@ -407,26 +383,28 @@ void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { ggml_cuda_op_unary(ctx, dst, op_elu); } +// GLU void ggml_cuda_op_reglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary_gated(ctx, dst, op_reglu); + ggml_cuda_op_unary_gated(ctx, dst, op_relu); } void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary_gated(ctx, dst, op_geglu); + ggml_cuda_op_unary_gated(ctx, dst, op_gelu); } void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary_gated(ctx, dst, op_swiglu); + ggml_cuda_op_unary_gated(ctx, dst, op_silu); } void ggml_cuda_op_geglu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary_gated(ctx, dst, op_geglu_erf); + ggml_cuda_op_unary_gated(ctx, dst, op_gelu_erf); } void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary_gated(ctx, dst, op_geglu_quick); + ggml_cuda_op_unary_gated(ctx, dst, op_gelu_quick); } +// xIELU void ggml_cuda_op_xielu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { // Get the XIELU parameters from the operation const float * op_params = (const float*)dst->op_params; From 28f9086f0ae2a4f6b272505b6c72fdc61d13d4bc Mon Sep 17 00:00:00 2001 From: "Piotr Wilkin (ilintar)" Date: Tue, 16 Sep 2025 07:59:13 +0000 Subject: [PATCH 12/32] Fix after merge --- gguf-py/gguf/constants.py | 1 + 1 file changed, 1 insertion(+) diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 1f21678815f0a..dfa3e8d841b27 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -2725,6 +2725,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_GATE, MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, + ], MODEL_ARCH.LLADA_MOE: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, From 2f68c03c55dabe5e7ab15c81c77ac9aa75753df7 Mon Sep 17 00:00:00 2001 From: "Piotr Wilkin (ilintar)" Date: Tue, 16 Sep 2025 10:14:51 +0000 Subject: [PATCH 13/32] Final newline --- ggml/src/ggml-cuda/unary.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index 37a8111eb01e3..a949ca5d20733 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -504,4 +504,4 @@ void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) } else { leaky_relu_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), negative_slope, stream); } -} \ No newline at end of file +} From 4294dbf5c3f6a74892111aba43a78617d5dae3f2 Mon Sep 17 00:00:00 2001 From: "Piotr Wilkin (ilintar)" Date: Wed, 17 Sep 2025 19:42:05 +0000 Subject: [PATCH 14/32] Make xIELU an UNARY_OP --- ggml/include/ggml.h | 2 +- ggml/src/ggml-cpu/ggml-cpu.c | 6 +----- ggml/src/ggml-cpu/ops.cpp | 4 ++++ ggml/src/ggml-cuda/ggml-cuda.cu | 6 +++--- ggml/src/ggml-cuda/unary.cu | 9 +++++---- ggml/src/ggml.c | 21 +++++++++++---------- 6 files changed, 25 insertions(+), 23 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 1cf9bdfd3986f..05a2655bd72f2 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -554,7 +554,6 @@ extern "C" { GGML_OP_OPT_STEP_SGD, GGML_OP_GLU, - GGML_OP_XIELU, GGML_OP_COUNT, }; @@ -575,6 +574,7 @@ extern "C" { GGML_UNARY_OP_HARDSIGMOID, GGML_UNARY_OP_EXP, GGML_UNARY_OP_GELU_ERF, + GGML_UNARY_OP_XIELU, GGML_UNARY_OP_COUNT, }; diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index f3922f988c97d..37d504853eee3 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -1974,10 +1974,6 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_unary(params, tensor); } break; - case GGML_OP_XIELU: - { - ggml_compute_forward_xielu(params, tensor); - } break; case GGML_OP_GLU: { ggml_compute_forward_glu(params, tensor); @@ -2144,7 +2140,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_ADD_ID: case GGML_OP_ADD1: case GGML_OP_ACC: - case GGML_OP_XIELU: { n_tasks = n_threads; } break; @@ -2192,6 +2187,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_UNARY_OP_GELU_ERF: case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_SILU: + case GGML_UNARY_OP_XIELU: { n_tasks = n_threads; } break; diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 212e52ef6a1c8..cfd5e69a04541 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -9767,6 +9767,10 @@ void ggml_compute_forward_unary( { ggml_compute_forward_exp(params, dst); } break; + case GGML_UNARY_OP_XIELU: + { + ggml_compute_forward_xielu(params, dst); + } break; default: { GGML_ABORT("fatal error"); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index c34e330c6fff1..8afa4a100f7fe 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2333,6 +2333,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_UNARY_OP_ELU: ggml_cuda_op_elu(ctx, dst); break; + case GGML_UNARY_OP_XIELU: + ggml_cuda_op_xielu(ctx, dst); + break; default: return false; } @@ -2517,9 +2520,6 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_OPT_STEP_SGD: ggml_cuda_opt_step_sgd(ctx, dst); break; - case GGML_OP_XIELU: - ggml_cuda_op_xielu(ctx, dst); - break; default: return false; } diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index a949ca5d20733..99c4fb54caba4 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -407,11 +407,12 @@ void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst // xIELU void ggml_cuda_op_xielu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { // Get the XIELU parameters from the operation + const float * op_params = (const float*)dst->op_params; - float alpha_n = op_params[0]; - float alpha_p = op_params[1]; - const float beta = op_params[2]; - const float eps = op_params[3]; + float alpha_n = ggml_get_op_params_f32(dst, 1); + float alpha_p = ggml_get_op_params_f32(dst, 2); + const float beta = ggml_get_op_params_f32(dst, 3); + const float eps = ggml_get_op_params_f32(dst, 4); op_xielu_functor op(alpha_n, alpha_p, beta, eps); ggml_cuda_op_unary(ctx, dst, op); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 900e614be0788..3415deffd0d52 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1017,10 +1017,9 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "OPT_STEP_SGD", "GLU", - "XIELU", }; -static_assert(GGML_OP_COUNT == 91, "GGML_OP_COUNT != 90"); +static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1122,10 +1121,9 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "sgd(x)", "glu(x)", - "xielu(x)", }; -static_assert(GGML_OP_COUNT == 91, "GGML_OP_COUNT != 90"); +static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -1145,9 +1143,10 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = { "HARDSIGMOID", "EXP", "GELU_ERF", + "XIELU", }; -static_assert(GGML_UNARY_OP_COUNT == 15, "GGML_UNARY_OP_COUNT != 15"); +static_assert(GGML_UNARY_OP_COUNT == 16, "GGML_UNARY_OP_COUNT != 16"); static const char * GGML_GLU_OP_NAME[GGML_GLU_OP_COUNT] = { "REGLU", @@ -2662,11 +2661,13 @@ struct ggml_tensor * ggml_xielu( float eps) { struct ggml_tensor * result = ggml_dup_tensor(ctx, a); - // Store the parameters as operation parameters - float params[] = { beta + softplus(alpha_n), softplus(alpha_p), beta, eps }; - ggml_set_op_params(result, params, sizeof(params)); + ggml_set_op_params_i32(result, 0, (int32_t) GGML_UNARY_OP_XIELU); + ggml_set_op_params_f32(result, 1, beta + softplus(alpha_n)); + ggml_set_op_params_f32(result, 2, softplus(alpha_p)); + ggml_set_op_params_f32(result, 3, beta); + ggml_set_op_params_f32(result, 4, eps); - result->op = GGML_OP_XIELU; + result->op = GGML_OP_UNARY; result->src[0] = a; return result; @@ -7236,4 +7237,4 @@ bool ggml_threadpool_params_match(const struct ggml_threadpool_params * p0, cons if (p0->poll != p1->poll ) return false; if (p0->strict_cpu != p1->strict_cpu ) return false; return memcmp(p0->cpumask, p1->cpumask, GGML_MAX_N_THREADS) == 0; -} +} \ No newline at end of file From b2aa4fb8d575b9f10d651557c7e92e6b9f0dd16a Mon Sep 17 00:00:00 2001 From: "Piotr Wilkin (ilintar)" Date: Wed, 17 Sep 2025 20:08:38 +0000 Subject: [PATCH 15/32] Final newline --- ggml/src/ggml.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 3415deffd0d52..737cc22f2af2e 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -7237,4 +7237,4 @@ bool ggml_threadpool_params_match(const struct ggml_threadpool_params * p0, cons if (p0->poll != p1->poll ) return false; if (p0->strict_cpu != p1->strict_cpu ) return false; return memcmp(p0->cpumask, p1->cpumask, GGML_MAX_N_THREADS) == 0; -} \ No newline at end of file +} From 86a239cf2179208f31f9734b7120daff9aa2ced7 Mon Sep 17 00:00:00 2001 From: Piotr Wilkin Date: Thu, 18 Sep 2025 00:24:30 +0200 Subject: [PATCH 16/32] Correctly account for parameter shift --- ggml/src/ggml-cpu/unary-ops.cpp | 11 +++++------ ggml/src/ggml-cuda/unary.cu | 2 -- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/ggml/src/ggml-cpu/unary-ops.cpp b/ggml/src/ggml-cpu/unary-ops.cpp index 80ac5f7f5588b..3c9f3b3201906 100644 --- a/ggml/src/ggml-cpu/unary-ops.cpp +++ b/ggml/src/ggml-cpu/unary-ops.cpp @@ -201,12 +201,11 @@ static float softplus(float input, float beta=1.0f, float threshold=20.0f) { void ggml_compute_forward_xielu(const ggml_compute_params * params, ggml_tensor * dst) { // Get the XIELU parameters from the operation - const float * op_params = (const float*)dst->op_params; - float alpha_n = op_params[0]; - float alpha_p = op_params[1]; - const float beta = op_params[2]; - const float eps = op_params[3]; - + float alpha_n = ggml_get_op_params_f32(dst, 1); + float alpha_p = ggml_get_op_params_f32(dst, 2); + const float beta = ggml_get_op_params_f32(dst, 3); + const float eps = ggml_get_op_params_f32(dst, 4); + const auto xielu_op_params = [alpha_n, alpha_p, beta, eps](float f) { return op_xielu(f, alpha_n, alpha_p, beta, eps); }; diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index 99c4fb54caba4..3574b298d4bad 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -407,8 +407,6 @@ void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst // xIELU void ggml_cuda_op_xielu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { // Get the XIELU parameters from the operation - - const float * op_params = (const float*)dst->op_params; float alpha_n = ggml_get_op_params_f32(dst, 1); float alpha_p = ggml_get_op_params_f32(dst, 2); const float beta = ggml_get_op_params_f32(dst, 3); From bd1902619a08392ba1ea9b7ee76b38b9e7ca9ab6 Mon Sep 17 00:00:00 2001 From: Piotr Wilkin Date: Thu, 18 Sep 2025 00:26:20 +0200 Subject: [PATCH 17/32] Argh. --- ggml/src/ggml-cpu/unary-ops.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-cpu/unary-ops.cpp b/ggml/src/ggml-cpu/unary-ops.cpp index 3c9f3b3201906..262a5bd2f9fa7 100644 --- a/ggml/src/ggml-cpu/unary-ops.cpp +++ b/ggml/src/ggml-cpu/unary-ops.cpp @@ -205,7 +205,7 @@ void ggml_compute_forward_xielu(const ggml_compute_params * params, ggml_tensor float alpha_p = ggml_get_op_params_f32(dst, 2); const float beta = ggml_get_op_params_f32(dst, 3); const float eps = ggml_get_op_params_f32(dst, 4); - + const auto xielu_op_params = [alpha_n, alpha_p, beta, eps](float f) { return op_xielu(f, alpha_n, alpha_p, beta, eps); }; From 13cc3be2ba00673f3dedda3fb7deb6a03f9d18d6 Mon Sep 17 00:00:00 2001 From: "Piotr Wilkin (ilintar)" Date: Thu, 18 Sep 2025 09:58:53 +0200 Subject: [PATCH 18/32] Update ggml/src/ggml-cpu/unary-ops.cpp Co-authored-by: Georgi Gerganov --- ggml/src/ggml-cpu/unary-ops.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-cpu/unary-ops.cpp b/ggml/src/ggml-cpu/unary-ops.cpp index 262a5bd2f9fa7..d7c395f15f5b2 100644 --- a/ggml/src/ggml-cpu/unary-ops.cpp +++ b/ggml/src/ggml-cpu/unary-ops.cpp @@ -74,7 +74,7 @@ static inline float op_log(float x) { } template -static inline void vec_unary_op(const Op& op, int64_t n, dst_t * y, const src0_t * x) { +static inline void vec_unary_op(const Op & op, int64_t n, dst_t * y, const src0_t * x) { constexpr auto src0_to_f32 = type_conversion_table::to_f32; constexpr auto f32_to_dst = type_conversion_table::from_f32; From 40f2f807d7e4656cf810dd7fd857b5ba417cedf1 Mon Sep 17 00:00:00 2001 From: Piotr Wilkin Date: Thu, 18 Sep 2025 12:51:18 +0200 Subject: [PATCH 19/32] Refactor: remove unused methods, inline and factorize softplus, add const modifiers --- ggml/src/ggml-cpu/ops.cpp | 4 ++-- ggml/src/ggml-cpu/unary-ops.cpp | 5 ----- ggml/src/ggml-cuda/unary.cu | 10 ++++------ ggml/src/ggml-impl.h | 4 ++++ ggml/src/ggml.c | 16 ++++++---------- src/llama-model.cpp | 13 ------------- 6 files changed, 16 insertions(+), 36 deletions(-) diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index cfd5e69a04541..0776051fedce6 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -9407,7 +9407,7 @@ static void ggml_compute_forward_ssm_scan_f32( // n_head for (int h = ih0; h < ih1; ++h) { // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16 - const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h]; + const float dt_soft_plus = softplus(dt[h]); const float dA = expf(dt_soft_plus * A[h]); const int g = h / (nh / ng); // repeat_interleave @@ -9504,7 +9504,7 @@ static void ggml_compute_forward_ssm_scan_f32( // n_head for (int h = ih0; h < ih1; ++h) { // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16 - const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h]; + const float dt_soft_plus = softplus(dt[h]); const int g = h / (nh / ng); // repeat_interleave // dim diff --git a/ggml/src/ggml-cpu/unary-ops.cpp b/ggml/src/ggml-cpu/unary-ops.cpp index d7c395f15f5b2..dfd7ffc3e7e71 100644 --- a/ggml/src/ggml-cpu/unary-ops.cpp +++ b/ggml/src/ggml-cpu/unary-ops.cpp @@ -194,11 +194,6 @@ void ggml_compute_forward_log(const ggml_compute_params * params, ggml_tensor * unary_op(op_log, params, dst); } -static float softplus(float input, float beta=1.0f, float threshold=20.0f) { - if (input * beta > threshold) return input; - return (1/beta) * logf(1 + expf(beta * input)); -} - void ggml_compute_forward_xielu(const ggml_compute_params * params, ggml_tensor * dst) { // Get the XIELU parameters from the operation float alpha_n = ggml_get_op_params_f32(dst, 1); diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index 3574b298d4bad..7d859e6559d23 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -211,7 +211,6 @@ void ggml_cuda_op_unary_gated(ggml_backend_cuda_context & ctx, ggml_tensor * dst } } -// Functor for XIELU operation with parameters struct op_xielu_functor { float alpha_n, alpha_p, beta, eps; @@ -219,14 +218,14 @@ struct op_xielu_functor { : alpha_n(a_n), alpha_p(a_p), beta(b), eps(e) {} __device__ __forceinline__ float operator()(float x) const { - float gate_pos = (x > 0.0f); // positive branch gate + const float gate_pos = (x > 0.0f); // positive branch gate // Positive branch: alpha_p * v^2 + beta * v - float y_pos = alpha_p * x * x + beta * x; + const float y_pos = alpha_p * x * x + beta * x; // Negative branch: - float min_v_eps = fminf(x, eps); // works fine even if eps < 0 - float y_neg = (expm1f(min_v_eps) - x) * alpha_n + beta * x; + const float min_v_eps = fminf(x, eps); // works fine even if eps < 0 + const float y_neg = (expm1f(min_v_eps) - x) * alpha_n + beta * x; // Select the appropriate branch based on the gate return gate_pos * y_pos + (1.0f - gate_pos) * y_neg; @@ -234,7 +233,6 @@ struct op_xielu_functor { }; // swiglu_oai - template static __global__ void swiglu_oai_kernel(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, float alpha, float limit) { const int64_t i = int64_t(blockDim.x)*blockIdx.x + threadIdx.x; diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h index 19a7adb2d101b..0ff434f45ca25 100644 --- a/ggml/src/ggml-impl.h +++ b/ggml/src/ggml-impl.h @@ -89,6 +89,10 @@ static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml return true; } +static inline float softplus(float input) { + return (input > 20.0f) ? input : logf(1 + expf(input)); +} + // // logging // diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 737cc22f2af2e..ee99100a83937 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -2646,12 +2646,14 @@ struct ggml_tensor * ggml_silu( return ggml_unary(ctx, a, GGML_UNARY_OP_SILU); } -// ggml_xielu -static float softplus(float input) { - if (input > 20.0f) return input; - return logf(1 + expf(input)); +struct ggml_tensor * ggml_silu_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_SILU); } +// ggml_xielu + struct ggml_tensor * ggml_xielu( struct ggml_context * ctx, struct ggml_tensor * a, @@ -2673,12 +2675,6 @@ struct ggml_tensor * ggml_xielu( return result; } -struct ggml_tensor * ggml_silu_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a) { - return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_SILU); -} - // ggml_silu_back struct ggml_tensor * ggml_silu_back( diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 14ec518322c45..2a0dd2dac2742 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -18852,19 +18852,6 @@ struct llm_build_smallthinker : public llm_graph_context{ } }; -// TODO: maybe put this as a general helper in ggml.c? -static float get_scalar_f32_val(const ggml_tensor *t) { - float onef; - if (t->buffer) { - ggml_backend_tensor_get(t, &onef, 0, sizeof(float)); - } else { - GGML_ASSERT(t->data); - onef = *((float *) t->data); - } - return onef; -} - -// Apertus model graph builder with xIELU activation struct llm_build_apertus : public llm_graph_context { llm_build_apertus(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; From 0d36139c1615f589f318ec25fde0515e709411e4 Mon Sep 17 00:00:00 2001 From: Piotr Wilkin Date: Tue, 23 Sep 2025 12:08:43 +0200 Subject: [PATCH 20/32] Revert CUDA changes, implement xIELU as a separate OP --- ggml/src/ggml-cuda/unary.cu | 439 +++++++++++++++++++---------------- ggml/src/ggml-cuda/unary.cuh | 1 + 2 files changed, 241 insertions(+), 199 deletions(-) diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index 7d859e6559d23..0bd3de3c40ef6 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -1,7 +1,5 @@ #include "unary.cuh" - -// Simple inline operation functions (moved to top) static __device__ __forceinline__ float op_abs(float x) { return fabsf(x); } @@ -21,16 +19,19 @@ static __device__ __forceinline__ float op_step(float x) { static __device__ __forceinline__ float op_gelu(float x) { const float GELU_COEF_A = 0.044715f; const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; + return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); } static __device__ __forceinline__ float op_gelu_erf(float x) { const float SQRT_2_INV = 0.70710678118654752440084436210484f; + return 0.5f*x*(1.0f + erff(x*SQRT_2_INV)); } static __device__ __forceinline__ float op_gelu_quick(float x) { const float GELU_QUICK_COEF = -1.702f; + return x * (1.0f / (1.0f + expf(GELU_QUICK_COEF * x))); } @@ -86,19 +87,8 @@ static __device__ __forceinline__ float op_elu(float x) { return (x > 0.f) ? x : expm1f(x); } -// Special operation functions (stay in original locations) -static __device__ __forceinline__ float op_silu_back(float grad, float x) { - const float s = 1.0f / (1.0f + expf(-x)); - return grad * s * (1.0f + x * (1.0f - s)); -} - -static __device__ __forceinline__ float op_leaky_relu(float x, const float negative_slope) { - return fmaxf(x, 0) + fminf(x, 0.0f) * negative_slope; -} - -// Template functions and kernels -template -static __global__ void unary_op_kernel(const T * x, T * dst, const int k, const Op op) { +template +static __global__ void unary_op_kernel(const T * x, T * dst, const int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { @@ -108,14 +98,14 @@ static __global__ void unary_op_kernel(const T * x, T * dst, const int k, const dst[i] = (T)op((float)x[i]); } -template -static void unary_cuda(const T * x, T * dst, const int k, const Op op, cudaStream_t stream) { +template +static void unary_cuda(const T * x, T * dst, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_NEG_BLOCK_SIZE - 1) / CUDA_NEG_BLOCK_SIZE; - unary_op_kernel<<>>(x, dst, k, op); + unary_op_kernel<<>>(x, dst, k); } -template -void ggml_cuda_op_unary(ggml_backend_cuda_context & ctx, ggml_tensor * dst, const Op op) { +template +void ggml_cuda_op_unary(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const void * src0_d = src0->data; void * dst_d = dst->data; @@ -128,16 +118,95 @@ void ggml_cuda_op_unary(ggml_backend_cuda_context & ctx, ggml_tensor * dst, cons GGML_ASSERT(src0->type == dst->type); if (src0->type == GGML_TYPE_F16) { - unary_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), op, stream); + unary_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream); } else { - unary_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), op, stream); + unary_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream); } } +void ggml_cuda_op_abs(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst); +} + +void ggml_cuda_op_sgn(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst); +} + +void ggml_cuda_op_neg(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst); +} + +void ggml_cuda_op_step(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst); +} + +void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst); +} + +void ggml_cuda_op_gelu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst); +} + +void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst); +} + +void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst); +} + +void ggml_cuda_op_tanh(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst); +} + +void ggml_cuda_op_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst); +} + +void ggml_cuda_op_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst); +} + +void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst); +} + +void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst); +} + +void ggml_cuda_op_exp(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst); +} + +void ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst); +} + +void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst); +} + +void ggml_cuda_op_sin(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst); +} + +void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst); +} + +void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst); +} + +void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst); +} /* gated ops */ -template -static __global__ void unary_gated_op_kernel(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, const Op op) { +template +static __global__ void unary_gated_op_kernel(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1) { const int64_t i = int64_t(blockDim.x)*blockIdx.x + threadIdx.x; if (i >= k) { @@ -151,14 +220,14 @@ static __global__ void unary_gated_op_kernel(const T * x, const T * g, T * dst, dst[i] = (T)(op((float)x[j0]) * (float)g[j1]); } -template -static void unary_gated_cuda(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, const Op op, cudaStream_t stream) { +template +static void unary_gated_cuda(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, cudaStream_t stream) { const int64_t num_blocks = (k + CUDA_GLU_BLOCK_SIZE - 1) / CUDA_GLU_BLOCK_SIZE; - unary_gated_op_kernel<<>>(x, g, dst, k, n, o0, o1, op); + unary_gated_op_kernel<<>>(x, g, dst, k, n, o0, o1); } -template -void ggml_cuda_op_unary_gated(ggml_backend_cuda_context & ctx, ggml_tensor * dst, const Op op) { +template +void ggml_cuda_op_unary_gated(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; void * src0_d = src0->data; @@ -197,7 +266,7 @@ void ggml_cuda_op_unary_gated(ggml_backend_cuda_context & ctx, ggml_tensor * dst src1_p += swapped ? 0 : nc; } - unary_gated_cuda(src0_p, src1_p, (half *)dst_d, ggml_nelements(dst), nc, src0_o / sizeof(half), src1_o / sizeof(half), op, stream); + unary_gated_cuda(src0_p, src1_p, (half *)dst_d, ggml_nelements(dst), nc, src0_o / sizeof(half), src1_o / sizeof(half), stream); } else { float * src0_p = (float *) src0_d; float * src1_p = (float *) src1_d; @@ -207,32 +276,32 @@ void ggml_cuda_op_unary_gated(ggml_backend_cuda_context & ctx, ggml_tensor * dst src1_p += swapped ? 0 : nc; } - unary_gated_cuda(src0_p, src1_p, (float *)dst_d, ggml_nelements(dst), nc, src0_o / sizeof(float), src1_o / sizeof(float), op, stream); + unary_gated_cuda(src0_p, src1_p, (float *)dst_d, ggml_nelements(dst), nc, src0_o / sizeof(float), src1_o / sizeof(float), stream); } } -struct op_xielu_functor { - float alpha_n, alpha_p, beta, eps; - - __host__ __device__ __forceinline__ op_xielu_functor(float a_n, float a_p, float b, float e) - : alpha_n(a_n), alpha_p(a_p), beta(b), eps(e) {} +void ggml_cuda_op_reglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary_gated(ctx, dst); +} - __device__ __forceinline__ float operator()(float x) const { - const float gate_pos = (x > 0.0f); // positive branch gate +void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary_gated(ctx, dst); +} - // Positive branch: alpha_p * v^2 + beta * v - const float y_pos = alpha_p * x * x + beta * x; +void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary_gated(ctx, dst); +} - // Negative branch: - const float min_v_eps = fminf(x, eps); // works fine even if eps < 0 - const float y_neg = (expm1f(min_v_eps) - x) * alpha_n + beta * x; +void ggml_cuda_op_geglu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary_gated(ctx, dst); +} - // Select the appropriate branch based on the gate - return gate_pos * y_pos + (1.0f - gate_pos) * y_neg; - } -}; +void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary_gated(ctx, dst); +} // swiglu_oai + template static __global__ void swiglu_oai_kernel(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, float alpha, float limit) { const int64_t i = int64_t(blockDim.x)*blockIdx.x + threadIdx.x; @@ -262,158 +331,6 @@ static void swiglu_oai_cuda(const T * x, const T * g, T * dst, const int64_t k, swiglu_oai_kernel<<>>(x, g, dst, k, n, o0, o1, alpha, limit); } -/* silu_back */ - -template -static __global__ void silu_back_kernel(const T * grad, const T * xf, T * dst, const int k) { - const int i = blockDim.x*blockIdx.x + threadIdx.x; - - if (i >= k) { - return; - } - - dst[i] = (T)op_silu_back((float)grad[i], (float)xf[i]); -} - -template -static void silu_back_cuda(const T * grad, const T * x, T * dst, const int k, cudaStream_t stream) { - const int num_blocks = (k + CUDA_SILU_BACK_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE; - silu_back_kernel<<>>(grad, x, dst, k); -} - -/* leaky relu */ - -template -static __global__ void leaky_relu_kernel(const T * x, T * dst, const int k, const float negative_slope) { - const int i = blockDim.x*blockIdx.x + threadIdx.x; - - if (i >= k) { - return; - } - - dst[i] = (T)op_leaky_relu((float)x[i], negative_slope); -} - -template -static void leaky_relu_cuda(const T * x, T * dst, const int k, const float negative_slope, cudaStream_t stream) { - const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE; - leaky_relu_kernel<<>>(x, dst, k, negative_slope); -} - -// All ggml_cuda_op_* functions moved to bottom -void ggml_cuda_op_abs(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary(ctx, dst, op_abs); -} - -void ggml_cuda_op_sgn(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary(ctx, dst, op_sgn); -} - -void ggml_cuda_op_neg(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary(ctx, dst, op_neg); -} - -void ggml_cuda_op_step(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary(ctx, dst, op_step); -} - -void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary(ctx, dst, op_gelu); -} - -void ggml_cuda_op_gelu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary(ctx, dst, op_gelu_erf); -} - -void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary(ctx, dst, op_gelu_quick); -} - -void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary(ctx, dst, op_silu); -} - -void ggml_cuda_op_tanh(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary(ctx, dst, op_tanh); -} - -void ggml_cuda_op_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary(ctx, dst, op_relu); -} - -void ggml_cuda_op_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary(ctx, dst, op_sigmoid); -} - -void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary(ctx, dst, op_hardsigmoid); -} - -void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary(ctx, dst, op_hardswish); -} - -void ggml_cuda_op_exp(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary(ctx, dst, op_exp); -} - -void ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary(ctx, dst, op_sqr); -} - -void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary(ctx, dst, op_sqrt); -} - -void ggml_cuda_op_sin(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary(ctx, dst, op_sin); -} - -void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary(ctx, dst, op_cos); -} - -void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary(ctx, dst, op_log); -} - -void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary(ctx, dst, op_elu); -} - -// GLU -void ggml_cuda_op_reglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary_gated(ctx, dst, op_relu); -} - -void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary_gated(ctx, dst, op_gelu); -} - -void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary_gated(ctx, dst, op_silu); -} - -void ggml_cuda_op_geglu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary_gated(ctx, dst, op_gelu_erf); -} - -void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_unary_gated(ctx, dst, op_gelu_quick); -} - -// xIELU -void ggml_cuda_op_xielu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - // Get the XIELU parameters from the operation - float alpha_n = ggml_get_op_params_f32(dst, 1); - float alpha_p = ggml_get_op_params_f32(dst, 2); - const float beta = ggml_get_op_params_f32(dst, 3); - const float eps = ggml_get_op_params_f32(dst, 4); - - op_xielu_functor op(alpha_n, alpha_p, beta, eps); - ggml_cuda_op_unary(ctx, dst, op); -} - void ggml_cuda_op_swiglu_oai(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; @@ -458,6 +375,107 @@ void ggml_cuda_op_swiglu_oai(ggml_backend_cuda_context & ctx, ggml_tensor * dst) swiglu_oai_cuda(src0_p, src1_p, (float *)dst_d, ggml_nelements(dst), nc, src0_o / sizeof(float), src1_o / sizeof(float), alpha, limit, stream); } +/* xIELU */ +struct op_xielu_functor { + float alpha_n, alpha_p, beta, eps; + + __host__ __device__ __forceinline__ op_xielu_functor(float a_n, float a_p, float b, float e) + : alpha_n(a_n), alpha_p(a_p), beta(b), eps(e) {} + + __device__ __forceinline__ float operator()(float x) const { + const float gate_pos = (x > 0.0f); // positive branch gate + + // Positive branch: alpha_p * v^2 + beta * v + const float y_pos = alpha_p * x * x + beta * x; + + // Negative branch: + const float min_v_eps = fminf(x, eps); // works fine even if eps < 0 + const float y_neg = (expm1f(min_v_eps) - x) * alpha_n + beta * x; + + // Select the appropriate branch based on the gate + return gate_pos * y_pos + (1.0f - gate_pos) * y_neg; + } +}; + +/* CUDA kernel + launcher for xIELU */ + +template +static __global__ void xielu_kernel(const T * x, T * dst, const int k, float alpha_n, float alpha_p, float beta, float eps) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + + const float xi = (float)x[i]; + const float gate_pos = (xi > 0.0f); + + const float y_pos = alpha_p * xi * xi + beta * xi; + + const float min_v_eps = fminf(xi, eps); + const float y_neg = (expm1f(min_v_eps) - xi) * alpha_n + beta * xi; + + const float out = gate_pos * y_pos + (1.0f - gate_pos) * y_neg; + + dst[i] = (T)out; +} + +template +static void xielu_cuda(const T * x, T * dst, const int k, float alpha_n, float alpha_p, float beta, float eps, cudaStream_t stream) { + const int num_blocks = (k + CUDA_XIELU_BLOCK_SIZE) / CUDA_XIELU_BLOCK_SIZE; + xielu_kernel<<>>(x, dst, k, alpha_n, alpha_p, beta, eps); +} + +void ggml_cuda_op_xielu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const void * src0_d = src0->data; + void * dst_d = dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(ggml_is_contiguous(src0)); + + GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); + GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); + GGML_ASSERT(src0->type == dst->type); + + const float alpha_n = ggml_get_op_params_f32(dst, 1); + const float alpha_p = ggml_get_op_params_f32(dst, 2); + const float beta = ggml_get_op_params_f32(dst, 3); + const float eps = ggml_get_op_params_f32(dst, 4); + + if (src0->type == GGML_TYPE_F16) { + xielu_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), alpha_n, alpha_p, beta, eps, stream); + } else { + xielu_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), alpha_n, alpha_p, beta, eps, stream); + } +} + + + +/* silu_back */ + +static __device__ __forceinline__ float op_silu_back(float grad, float x) { + const float s = 1.0f / (1.0f + expf(-x)); + return grad * s * (1.0f + x * (1.0f - s)); +} + +template +static __global__ void silu_back_kernel(const T * grad, const T * xf, T * dst, const int k) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + + dst[i] = (T)op_silu_back((float)grad[i], (float)xf[i]); +} + +template +static void silu_back_cuda(const T * grad, const T * x, T * dst, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_SILU_BACK_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE; + silu_back_kernel<<>>(grad, x, dst, k); +} + void ggml_cuda_op_silu_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; // input from forward pass const ggml_tensor * src1 = dst->src[1]; // grads of forward pass output @@ -481,6 +499,29 @@ void ggml_cuda_op_silu_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) } } +/* leaky relu */ + +static __device__ __forceinline__ float op_leaky_relu(float x, const float negative_slope) { + return fmaxf(x, 0) + fminf(x, 0.0f) * negative_slope; +} + +template +static __global__ void leaky_relu_kernel(const T * x, T * dst, const int k, const float negative_slope) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + + dst[i] = (T)op_leaky_relu((float)x[i], negative_slope); +} + +template +static void leaky_relu_cuda(const T * x, T * dst, const int k, const float negative_slope, cudaStream_t stream) { + const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE; + leaky_relu_kernel<<>>(x, dst, k, negative_slope); +} + void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const void * src0_d = src0->data; diff --git a/ggml/src/ggml-cuda/unary.cuh b/ggml/src/ggml-cuda/unary.cuh index 9213c23d76bf2..8e7644fcd9a48 100644 --- a/ggml/src/ggml-cuda/unary.cuh +++ b/ggml/src/ggml-cuda/unary.cuh @@ -16,6 +16,7 @@ #define CUDA_SIN_BLOCK_SIZE 256 #define CUDA_COS_BLOCK_SIZE 256 #define CUDA_GLU_BLOCK_SIZE 256 +#define CUDA_XIELU_BLOCK_SIZE 256 void ggml_cuda_op_abs(ggml_backend_cuda_context & ctx, ggml_tensor * dst); From fc58d3b3ac33165aaf1b1ecf3f41e713cf0179f2 Mon Sep 17 00:00:00 2001 From: Piotr Wilkin Date: Tue, 23 Sep 2025 12:09:15 +0200 Subject: [PATCH 21/32] Pesky newline --- ggml/src/ggml-impl.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h index 044e3f1032fe3..5b54f98f516e2 100644 --- a/ggml/src/ggml-impl.h +++ b/ggml/src/ggml-impl.h @@ -101,7 +101,7 @@ static bool ggml_op_is_empty(enum ggml_op op) { return false; } } - + static inline float softplus(float input) { return (input > 20.0f) ? input : logf(1 + expf(input)); } From db9eb295cbdfcb9401ade624aca17bae1cd4280b Mon Sep 17 00:00:00 2001 From: "Piotr Wilkin (ilintar)" Date: Tue, 23 Sep 2025 22:06:19 +0000 Subject: [PATCH 22/32] Add float2half / half2float for F16 inputs/outputs --- ggml/src/ggml-cuda/unary.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index 0bd3de3c40ef6..226240c0c5472 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -407,7 +407,7 @@ static __global__ void xielu_kernel(const T * x, T * dst, const int k, float alp return; } - const float xi = (float)x[i]; + const float xi = x->type == GGML_TYPE_F32 ? (float) x[i] : __half2float(x[i]); const float gate_pos = (xi > 0.0f); const float y_pos = alpha_p * xi * xi + beta * xi; @@ -417,7 +417,7 @@ static __global__ void xielu_kernel(const T * x, T * dst, const int k, float alp const float out = gate_pos * y_pos + (1.0f - gate_pos) * y_neg; - dst[i] = (T)out; + dst[i] = (T) (dst->type == GGML_TYPE_F32 ? out : __float2half(out)); } template From dc1e4d5f879287fcff348546e52fd7da4820f531 Mon Sep 17 00:00:00 2001 From: "Piotr Wilkin (ilintar)" Date: Wed, 24 Sep 2025 18:43:26 +0000 Subject: [PATCH 23/32] CUDA variants, attempt 2 --- ggml/src/ggml-cuda/unary.cu | 26 ++------------------------ 1 file changed, 2 insertions(+), 24 deletions(-) diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index 226240c0c5472..62bdba7b17b75 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -375,28 +375,6 @@ void ggml_cuda_op_swiglu_oai(ggml_backend_cuda_context & ctx, ggml_tensor * dst) swiglu_oai_cuda(src0_p, src1_p, (float *)dst_d, ggml_nelements(dst), nc, src0_o / sizeof(float), src1_o / sizeof(float), alpha, limit, stream); } -/* xIELU */ -struct op_xielu_functor { - float alpha_n, alpha_p, beta, eps; - - __host__ __device__ __forceinline__ op_xielu_functor(float a_n, float a_p, float b, float e) - : alpha_n(a_n), alpha_p(a_p), beta(b), eps(e) {} - - __device__ __forceinline__ float operator()(float x) const { - const float gate_pos = (x > 0.0f); // positive branch gate - - // Positive branch: alpha_p * v^2 + beta * v - const float y_pos = alpha_p * x * x + beta * x; - - // Negative branch: - const float min_v_eps = fminf(x, eps); // works fine even if eps < 0 - const float y_neg = (expm1f(min_v_eps) - x) * alpha_n + beta * x; - - // Select the appropriate branch based on the gate - return gate_pos * y_pos + (1.0f - gate_pos) * y_neg; - } -}; - /* CUDA kernel + launcher for xIELU */ template @@ -407,7 +385,7 @@ static __global__ void xielu_kernel(const T * x, T * dst, const int k, float alp return; } - const float xi = x->type == GGML_TYPE_F32 ? (float) x[i] : __half2float(x[i]); + const float xi = sizeof(x[i]) == sizeof(half) ? __half2float(x[i]) : (float) x[i]; const float gate_pos = (xi > 0.0f); const float y_pos = alpha_p * xi * xi + beta * xi; @@ -417,7 +395,7 @@ static __global__ void xielu_kernel(const T * x, T * dst, const int k, float alp const float out = gate_pos * y_pos + (1.0f - gate_pos) * y_neg; - dst[i] = (T) (dst->type == GGML_TYPE_F32 ? out : __float2half(out)); + dst[i] = (T) (sizeof(dst[i]) == sizeof(float)) ? out : __float2half(out)); } template From 6c843cec13d736468eefaff93a091f1eb0c04161 Mon Sep 17 00:00:00 2001 From: "Piotr Wilkin (ilintar)" Date: Wed, 24 Sep 2025 18:44:25 +0000 Subject: [PATCH 24/32] Actually, attempt 3 --- ggml/src/ggml-cuda/unary.cu | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index 62bdba7b17b75..6c4c3a0d9e595 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -385,17 +385,30 @@ static __global__ void xielu_kernel(const T * x, T * dst, const int k, float alp return; } - const float xi = sizeof(x[i]) == sizeof(half) ? __half2float(x[i]) : (float) x[i]; - const float gate_pos = (xi > 0.0f); + float xi; +#if __CUDA_ARCH__ >= 530 + if constexpr (std::is_same::value) { + xi = __half2float(x[i]); + } else +#endif + { + xi = (float)x[i]; + } + const float gate_pos = (xi > 0.0f); const float y_pos = alpha_p * xi * xi + beta * xi; - const float min_v_eps = fminf(xi, eps); const float y_neg = (expm1f(min_v_eps) - xi) * alpha_n + beta * xi; - const float out = gate_pos * y_pos + (1.0f - gate_pos) * y_neg; - dst[i] = (T) (sizeof(dst[i]) == sizeof(float)) ? out : __float2half(out)); +#if __CUDA_ARCH__ >= 530 + if constexpr (std::is_same::value) { + dst[i] = __float2half(out); + } else +#endif + { + dst[i] = (T)out; + } } template From f11ab3ca2da5beb436e838b98e365ae835708606 Mon Sep 17 00:00:00 2001 From: "Piotr Wilkin (ilintar)" Date: Thu, 25 Sep 2025 10:58:57 +0200 Subject: [PATCH 25/32] Update ggml/src/ggml-cuda/unary.cu MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Johannes Gäßler --- ggml/src/ggml-cuda/unary.cu | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index 6c4c3a0d9e595..d67a40c446334 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -385,15 +385,7 @@ static __global__ void xielu_kernel(const T * x, T * dst, const int k, float alp return; } - float xi; -#if __CUDA_ARCH__ >= 530 - if constexpr (std::is_same::value) { - xi = __half2float(x[i]); - } else -#endif - { - xi = (float)x[i]; - } + const float xi = ggml_cuda_cast(x[i]); const float gate_pos = (xi > 0.0f); const float y_pos = alpha_p * xi * xi + beta * xi; From 57e226302cb7246f87d0e43de08d7809ef7ff4e5 Mon Sep 17 00:00:00 2001 From: "Piotr Wilkin (ilintar)" Date: Thu, 25 Sep 2025 10:24:30 +0000 Subject: [PATCH 26/32] Missing convert header --- ggml/src/ggml-cuda/unary.cu | 1 + 1 file changed, 1 insertion(+) diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index d67a40c446334..17a946bc8715d 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -1,4 +1,5 @@ #include "unary.cuh" +#include "convert.cuh" static __device__ __forceinline__ float op_abs(float x) { return fabsf(x); From 62401d8835076cafa2efbeef6e911a663820ea9e Mon Sep 17 00:00:00 2001 From: Piotr Wilkin Date: Thu, 25 Sep 2025 19:14:47 +0200 Subject: [PATCH 27/32] Proper formula and reference for xIELU in the comments. --- ggml/include/ggml.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 6ebe0bfb6c34c..bf40486c8518f 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -1150,7 +1150,8 @@ extern "C" { struct ggml_tensor * a); // xIELU activation function - // x = x * (alpha_n + alpha_p * sigmoid(beta * x)) + eps * (x > 0) + // https://arxiv.org/abs/2411.13010 + // x = (x > 0) ? a_p * x^2 + b * x : (exmp1(min(x, eps)) - x) * a_n + b * x GGML_API struct ggml_tensor * ggml_xielu( struct ggml_context * ctx, struct ggml_tensor * a, From 58e6e0f8e5eec3c39ff733a269bae430ffd5a6a2 Mon Sep 17 00:00:00 2001 From: Piotr Wilkin Date: Thu, 25 Sep 2025 21:21:02 +0200 Subject: [PATCH 28/32] Modify unary-ops.cpp to add the functor-based logic besides the template system to retain optimizations --- ggml/src/ggml-cpu/unary-ops.cpp | 139 +++++++++++++++++++++++++------- 1 file changed, 109 insertions(+), 30 deletions(-) diff --git a/ggml/src/ggml-cpu/unary-ops.cpp b/ggml/src/ggml-cpu/unary-ops.cpp index dfd7ffc3e7e71..cf1a4615d042c 100644 --- a/ggml/src/ggml-cpu/unary-ops.cpp +++ b/ggml/src/ggml-cpu/unary-ops.cpp @@ -73,8 +73,87 @@ static inline float op_log(float x) { return logf(x); } +template +static inline void vec_unary_op(int64_t n, dst_t * y, const src0_t * x) { + constexpr auto src0_to_f32 = type_conversion_table::to_f32; + constexpr auto f32_to_dst = type_conversion_table::from_f32; + + for (int i = 0; i < n; i++) { + y[i] = f32_to_dst(op(src0_to_f32(x[i]))); + } +} + +template +static void apply_unary_op(const ggml_compute_params * params, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + + GGML_ASSERT(ggml_is_contiguous_1(src0) && ggml_is_contiguous_1(dst) && ggml_are_same_shape(src0, dst)); + + GGML_TENSOR_UNARY_OP_LOCALS + + GGML_ASSERT( nb0 == sizeof(dst_t)); + GGML_ASSERT(nb00 == sizeof(src0_t)); + + const auto [ir0, ir1] = get_thread_range(params, src0); + + for (int64_t ir = ir0; ir < ir1; ++ir) { + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + + dst_t * dst_ptr = (dst_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); + const src0_t * src0_ptr = (const src0_t *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); + + vec_unary_op(ne0, dst_ptr, src0_ptr); + } +} + +// TODO: Use the 'traits' lookup table (for type conversion fns), instead of a mass of 'if' conditions with long templates +template +static void unary_op(const ggml_compute_params * params, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + + /* */ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { // all f32 + apply_unary_op(params, dst); + } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { // all f16 + apply_unary_op(params, dst); + } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_BF16) { // all bf16 + apply_unary_op(params, dst); + } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_F32) { + apply_unary_op(params, dst); + } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) { + apply_unary_op(params, dst); + } else { + fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s\n", __func__, + ggml_type_name(dst->type), ggml_type_name(src0->type)); + GGML_ABORT("fatal error"); + } +} + +template +static void unary_op_params(const ggml_compute_params * params, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + + /* */ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { // all f32 + apply_unary_op(params, dst); + } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { // all f16 + apply_unary_op(params, dst); + } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_BF16) { // all bf16 + apply_unary_op(params, dst); + } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_F32) { + apply_unary_op(params, dst); + } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) { + apply_unary_op(params, dst); + } else { + fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s\n", __func__, + ggml_type_name(dst->type), ggml_type_name(src0->type)); + GGML_ABORT("fatal error"); + } +} + +// Extend vec_unary_op to support functors template -static inline void vec_unary_op(const Op & op, int64_t n, dst_t * y, const src0_t * x) { +static inline void vec_unary_op_functor(int64_t n, dst_t * y, const src0_t * x, Op op) { constexpr auto src0_to_f32 = type_conversion_table::to_f32; constexpr auto f32_to_dst = type_conversion_table::from_f32; @@ -83,8 +162,9 @@ static inline void vec_unary_op(const Op & op, int64_t n, dst_t * y, const src0_ } } +// Extend apply_unary_op to support functors template -static void apply_unary_op(const Op& op, const ggml_compute_params * params, ggml_tensor * dst) { +static void apply_unary_op_functor(const ggml_compute_params * params, ggml_tensor * dst, Op op) { const ggml_tensor * src0 = dst->src[0]; GGML_ASSERT(ggml_is_contiguous_1(src0) && ggml_is_contiguous_1(dst) && ggml_are_same_shape(src0, dst)); @@ -104,25 +184,25 @@ static void apply_unary_op(const Op& op, const ggml_compute_params * params, ggm dst_t * dst_ptr = (dst_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); const src0_t * src0_ptr = (const src0_t *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); - vec_unary_op(op, ne0, dst_ptr, src0_ptr); + vec_unary_op_functor(ne0, dst_ptr, src0_ptr, op); } } -// TODO: Use the 'traits' lookup table (for type conversion fns), instead of a mass of 'if' conditions with long templates +// Generic dispatcher for functors template -static void unary_op(const Op& op, const ggml_compute_params * params, ggml_tensor * dst) { +static void unary_op_functor(const ggml_compute_params * params, ggml_tensor * dst, Op op) { const ggml_tensor * src0 = dst->src[0]; /* */ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { // all f32 - apply_unary_op(op, params, dst); + apply_unary_op_functor(params, dst, op); } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { // all f16 - apply_unary_op(op, params, dst); + apply_unary_op_functor(params, dst, op); } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_BF16) { // all bf16 - apply_unary_op(op, params, dst); + apply_unary_op_functor(params, dst, op); } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_F32) { - apply_unary_op(op, params, dst); + apply_unary_op_functor(params, dst, op); } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) { - apply_unary_op(op, params, dst); + apply_unary_op_functor(params, dst, op); } else { fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s\n", __func__, ggml_type_name(dst->type), ggml_type_name(src0->type)); @@ -131,73 +211,72 @@ static void unary_op(const Op& op, const ggml_compute_params * params, ggml_tens } void ggml_compute_forward_abs(const ggml_compute_params * params, ggml_tensor * dst) { - unary_op(op_abs, params, dst); + unary_op(params, dst); } void ggml_compute_forward_sgn(const ggml_compute_params * params, ggml_tensor * dst) { - unary_op(op_sgn, params, dst); + unary_op(params, dst); } void ggml_compute_forward_neg(const ggml_compute_params * params, ggml_tensor * dst) { - unary_op(op_neg, params, dst); + unary_op(params, dst); } void ggml_compute_forward_step(const ggml_compute_params * params, ggml_tensor * dst) { - unary_op(op_step, params, dst); + unary_op(params, dst); } void ggml_compute_forward_tanh(const ggml_compute_params * params, ggml_tensor * dst) { - unary_op(op_tanh, params, dst); + unary_op(params, dst); } void ggml_compute_forward_elu(const ggml_compute_params * params, ggml_tensor * dst) { - unary_op(op_elu, params, dst); + unary_op(params, dst); } void ggml_compute_forward_relu(const ggml_compute_params * params, ggml_tensor * dst) { - unary_op(op_relu, params, dst); + unary_op(params, dst); } void ggml_compute_forward_sigmoid(const ggml_compute_params * params, ggml_tensor * dst) { - unary_op(op_sigmoid, params, dst); + unary_op(params, dst); } void ggml_compute_forward_hardsigmoid(const ggml_compute_params * params, ggml_tensor * dst) { - unary_op(op_hardsigmoid, params, dst); + unary_op(params, dst); } void ggml_compute_forward_exp(const ggml_compute_params * params, ggml_tensor * dst) { - unary_op(op_exp, params, dst); + unary_op(params, dst); } void ggml_compute_forward_hardswish(const ggml_compute_params * params, ggml_tensor * dst) { - unary_op(op_hardswish, params, dst); + unary_op(params, dst); } void ggml_compute_forward_sqr(const ggml_compute_params * params, ggml_tensor * dst) { - unary_op(op_sqr, params, dst); + unary_op(params, dst); } void ggml_compute_forward_sqrt(const ggml_compute_params * params, ggml_tensor * dst) { - unary_op(op_sqrt, params, dst); + unary_op(params, dst); } void ggml_compute_forward_sin(const ggml_compute_params * params, ggml_tensor * dst) { - unary_op(op_sin, params, dst); + unary_op(params, dst); } void ggml_compute_forward_cos(const ggml_compute_params * params, ggml_tensor * dst) { - unary_op(op_cos, params, dst); + unary_op(params, dst); } void ggml_compute_forward_log(const ggml_compute_params * params, ggml_tensor * dst) { - unary_op(op_log, params, dst); + unary_op(params, dst); } void ggml_compute_forward_xielu(const ggml_compute_params * params, ggml_tensor * dst) { - // Get the XIELU parameters from the operation - float alpha_n = ggml_get_op_params_f32(dst, 1); - float alpha_p = ggml_get_op_params_f32(dst, 2); + const float alpha_n = ggml_get_op_params_f32(dst, 1); + const float alpha_p = ggml_get_op_params_f32(dst, 2); const float beta = ggml_get_op_params_f32(dst, 3); const float eps = ggml_get_op_params_f32(dst, 4); @@ -205,6 +284,6 @@ void ggml_compute_forward_xielu(const ggml_compute_params * params, ggml_tensor return op_xielu(f, alpha_n, alpha_p, beta, eps); }; - unary_op(xielu_op_params, params, dst); + unary_op_functor(params, dst, xielu_op_params); } From 5a02bd486ba7c6c458ddf06e90c9bebb5a3ed6ca Mon Sep 17 00:00:00 2001 From: "Piotr Wilkin (ilintar)" Date: Thu, 25 Sep 2025 21:27:41 +0200 Subject: [PATCH 29/32] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Sigbjørn Skjæret --- convert_hf_to_gguf.py | 2 +- gguf-py/gguf/constants.py | 12 ++++------ gguf-py/gguf/gguf_writer.py | 16 ++++++------- gguf-py/gguf/tensor_mapping.py | 44 ---------------------------------- src/llama-model.cpp | 27 ++++++++++----------- 5 files changed, 26 insertions(+), 75 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index c64c8a8166ef3..911400d1b82f2 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -8772,7 +8772,7 @@ class ApertusModel(LlamaModel): def modify_tensors(self, data_torch, name, bid): # Handle xIELU activation parameters - n_layers = self.hparams.get("num_hidden_layers") + n_layers = self.hparams["num_hidden_layers"] if name.endswith(".act_fn.alpha_n"): self._alpha_n[bid] = data_torch.to("cpu").float().item() if (len(self._alpha_n) == n_layers): diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index dfa3e8d841b27..e0ba8d0b6c5c6 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -295,10 +295,10 @@ class Diffusion: SHIFT_LOGITS = "diffusion.shift_logits" class xIELU: - XIELU_ALPHA_P = "xielu.alpha_p" - XIELU_ALPHA_N = "xielu.alpha_n" - XIELU_BETA = "xielu.beta" - XIELU_EPS = "xielu.eps" + ALPHA_P = "xielu.alpha_p" + ALPHA_N = "xielu.alpha_n" + BETA = "xielu.beta" + EPS = "xielu.eps" # @@ -458,10 +458,6 @@ class MODEL_TENSOR(IntEnum): FFN_GATE_SHEXP = auto() FFN_DOWN_SHEXP = auto() FFN_UP_SHEXP = auto() - FFN_ACT_ALPHA_N = auto() - FFN_ACT_ALPHA_P = auto() - FFN_ACT_BETA = auto() - FFN_ACT_EPS = auto() FFN_EXP_PROBS_B = auto() ATTN_Q_NORM = auto() ATTN_K_NORM = auto() diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 1a18bfbf75604..724c21d2c14db 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -1075,17 +1075,17 @@ def add_audio_num_mel_bins(self, value: int) -> None: def add_audio_stack_factor(self, value: int) -> None: self.add_uint32(Keys.ClipAudio.Projector.STACK_FACTOR, value) - def add_xielu_alpha_p(self, value: Sequence[float]): - self.add_array(Keys.xIELU.XIELU_ALPHA_P, value) + def add_xielu_alpha_p(self, values: Sequence[float]): + self.add_array(Keys.xIELU.ALPHA_P, values) - def add_xielu_alpha_n(self, value: Sequence[float]): - self.add_array(Keys.xIELU.XIELU_ALPHA_N, value) + def add_xielu_alpha_n(self, values: Sequence[float]): + self.add_array(Keys.xIELU.ALPHA_N, values) - def add_xielu_beta(self, value: Sequence[float]): - self.add_array(Keys.xIELU.XIELU_BETA, value) + def add_xielu_beta(self, values: Sequence[float]): + self.add_array(Keys.xIELU.BETA, values) - def add_xielu_eps(self, value: Sequence[float]): - self.add_array(Keys.xIELU.XIELU_EPS, value) + def add_xielu_eps(self, values: Sequence[float]): + self.add_array(Keys.xIELU.EPS, values) # diffusion models diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index df264f1739c44..8fd9e454e0e81 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -451,22 +451,6 @@ class TensorNameMap: "layers.{bid}.mlp.gate_proj", # qwen3-embedding ), - MODEL_TENSOR.FFN_ACT_ALPHA_N: ( - "model.layers.{bid}.mlp.act_fn.alpha_n", # apertus xIELU - ), - - MODEL_TENSOR.FFN_ACT_ALPHA_P: ( - "model.layers.{bid}.mlp.act_fn.alpha_p", # apertus xIELU - ), - - MODEL_TENSOR.FFN_ACT_BETA: ( - "model.layers.{bid}.mlp.act_fn.beta", # apertus xIELU - ), - - MODEL_TENSOR.FFN_ACT_EPS: ( - "model.layers.{bid}.mlp.act_fn.eps", # apertus xIELU - ), - MODEL_TENSOR.FFN_GATE_EXP: ( "layers.{bid}.feed_forward.experts.w1", # mixtral (merged) "transformer.decoder_layer.{bid}.moe.linear", # Grok (merged) @@ -1491,34 +1475,6 @@ class TensorNameMap: "model.layers.{bid}.post_attention_layernorm", ), }, - MODEL_ARCH.APERTUS: { - MODEL_TENSOR.ATTN_NORM: ( - "model.layers.{bid}.attention_layernorm", - ), - MODEL_TENSOR.ATTN_Q_NORM: ( - "model.layers.{bid}.attention.query_layernorm", - "model.layers.{bid}.self_attn.q_norm", - ), - MODEL_TENSOR.ATTN_K_NORM: ( - "model.layers.{bid}.attention.key_layernorm", - "model.layers.{bid}.self_attn.k_norm", - ), - MODEL_TENSOR.FFN_NORM: ( - "model.layers.{bid}.feedforward_layernorm", - ), - MODEL_TENSOR.FFN_ACT_ALPHA_N: ( - "model.layers.{bid}.mlp.act_fn.alpha_n", - ), - MODEL_TENSOR.FFN_ACT_ALPHA_P: ( - "model.layers.{bid}.mlp.act_fn.alpha_p", - ), - MODEL_TENSOR.FFN_ACT_BETA: ( - "model.layers.{bid}.mlp.act_fn.beta", - ), - MODEL_TENSOR.FFN_ACT_EPS: ( - "model.layers.{bid}.mlp.act_fn.eps", - ), - }, } mapping: dict[str, tuple[MODEL_TENSOR, str]] diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 23e07cec3f896..a07288cb91007 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -513,10 +513,10 @@ void llama_model::load_hparams(llama_model_loader & ml) { std::fill(hparams.rope_sections.begin(), hparams.rope_sections.end(), 0); std::fill(hparams.swa_layers.begin(), hparams.swa_layers.end(), 0); - std::fill(hparams.xielu_alpha_n.begin(), hparams.xielu_alpha_n.end(), 0); - std::fill(hparams.xielu_alpha_p.begin(), hparams.xielu_alpha_p.end(), 0); - std::fill(hparams.xielu_beta.begin(), hparams.xielu_beta.end(), 0); - std::fill(hparams.xielu_eps.begin(), hparams.xielu_eps.end(), 0); + std::fill(hparams.xielu_alpha_n.begin(), hparams.xielu_alpha_n.end(), 0.0f); + std::fill(hparams.xielu_alpha_p.begin(), hparams.xielu_alpha_p.end(), 0.0f); + std::fill(hparams.xielu_beta.begin(), hparams.xielu_beta.end(), 0.0f); + std::fill(hparams.xielu_eps.begin(), hparams.xielu_eps.end(), 0.0f); ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, hparams.n_layer, false); ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer, false); @@ -2014,10 +2014,10 @@ void llama_model::load_hparams(llama_model_loader & ml) { case LLM_ARCH_APERTUS: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key_or_arr(LLM_KV_XIELU_ALPHA_N, hparams.xielu_alpha_n, hparams.n_layer); - ml.get_key_or_arr(LLM_KV_XIELU_ALPHA_P, hparams.xielu_alpha_p, hparams.n_layer); - ml.get_key_or_arr(LLM_KV_XIELU_BETA, hparams.xielu_beta, hparams.n_layer); - ml.get_key_or_arr(LLM_KV_XIELU_EPS, hparams.xielu_eps, hparams.n_layer); + ml.get_key_or_arr(LLM_KV_XIELU_ALPHA_N, hparams.xielu_alpha_n, hparams.n_layer); + ml.get_key_or_arr(LLM_KV_XIELU_ALPHA_P, hparams.xielu_alpha_p, hparams.n_layer); + ml.get_key_or_arr(LLM_KV_XIELU_BETA, hparams.xielu_beta, hparams.n_layer); + ml.get_key_or_arr(LLM_KV_XIELU_EPS, hparams.xielu_eps, hparams.n_layer); switch (hparams.n_layer) { case 32: type = LLM_TYPE_8B; break; @@ -5858,7 +5858,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // output output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, 0); for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; @@ -5866,11 +5866,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { - layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - } - else { - layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } else { + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); } layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, 0); From 90f052d363a3fc10fe2362f6d89a90bdad2571d2 Mon Sep 17 00:00:00 2001 From: Piotr Wilkin Date: Thu, 25 Sep 2025 21:32:00 +0200 Subject: [PATCH 30/32] Add tensor mappings for Apertus to global list instead --- gguf-py/gguf/tensor_mapping.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 8fd9e454e0e81..8cfe7bac80a81 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -148,6 +148,7 @@ class TensorNameMap: "model.layers.{bid}.operator_norm", # lfm2 "model.transformer.blocks.{bid}.attn_norm", # llada "layers.{bid}.input_layernorm", # qwen3-embedding + "model.layers.{bid}.attention_layernorm" # apertus ), # Attention norm 2 @@ -325,6 +326,7 @@ class TensorNameMap: "model.layers.layers.{bid}.pre_mlp_norm", # plamo2 "model.transformer.blocks.{bid}.ff_norm", # llada "layers.{bid}.post_attention_layernorm", # qwen3-embedding + "model.layers.{bid}.feedforward_layernorm", # apertus ), # Post feed-forward norm @@ -535,6 +537,7 @@ class TensorNameMap: "transformer.layers.{bid}.attn.q_norm", # openelm "model.layers.layers.{bid}.mixer.q", # plamo2 "layers.{bid}.self_attn.q_norm", # qwen3-embedding + "model.layers.{bid}.attention.query_layernorm", # apertus ), MODEL_TENSOR.ATTN_K_NORM: ( @@ -548,6 +551,7 @@ class TensorNameMap: "transformer.layers.{bid}.attn.k_norm", # openelm "model.layers.layers.{bid}.mixer.k", # plamo2 "layers.{bid}.self_attn.k_norm", # qwen3-embedding + "model.layers.{bid}.attention.key_layernorm", # apertus ), MODEL_TENSOR.ROPE_FREQS: ( From 6972404f7efa9dca7f114a90a3442be427dedc29 Mon Sep 17 00:00:00 2001 From: Piotr Wilkin Date: Thu, 25 Sep 2025 21:42:17 +0200 Subject: [PATCH 31/32] Fix lazy on scalars --- convert_hf_to_gguf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 911400d1b82f2..c8532782fc916 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -8964,7 +8964,7 @@ def meta_with_dtype_and_shape(cls, dtype: torch.dtype, shape: tuple[int, ...]) - def from_safetensors_slice(cls, st_slice: Any) -> Tensor: dtype = cls._dtype_str_map[st_slice.get_dtype()] shape: tuple[int, ...] = tuple(st_slice.get_shape()) - lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(st_slice,), func=lambda s: s[:]) + lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(st_slice,), func=lambda s: s[...] if len(s.get_shape()) == 0 else s[:]) return cast(torch.Tensor, lazy) @classmethod From 2ca353bedd8e3a77d00a519e54055be3e823cd4a Mon Sep 17 00:00:00 2001 From: "Piotr Wilkin (ilintar)" Date: Thu, 25 Sep 2025 22:41:07 +0200 Subject: [PATCH 32/32] Update ggml/src/ggml-cuda/unary.cu MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Johannes Gäßler --- ggml/src/ggml-cuda/unary.cu | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index 17a946bc8715d..3c564566a51ff 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -394,14 +394,7 @@ static __global__ void xielu_kernel(const T * x, T * dst, const int k, float alp const float y_neg = (expm1f(min_v_eps) - xi) * alpha_n + beta * xi; const float out = gate_pos * y_pos + (1.0f - gate_pos) * y_neg; -#if __CUDA_ARCH__ >= 530 - if constexpr (std::is_same::value) { - dst[i] = __float2half(out); - } else -#endif - { - dst[i] = (T)out; - } + dst[i] = ggml_cuda_cast(out); } template