Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
64add82
First attempt
pwilkin Sep 7, 2025
86cfc18
No permute during convert (fixes qk tensors), proper norm application.
pwilkin Sep 8, 2025
8c762a3
RoPE = NeoX
pwilkin Sep 8, 2025
ffdfd1d
Coherence!
pwilkin Sep 8, 2025
74dcf89
Merge branch 'ggml-org:master' into apertus-implementation
pwilkin Sep 9, 2025
ab11d94
Migrate xielu params from tensors to hyperparameters
pwilkin Sep 9, 2025
1b18472
Simple CUDA kernel
pwilkin Sep 13, 2025
1606a3c
Revert stupid LLM refactorings
pwilkin Sep 14, 2025
eec384f
Chat template support
pwilkin Sep 14, 2025
b2a92d0
configchecker / flake8 errors
pwilkin Sep 14, 2025
ef9ef66
Reorder unary.cu
pwilkin Sep 15, 2025
d009194
I do conclude that LLMs are, in fact, stupid.
pwilkin Sep 15, 2025
73bd64f
Merge branch 'master' into apertus-implementation
pwilkin Sep 16, 2025
28f9086
Fix after merge
pwilkin Sep 16, 2025
2f68c03
Final newline
pwilkin Sep 16, 2025
4294dbf
Make xIELU an UNARY_OP
pwilkin Sep 17, 2025
b2aa4fb
Final newline
pwilkin Sep 17, 2025
86a239c
Correctly account for parameter shift
pwilkin Sep 17, 2025
bd19026
Argh.
pwilkin Sep 17, 2025
13cc3be
Update ggml/src/ggml-cpu/unary-ops.cpp
pwilkin Sep 18, 2025
40f2f80
Refactor: remove unused methods, inline and factorize softplus, add c…
pwilkin Sep 18, 2025
dbe0ccb
Merge branch 'master' into apertus-implementation
pwilkin Sep 23, 2025
0d36139
Revert CUDA changes, implement xIELU as a separate OP
pwilkin Sep 23, 2025
fc58d3b
Pesky newline
pwilkin Sep 23, 2025
db9eb29
Add float2half / half2float for F16 inputs/outputs
pwilkin Sep 23, 2025
dc1e4d5
CUDA variants, attempt 2
pwilkin Sep 24, 2025
6c843ce
Actually, attempt 3
pwilkin Sep 24, 2025
f11ab3c
Update ggml/src/ggml-cuda/unary.cu
pwilkin Sep 25, 2025
57e2263
Missing convert header
pwilkin Sep 25, 2025
62401d8
Proper formula and reference for xIELU in the comments.
pwilkin Sep 25, 2025
58e6e0f
Modify unary-ops.cpp to add the functor-based logic besides the templ…
pwilkin Sep 25, 2025
5a02bd4
Apply suggestions from code review
pwilkin Sep 25, 2025
90f052d
Add tensor mappings for Apertus to global list instead
pwilkin Sep 25, 2025
6972404
Fix lazy on scalars
pwilkin Sep 25, 2025
d29cb45
Merge remote-tracking branch 'pwilkin/master' into apertus-implementa…
pwilkin Sep 25, 2025
2ca353b
Update ggml/src/ggml-cuda/unary.cu
pwilkin Sep 25, 2025
66f37c0
Add comment about the constraints on positive/negative alpha
pwilkin Oct 2, 2025
78ac06b
Change `softplus` to `ggml_softplus`
pwilkin Oct 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -8566,6 +8566,17 @@ def prepare_tensors(self):
raise ValueError(f"Unprocessed experts: {experts}")


@ModelBase.register("ApertusForCausalLM")
class ApertusModel(LlamaModel):
model_arch = gguf.MODEL_ARCH.APERTUS

def modify_tensors(self, data_torch, name, bid):
# Handle xIELU activation parameters
if name.endswith(".act_fn.alpha_n") or name.endswith(".act_fn.alpha_p") or name.endswith(".act_fn.beta") or name.endswith(".act_fn.eps"):
return [(self.map_tensor_name(name), data_torch)]

return super().modify_tensors(data_torch, name, bid)

class MistralModel(LlamaModel):
model_arch = gguf.MODEL_ARCH.LLAMA
model_name = "Mistral"
Expand Down
11 changes: 11 additions & 0 deletions ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,7 @@ extern "C" {
GGML_OP_OPT_STEP_SGD,

GGML_OP_GLU,
GGML_OP_XIELU,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason not to put this along the rest of the unary ops like RELU, GELU, etc.?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really :>

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ggerganov Done.


GGML_OP_COUNT,
};
Expand Down Expand Up @@ -1148,6 +1149,16 @@ extern "C" {
struct ggml_context * ctx,
struct ggml_tensor * a);

// xIELU activation function
// x = x * (alpha_n + alpha_p * sigmoid(beta * x)) + eps * (x > 0)
GGML_API struct ggml_tensor * ggml_xielu(
struct ggml_context * ctx,
struct ggml_tensor * a,
float alpha_n,
float alpha_p,
float beta,
float eps);

// gated linear unit ops
// A: n columns, r rows,
// result is n / 2 columns, r rows,
Expand Down
5 changes: 5 additions & 0 deletions ggml/src/ggml-cpu/ggml-cpu.c
Original file line number Diff line number Diff line change
Expand Up @@ -1971,6 +1971,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_unary(params, tensor);
} break;
case GGML_OP_XIELU:
{
ggml_compute_forward_xielu(params, tensor);
} break;
case GGML_OP_GLU:
{
ggml_compute_forward_glu(params, tensor);
Expand Down Expand Up @@ -2160,6 +2164,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_OP_REPEAT:
case GGML_OP_REPEAT_BACK:
case GGML_OP_LEAKY_RELU:
case GGML_OP_XIELU:
{
n_tasks = 1;
} break;
Expand Down
89 changes: 61 additions & 28 deletions ggml/src/ggml-cpu/unary-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,15 @@ static inline float op_sqrt(float x) {
return sqrtf(x);
}

static inline float op_xielu(float x, float alpha_n, float alpha_p, float beta, float eps) {
if (x > 0.0f) {
return alpha_p * x * x + beta * x;
} else {
const float min_x_eps = fminf(x, eps);
return (expm1f(min_x_eps) - x) * alpha_n + beta * x;
}
}

static inline float op_sin(float x) {
return sinf(x);
}
Expand All @@ -64,8 +73,8 @@ static inline float op_log(float x) {
return logf(x);
}

template <float (*op)(float), typename src0_t, typename dst_t>
static inline void vec_unary_op(int64_t n, dst_t * y, const src0_t * x) {
template <typename Op, typename src0_t, typename dst_t>
static inline void vec_unary_op(const Op& op, int64_t n, dst_t * y, const src0_t * x) {
constexpr auto src0_to_f32 = type_conversion_table<src0_t>::to_f32;
constexpr auto f32_to_dst = type_conversion_table<dst_t >::from_f32;

Expand All @@ -74,8 +83,8 @@ static inline void vec_unary_op(int64_t n, dst_t * y, const src0_t * x) {
}
}

template <float (*op)(float), typename src0_t, typename dst_t>
static void apply_unary_op(const ggml_compute_params * params, ggml_tensor * dst) {
template <typename Op, typename src0_t, typename dst_t>
static void apply_unary_op(const Op& op, const ggml_compute_params * params, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];

GGML_ASSERT(ggml_is_contiguous_1(src0) && ggml_is_contiguous_1(dst) && ggml_are_same_shape(src0, dst));
Expand All @@ -95,25 +104,25 @@ static void apply_unary_op(const ggml_compute_params * params, ggml_tensor * dst
dst_t * dst_ptr = (dst_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
const src0_t * src0_ptr = (const src0_t *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);

vec_unary_op<op>(ne0, dst_ptr, src0_ptr);
vec_unary_op<decltype(op), src0_t, dst_t>(op, ne0, dst_ptr, src0_ptr);
}
}

// TODO: Use the 'traits' lookup table (for type conversion fns), instead of a mass of 'if' conditions with long templates
template <float (*op)(float)>
static void unary_op(const ggml_compute_params * params, ggml_tensor * dst) {
template <typename Op>
static void unary_op(const Op& op, const ggml_compute_params * params, ggml_tensor * dst) {
Copy link
Collaborator

@ngxson ngxson Sep 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of passing the function pointer here (which comes with the risk of performance degrade), you can simply change the op signature from float (*op)(float) to float (*op)(float, ggml_tensor *)

Having both template and decltype(op) here is pretty much redundant as they will be all compiled into one single version of the function, which against the reason why we usetemplate in the first place.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ngxson Is this still a concern - I'm not sure if the current version has addressed your comment or not.

const ggml_tensor * src0 = dst->src[0];

/* */ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { // all f32
apply_unary_op<op, float, float>(params, dst);
apply_unary_op<decltype(op), float, float>(op, params, dst);
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { // all f16
apply_unary_op<op, ggml_fp16_t, ggml_fp16_t>(params, dst);
apply_unary_op<decltype(op), ggml_fp16_t, ggml_fp16_t>(op, params, dst);
} else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_BF16) { // all bf16
apply_unary_op<op, ggml_bf16_t, ggml_bf16_t>(params, dst);
apply_unary_op<decltype(op), ggml_bf16_t, ggml_bf16_t>(op, params, dst);
} else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_F32) {
apply_unary_op<op, ggml_bf16_t, float>(params, dst);
apply_unary_op<decltype(op), ggml_bf16_t, float>(op, params, dst);
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
apply_unary_op<op, ggml_fp16_t, float>(params, dst);
apply_unary_op<decltype(op), ggml_fp16_t, float>(op, params, dst);
} else {
fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s\n", __func__,
ggml_type_name(dst->type), ggml_type_name(src0->type));
Expand All @@ -122,65 +131,89 @@ static void unary_op(const ggml_compute_params * params, ggml_tensor * dst) {
}

void ggml_compute_forward_abs(const ggml_compute_params * params, ggml_tensor * dst) {
unary_op<op_abs>(params, dst);
unary_op(op_abs, params, dst);
}

void ggml_compute_forward_sgn(const ggml_compute_params * params, ggml_tensor * dst) {
unary_op<op_sgn>(params, dst);
unary_op(op_sgn, params, dst);
}

void ggml_compute_forward_neg(const ggml_compute_params * params, ggml_tensor * dst) {
unary_op<op_neg>(params, dst);
unary_op(op_neg, params, dst);
}

void ggml_compute_forward_step(const ggml_compute_params * params, ggml_tensor * dst) {
unary_op<op_step>(params, dst);
unary_op(op_step, params, dst);
}

void ggml_compute_forward_tanh(const ggml_compute_params * params, ggml_tensor * dst) {
unary_op<op_tanh>(params, dst);
unary_op(op_tanh, params, dst);
}

void ggml_compute_forward_elu(const ggml_compute_params * params, ggml_tensor * dst) {
unary_op<op_elu>(params, dst);
unary_op(op_elu, params, dst);
}

void ggml_compute_forward_relu(const ggml_compute_params * params, ggml_tensor * dst) {
unary_op<op_relu>(params, dst);
unary_op(op_relu, params, dst);
}

void ggml_compute_forward_sigmoid(const ggml_compute_params * params, ggml_tensor * dst) {
unary_op<op_sigmoid>(params, dst);
unary_op(op_sigmoid, params, dst);
}

void ggml_compute_forward_hardsigmoid(const ggml_compute_params * params, ggml_tensor * dst) {
unary_op<op_hardsigmoid>(params, dst);
unary_op(op_hardsigmoid, params, dst);
}

void ggml_compute_forward_exp(const ggml_compute_params * params, ggml_tensor * dst) {
unary_op<op_exp>(params, dst);
unary_op(op_exp, params, dst);
}

void ggml_compute_forward_hardswish(const ggml_compute_params * params, ggml_tensor * dst) {
unary_op<op_hardswish>(params, dst);
unary_op(op_hardswish, params, dst);
}

void ggml_compute_forward_sqr(const ggml_compute_params * params, ggml_tensor * dst) {
unary_op<op_sqr>(params, dst);
unary_op(op_sqr, params, dst);
}

void ggml_compute_forward_sqrt(const ggml_compute_params * params, ggml_tensor * dst) {
unary_op<op_sqrt>(params, dst);
unary_op(op_sqrt, params, dst);
}

void ggml_compute_forward_sin(const ggml_compute_params * params, ggml_tensor * dst) {
unary_op<op_sin>(params, dst);
unary_op(op_sin, params, dst);
}

void ggml_compute_forward_cos(const ggml_compute_params * params, ggml_tensor * dst) {
unary_op<op_cos>(params, dst);
unary_op(op_cos, params, dst);
}

void ggml_compute_forward_log(const ggml_compute_params * params, ggml_tensor * dst) {
unary_op<op_log>(params, dst);
unary_op(op_log, params, dst);
}

static float softplus(float input, float beta=1.0f, float threshold=20.0f) {
if (input * beta > threshold) return input;
return (1/beta) * logf(1 + expf(beta * input));
}

void ggml_compute_forward_xielu(const ggml_compute_params * params, ggml_tensor * dst) {
// Get the XIELU parameters from the operation
const float * op_params = (const float*)dst->op_params;
float alpha_n = op_params[0];
float alpha_p = op_params[1];
const float beta = op_params[2];
const float eps = op_params[3];

// alpha_p = softplus(alpha_p);
// alpha_n = beta + softplus(alpha_n);

const auto xielu_op_params = [alpha_n, alpha_p, beta, eps](float f) {
return op_xielu(f, alpha_n, alpha_p, beta, eps);
};

unary_op(xielu_op_params, params, dst);
}

1 change: 1 addition & 0 deletions ggml/src/ggml-cpu/unary-ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ void ggml_compute_forward_sqrt(const struct ggml_compute_params * params, struct
void ggml_compute_forward_sin(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_cos(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_log(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_xielu(const struct ggml_compute_params * params, struct ggml_tensor * dst);

#ifdef __cplusplus
}
Expand Down
27 changes: 24 additions & 3 deletions ggml/src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -1017,9 +1017,10 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"OPT_STEP_SGD",

"GLU",
"XIELU",
};

static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90");
static_assert(GGML_OP_COUNT == 91, "GGML_OP_COUNT != 90");

static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
Expand Down Expand Up @@ -1121,9 +1122,10 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"sgd(x)",

"glu(x)",
"xielu(x)",
};

static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90");
static_assert(GGML_OP_COUNT == 91, "GGML_OP_COUNT != 90");

static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");

Expand All @@ -1147,7 +1149,6 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {

static_assert(GGML_UNARY_OP_COUNT == 15, "GGML_UNARY_OP_COUNT != 15");


static const char * GGML_GLU_OP_NAME[GGML_GLU_OP_COUNT] = {
"REGLU",
"GEGLU",
Expand Down Expand Up @@ -2646,6 +2647,26 @@ struct ggml_tensor * ggml_silu(
return ggml_unary(ctx, a, GGML_UNARY_OP_SILU);
}

// ggml_xielu
struct ggml_tensor * ggml_xielu(
struct ggml_context * ctx,
struct ggml_tensor * a,
float alpha_n,
float alpha_p,
float beta,
float eps) {
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);

// Store the parameters as operation parameters
float params[] = { alpha_n, alpha_p, beta, eps };
ggml_set_op_params(result, params, sizeof(params));

result->op = GGML_OP_XIELU;
result->src[0] = a;

return result;
}

struct ggml_tensor * ggml_silu_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a) {
Expand Down
36 changes: 36 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,7 @@ class MODEL_ARCH(IntEnum):
SMALLTHINKER = auto()
LLADA = auto()
SEED_OSS = auto()
APERTUS = auto()


class VISION_PROJECTOR_TYPE(IntEnum):
Expand Down Expand Up @@ -440,6 +441,10 @@ class MODEL_TENSOR(IntEnum):
FFN_GATE_SHEXP = auto()
FFN_DOWN_SHEXP = auto()
FFN_UP_SHEXP = auto()
FFN_ACT_ALPHA_N = auto()
FFN_ACT_ALPHA_P = auto()
FFN_ACT_BETA = auto()
FFN_ACT_EPS = auto()
FFN_EXP_PROBS_B = auto()
ATTN_Q_NORM = auto()
ATTN_K_NORM = auto()
Expand Down Expand Up @@ -727,6 +732,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.SMALLTHINKER: "smallthinker",
MODEL_ARCH.LLADA: "llada",
MODEL_ARCH.SEED_OSS: "seed_oss",
MODEL_ARCH.APERTUS: "apertus",
}

VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
Expand Down Expand Up @@ -773,12 +779,20 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_GATE_SHEXP: "blk.{bid}.ffn_gate_shexp",
MODEL_TENSOR.FFN_DOWN_SHEXP: "blk.{bid}.ffn_down_shexp",
MODEL_TENSOR.FFN_UP_SHEXP: "blk.{bid}.ffn_up_shexp",
MODEL_TENSOR.FFN_ACT_ALPHA_N: "blk.{bid}.ffn_act_alpha_n",
MODEL_TENSOR.FFN_ACT_ALPHA_P: "blk.{bid}.ffn_act_alpha_p",
MODEL_TENSOR.FFN_ACT_BETA: "blk.{bid}.ffn_act_beta",
MODEL_TENSOR.FFN_ACT_EPS: "blk.{bid}.ffn_act_eps",
MODEL_TENSOR.FFN_ACT: "blk.{bid}.ffn",
MODEL_TENSOR.FFN_NORM_EXP: "blk.{bid}.ffn_norm_exps",
MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate_exps",
MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down_exps",
MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up_exps",
MODEL_TENSOR.FFN_EXP_PROBS_B: "blk.{bid}.exp_probs_b",
MODEL_TENSOR.FFN_ACT_ALPHA_N: "blk.{bid}.ffn_act_alpha_n",
MODEL_TENSOR.FFN_ACT_ALPHA_P: "blk.{bid}.ffn_act_alpha_p",
MODEL_TENSOR.FFN_ACT_BETA: "blk.{bid}.ffn_act_beta",
MODEL_TENSOR.FFN_ACT_EPS: "blk.{bid}.ffn_act_eps",
MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm",
MODEL_TENSOR.PER_LAYER_TOKEN_EMBD: "per_layer_token_embd", # gemma3n
MODEL_TENSOR.PER_LAYER_MODEL_PROJ: "per_layer_model_proj", # gemma3n
Expand Down Expand Up @@ -2683,6 +2697,28 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
],
MODEL_ARCH.APERTUS: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_ROT_EMBD,
MODEL_TENSOR.ATTN_Q_NORM,
MODEL_TENSOR.ATTN_K_NORM,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.FFN_ACT_ALPHA_N,
MODEL_TENSOR.FFN_ACT_ALPHA_P,
MODEL_TENSOR.FFN_ACT_BETA,
MODEL_TENSOR.FFN_ACT_EPS,
],
# TODO
}

Expand Down
Loading
Loading