Skip to content

Conversation

mitmul
Copy link
Contributor

@mitmul mitmul commented Sep 18, 2025

I found some wrong parameter usages to convert and load the original parameters into the llama.cpp model of PLaMo2. I fixed them in this PR.

@CISC
Copy link
Collaborator

CISC commented Sep 18, 2025

This looks like it will break num_attention_heads on old GGUFs?

@github-actions github-actions bot added the python python script changes label Sep 18, 2025
@mitmul
Copy link
Contributor Author

mitmul commented Sep 18, 2025

This looks like it will break num_attention_heads on old GGUFs?

You're right. Thanks. I updated this PR as in 10af12c to support existing GGUFs converted previously. I confirmed that the latest commit successfully runs a GGUF, mmnga/plamo-2-translate-gguf without any problems.

Copy link
Collaborator

@CISC CISC left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAICT there are no functional changes here besides making n_embd_head_k/v bound to hidden_size_per_head instead of calculated from hidden_size/num_attention_heads, n_head is not used in mamba layers anyway, what is the motivation for the PR?

Comment on lines 1080 to 1083

// Load attention parameters
ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k, false);
ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Load attention parameters
ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k, false);
ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false);

Already done here:

// non-transformer models do not have attention heads
if (hparams.n_head() > 0) {
// gpt-neox n_rot = rotary_pct * (n_embd / n_head)
// gpt-j n_rot = rotary_dim
hparams.n_embd_head_k = hparams.n_embd / hparams.n_head();
ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k, false);
hparams.n_embd_head_v = hparams.n_embd / hparams.n_head();
ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false);
// sanity check for n_rot (optional)
hparams.n_rot = hparams.n_embd_head_k;
ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false);
if (arch == LLM_ARCH_LLAMA || arch == LLM_ARCH_DECI || arch == LLM_ARCH_FALCON) {
if (hparams.n_rot != hparams.n_embd_head_k) {
throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k));
}
}
} else {
hparams.n_rot = 0;
hparams.n_embd_head_k = 0;
hparams.n_embd_head_v = 0;
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@CISC Ah, I applied your suggestion but noticed that the problem comes from the following two lines:

hparams.n_embd_head_k = hparams.n_embd / hparams.n_head();

hparams.n_embd_head_v = hparams.n_embd / hparams.n_head();

The main purpose of this PR is to support some variations of PLaMo2 model created by pruning larger models, that has its n_embed_head_k larger than n_embd / n_head.

So let me roll back this changes to support such cases in a variant of PLaMo2 models.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main purpose of this PR is to support some variations of PLaMo2 model created by pruning larger models, that has its n_embed_head_k larger than n_embd / n_head.

So let me roll back this changes to support such cases in a variant of PLaMo2 models.

But they would have to have the metadata present to work then, so no issue.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mitmul gentle ping

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@CISC Sorry for my late reaction. Um, could you explain what you meant with this?

But they would have to have the metadata present to work then, so no issue.

For example, pfnet/plamo-2.1-2b-cpt has

  • hidden_size = 2048 (for n_embd)
  • hidden_size_per_head = 128 (for n_embd_head_k and n_embd_head_v)
  • num_key_value_heads = 32 (for n_head_arr in il % 2 = 0 layers (attention layers, not mamba layers)),

so that n_embd_head_k/v != n_embd / n_head.
Then, with the current load_hparams() it goes through the else block here:

} else {
hparams.n_rot = 0;
hparams.n_embd_head_k = 0;
hparams.n_embd_head_v = 0;
}

because hparams.n_head() will return 0 that comes from the first element of n_head_arr indicating mamba layers.

So, I thought we need to call the following functions to set hparams.n_embd_head_k/v with the right values comes from hidden_size_per_head in config.json:

llama.cpp/src/llama-model.cpp

Lines 1082 to 1083 in 1be2787

ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k, false);
ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false);

That's why I sent this PR, but if I wrongly understand something around model loading, please let me know it and I'd appreciate it if you would give me some advices about how we can load such variants like pfnet/plamo-2.1-2b-cpt.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aha, that makes sense then, thank you for the explanation. :)

@mitmul
Copy link
Contributor Author

mitmul commented Sep 22, 2025

The core intention of this PR is to use n_embd_head_k/v which is set by hidden_size_per_head in config.json instead of calculating hidden_size / num_attention_heads for it. Because in some variants of PLaMo2 models, we use larger n_embd_head_k/v than hidden_size / num_attention_heads.

@CISC CISC merged commit ded67b9 into ggml-org:master Oct 1, 2025
62 of 72 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
python python script changes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants