Skip to content

Commit ded67b9

Browse files
mitmulCISC
andauthored
llama : parameter conversion and loading fixes for PLaMo2 variants (#16075)
* Fix to use hidden_size_per_head * Fix num heads * Fix array * Fix loading weights * Support old GGUF converted by the previous version of llama.cpp * Update src/llama-model.cpp Co-authored-by: Sigbjørn Skjæret <[email protected]> * Move shared parameter definitions to the outside of loop * Not calculating n_embd_head_k,v by n_embd / n_head --------- Co-authored-by: Sigbjørn Skjæret <[email protected]>
1 parent 1fe4e38 commit ded67b9

File tree

3 files changed

+27
-15
lines changed

3 files changed

+27
-15
lines changed

convert_hf_to_gguf.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4250,7 +4250,8 @@ def set_gguf_parameters(self):
42504250
# This logic matches modeling_plamo.py's is_mamba function
42514251
mamba_step = hparams.get("mamba_step", 2)
42524252
mamba_enabled = hparams.get("mamba_enabled", True)
4253-
mamba_layers = []
4253+
num_key_value_heads = []
4254+
num_attention_heads = []
42544255

42554256
if mamba_enabled:
42564257
for i in range(block_count):
@@ -4260,17 +4261,21 @@ def set_gguf_parameters(self):
42604261
else:
42614262
is_mamba = (i % mamba_step) != (mamba_step // 2)
42624263
if is_mamba:
4263-
mamba_layers.append(0)
4264+
num_key_value_heads.append(0)
4265+
num_attention_heads.append(0)
42644266
else:
4265-
mamba_layers.append(hparams.get("num_key_value_heads", 4))
4267+
num_key_value_heads.append(hparams.get("num_key_value_heads", 4))
4268+
num_attention_heads.append(hparams.get("num_attention_heads", 32))
42664269

4267-
if mamba_layers:
4268-
self.gguf_writer.add_head_count_kv(mamba_layers)
4270+
if num_key_value_heads and num_attention_heads:
4271+
self.gguf_writer.add_head_count_kv(num_key_value_heads)
4272+
self.gguf_writer.add_head_count(num_attention_heads)
42694273

42704274
self.gguf_writer.add_context_length(hparams.get("max_position_embeddings", 2048))
42714275
self.gguf_writer.add_embedding_length(hparams.get("hidden_size", 4096))
4276+
self.gguf_writer.add_key_length(hparams.get("hidden_size_per_head", 128))
4277+
self.gguf_writer.add_value_length(hparams.get("hidden_size_per_head", 128))
42724278
self.gguf_writer.add_block_count(block_count)
4273-
self.gguf_writer.add_head_count(hparams.get("num_attention_heads", 32))
42744279
self.gguf_writer.add_layer_norm_rms_eps(hparams.get("rms_norm_eps", 1e-06))
42754280
self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 10000))
42764281

src/llama-hparams.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ struct llama_hparams {
4242
uint32_t n_embd;
4343
uint32_t n_embd_features = 0;
4444
uint32_t n_layer;
45-
int32_t n_layer_kv_from_start = -1; // if non-negative, the first n_layer_kv_from_start layers have KV cache
45+
int32_t n_layer_kv_from_start = -1; // if non-negative, the first n_layer_kv_from_start layers have KV cache
4646
uint32_t n_rot;
4747
uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
4848
uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head

src/llama-model.cpp

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1084,7 +1084,11 @@ void llama_model::load_hparams(llama_model_loader & ml) {
10841084
}
10851085
break;
10861086
default: type = LLM_TYPE_UNKNOWN;
1087-
}
1087+
}
1088+
1089+
// Load attention parameters
1090+
ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k, false);
1091+
ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false);
10881092
} break;
10891093
case LLM_ARCH_GPT2:
10901094
{
@@ -3392,17 +3396,17 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
33923396
} break;
33933397
case LLM_ARCH_PLAMO2:
33943398
{
3399+
// mamba parameters
33953400
const uint32_t d_conv = hparams.ssm_d_conv;
33963401
const uint32_t d_state = hparams.ssm_d_state;
33973402
const uint32_t num_heads = hparams.ssm_dt_rank;
33983403
const uint32_t intermediate_size = hparams.ssm_d_inner;
3399-
const uint32_t head_dim = intermediate_size / num_heads;
3400-
const uint32_t qk_dim = head_dim;
3401-
const uint32_t v_dim = head_dim;
3402-
const int64_t num_attention_heads = hparams.n_head();
3403-
const int64_t q_num_heads = num_attention_heads;
34043404
const int64_t dt_dim = std::max(64, int(hparams.n_embd / 16));
34053405

3406+
// attention parameters
3407+
const uint32_t qk_dim = hparams.n_embd_head_k;
3408+
const uint32_t v_dim = hparams.n_embd_head_v;
3409+
34063410
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
34073411

34083412
// output
@@ -3436,6 +3440,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
34363440
layer.ssm_b_norm = create_tensor(tn(LLM_TENSOR_SSM_B_NORM, i), {d_state}, 0);
34373441
layer.ssm_c_norm = create_tensor(tn(LLM_TENSOR_SSM_C_NORM, i), {d_state}, 0);
34383442
} else {
3443+
const int64_t num_attention_heads = hparams.n_head(i);
3444+
const int64_t q_num_heads = num_attention_heads;
34393445
const int64_t num_key_value_heads = hparams.n_head_kv(i);
34403446
const int64_t k_num_heads = num_key_value_heads;
34413447
const int64_t v_num_heads = num_key_value_heads;
@@ -3444,8 +3450,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
34443450
const int64_t v_proj_dim = v_num_heads * v_dim;
34453451

34463452
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, q_proj_dim + k_proj_dim + v_proj_dim}, 0);
3447-
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {head_dim, num_attention_heads}, 0);
3448-
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {head_dim, k_num_heads}, 0);
3453+
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {qk_dim, num_attention_heads}, 0);
3454+
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {qk_dim, k_num_heads}, 0);
34493455
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {q_num_heads * v_dim, n_embd}, 0);
34503456
}
34513457

@@ -17611,6 +17617,7 @@ struct llm_build_plamo2 : public llm_graph_context_mamba {
1761117617
const int64_t n_embd_head_q = hparams.n_embd_head_k;
1761217618
const int64_t n_embd_head_k = hparams.n_embd_head_k;
1761317619
const int64_t n_embd_head_v = hparams.n_embd_head_v;
17620+
int32_t n_head = hparams.n_head(il);
1761417621
int32_t n_head_kv = hparams.n_head_kv(il);
1761517622

1761617623
const int64_t q_offset = 0;

0 commit comments

Comments
 (0)