Skip to content

Commit 27fa5f3

Browse files
committed
Correct convolution state dimension calculations
1 parent e24c9df commit 27fa5f3

File tree

2 files changed

+13
-17
lines changed

2 files changed

+13
-17
lines changed

convert_hf_to_gguf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3759,7 +3759,7 @@ def set_gguf_parameters(self):
37593759
self.gguf_writer.add_ssm_state_size(self.find_hparam(["linear_key_head_dim"]))
37603760
self.gguf_writer.add_ssm_group_count(self.find_hparam(["linear_num_key_heads"]))
37613761
self.gguf_writer.add_ssm_time_step_rank(self.find_hparam(["linear_num_value_heads"]))
3762-
self.gguf_writer.add_ssm_inner_size(self.find_hparam(["hidden_size"]) * (self.find_hparam(["linear_num_value_heads"]) // self.find_hparam(["linear_num_key_heads"])))
3762+
self.gguf_writer.add_ssm_inner_size(self.find_hparam(['linear_value_head_dim']) * self.find_hparam(['linear_num_value_heads']))
37633763

37643764
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
37653765
if name.startswith("mtp"):

src/llama-model.cpp

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19240,7 +19240,6 @@ struct llm_build_qwen3next : public llm_graph_context_mamba {
1924019240
ggml_tensor * mixed_ba = build_lora_mm(model.layers[il].ssm_beta_alpha, cur);
1924119241
cb(mixed_ba, "linear_attn_mixed_ba", il);
1924219242

19243-
// Reshape mixed_qkvz: [batch, seq_len, hidden_size] -> [batch, seq_len, num_k_heads, 2*head_k_dim + 2*head_v_dim*num_v_heads/num_k_heads]
1924419243
int64_t qkvz_new_dim = 2 * head_k_dim + 2 * head_v_dim * num_v_heads / num_k_heads;
1924519244
ggml_tensor * mixed_qkvz_reshaped =
1924619245
ggml_reshape_4d(ctx0, mixed_qkvz, qkvz_new_dim, num_k_heads, n_tokens, n_seqs);
@@ -19327,23 +19326,20 @@ struct llm_build_qwen3next : public llm_graph_context_mamba {
1932719326
// Build the convolution states tensor
1932819327
ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs);
1932919328
cb(conv_states, "conv_states", il);
19329+
19330+
// Combine query, key, value for convolution input
19331+
ggml_tensor * qkv_mixed = ggml_concat(ctx0, query, key, 1);
19332+
qkv_mixed = ggml_concat(ctx0, qkv_mixed, value_reshaped, 1);
19333+
19334+
int64_t qkv_dim = head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads;
1933019335

1933119336
// Calculate convolution kernel size
1933219337
const int64_t conv_kernel_size = model.layers[il].ssm_conv1d->ne[0];
19333-
19334-
// Calculate input dimensions for Qwen3Next
19335-
const int64_t input_dim = (head_k_dim * num_k_heads * 2) + (head_v_dim * num_v_heads);
19336-
19337-
// Reshape conv_states to [conv_kernel_size - 1, input_dim, n_seqs]
19338-
conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, input_dim, n_seqs);
19338+
conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state, n_seqs);
1933919339
cb(conv_states, "conv_states_reshaped", il);
1934019340

19341-
// Combine query, key, value for convolution input
19342-
ggml_tensor * qkv_mixed = ggml_concat(ctx0, query, key, 1);
19343-
qkv_mixed = ggml_concat(ctx0, qkv_mixed, value_reshaped, 1);
19344-
1934519341
// Reshape to [input_dim, n_seq_tokens, n_seqs] for concatenation
19346-
qkv_mixed = ggml_reshape_3d(ctx0, qkv_mixed, input_dim, n_seq_tokens, n_seqs);
19342+
qkv_mixed = ggml_reshape_3d(ctx0, qkv_mixed, qkv_dim, n_seq_tokens, n_seqs);
1934719343
cb(qkv_mixed, "qkv_mixed_for_conv", il);
1934819344

1934919345
// Concatenate cached conv states with current input
@@ -19367,18 +19363,18 @@ struct llm_build_qwen3next : public llm_graph_context_mamba {
1936719363
// Update convolution state cache
1936819364
// Extract the last (conv_kernel_size - 1) states from conv_input
1936919365
ggml_tensor * last_conv_states =
19370-
ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, input_dim, n_seqs, conv_input->nb[1],
19366+
ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, qkv_dim, n_seqs, conv_input->nb[1],
1937119367
conv_input->nb[2], n_seq_tokens * conv_input->nb[0]);
1937219368

1937319369
ggml_build_forward_expand(
1937419370
gf, ggml_cpy(ctx0, last_conv_states,
19375-
ggml_view_1d(ctx0, conv_states_all, (conv_kernel_size - 1) * input_dim * n_seqs,
19376-
mctx_cur->get_head() * (conv_kernel_size - 1) * input_dim *
19371+
ggml_view_1d(ctx0, conv_states_all, (conv_kernel_size - 1) * qkv_dim * n_seqs,
19372+
mctx_cur->get_head() * (conv_kernel_size - 1) * qkv_dim *
1937719373
ggml_element_size(conv_states_all))));
1937819374
cb(conv_states_all, "conv_states_updated", il);
1937919375

1938019376
// Reshape conv_output back to proper dimensions
19381-
conv_output = ggml_reshape_4d(ctx0, conv_output, input_dim, n_seqs, n_seq_tokens, 1);
19377+
conv_output = ggml_reshape_4d(ctx0, conv_output, qkv_dim, n_seqs, n_seq_tokens, 1);
1938219378
cb(conv_output, "conv_output_reshaped", il);
1938319379
conv_output = ggml_permute(ctx0, conv_output, 0, 2, 1, 3);
1938419380
cb(conv_output, "conv_output_final", il);

0 commit comments

Comments
 (0)