Skip to content

Commit 2b0673c

Browse files
committed
Cleanup ggml_delta_net
1 parent 72c98b0 commit 2b0673c

File tree

1 file changed

+33
-118
lines changed

1 file changed

+33
-118
lines changed

ggml/src/ggml-delta.c

Lines changed: 33 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,8 @@ struct ggml_tensor * ggml_delta_net(
5252
GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[3] == n_tokens);
5353

5454
GGML_ASSERT(g->ne[0] == S_v && g->ne[1] == H_v && g->ne[3] == n_tokens && g->ne[2] == batch_size);
55-
56-
struct ggml_tensor * beta_sigmoid = ggml_sigmoid(ctx, beta);
57-
report_tensor_size("beta_sigmoid", beta_sigmoid);
58-
55+
56+
// Merge q, k, v into qkv
5957
struct ggml_tensor * mixed_qkv = ggml_concat(ctx, q, k, 1);
6058
report_tensor_size("mixed_qkv_qk", mixed_qkv);
6159
mixed_qkv = ggml_concat(ctx, mixed_qkv, v, 1);
@@ -68,6 +66,7 @@ struct ggml_tensor * ggml_delta_net(
6866
struct ggml_tensor * mixed_qkv_padded = ggml_pad(ctx, mixed_qkv, conv_weight->ne[0] - 1, 0, 0, 0);
6967
report_tensor_size("mixed_qkv_padded", mixed_qkv_padded);
7068

69+
// Apply convolution
7170
struct ggml_tensor * conv_out = ggml_ssm_conv(ctx, mixed_qkv_padded, conv_weight);
7271
report_tensor_size("conv_out", conv_out);
7372

@@ -85,68 +84,36 @@ struct ggml_tensor * ggml_delta_net(
8584
conv_out = ggml_permute(ctx, conv_out, 0, 2, 1, 3);
8685
report_tensor_size("conv_out_transposed", conv_out);
8786

88-
struct ggml_tensor * q_conv = ggml_view_4d(ctx, conv_out,
89-
S_k, // ne0
90-
H_k, // ne1
91-
conv_out->ne[1], // ne2 = sequence length (1)
92-
conv_out->ne[2], // ne3 = batch (1)
93-
H_k * sizeof(float), // nb1 = stride along H_k
94-
conv_out->nb[1], // nb2 = stride along sequence dim
95-
conv_out->nb[2], // nb3 = stride along batch dim
96-
0 // offset in bytes
97-
);
87+
// Beta sigmoid
88+
struct ggml_tensor * beta_sigmoid = ggml_sigmoid(ctx, beta);
89+
report_tensor_size("beta_sigmoid", beta_sigmoid);
90+
91+
// Gate calculations are done elsewhere in llama-model.cpp
92+
93+
// Re-split the qkv tensors
94+
struct ggml_tensor * q_conv = ggml_view_4d(ctx, conv_out, S_k, H_k, conv_out->ne[1], conv_out->ne[2],
95+
H_k * sizeof(float), conv_out->nb[1], conv_out->nb[2], 0);
9896
report_tensor_size("q_conv_view", q_conv);
9997

100-
// k projection view
101-
struct ggml_tensor * k_conv = ggml_view_4d(ctx, conv_out,
102-
S_k, // ne0
103-
H_k, // ne1
104-
conv_out->ne[1], // ne2
105-
conv_out->ne[2], // ne3
106-
H_k * sizeof(float), // nb1
107-
conv_out->nb[1], // nb2
108-
conv_out->nb[2], // nb3
109-
S_k * H_k * sizeof(q->type) // offset = skip q_out
110-
);
98+
struct ggml_tensor * k_conv = ggml_view_4d(ctx, conv_out, S_k, H_k, conv_out->ne[1], conv_out->ne[2],
99+
H_k * sizeof(float), conv_out->nb[1], conv_out->nb[2], S_k * H_k * sizeof(q->type));
111100
report_tensor_size("k_conv_view", k_conv);
112101

113-
// v projection view
114-
struct ggml_tensor * v_conv = ggml_view_4d(ctx, conv_out,
115-
S_v, // ne0
116-
H_v, // ne1
117-
conv_out->ne[1], // ne2
118-
conv_out->ne[2], // ne3
119-
H_v * sizeof(float), // nb1
120-
conv_out->nb[1], // nb2
121-
conv_out->nb[2], // nb3
122-
(2 * S_k * H_k) * sizeof(q->type)// offset = skip q_out + k_out
123-
);
102+
struct ggml_tensor * v_conv = ggml_view_4d(ctx, conv_out, S_v, H_v, conv_out->ne[1], conv_out->ne[2], H_v * sizeof(float),
103+
conv_out->nb[1], conv_out->nb[2], (2 * S_k * H_k) * sizeof(q->type));
124104
report_tensor_size("v_conv_view", v_conv);
125105

126-
q_conv = ggml_permute(ctx, q_conv, 0, 2, 1, 3);
127-
report_tensor_size("q_conv_permuted", q_conv);
128-
k_conv = ggml_permute(ctx, k_conv, 0, 2, 1, 3);
129-
report_tensor_size("k_conv_permuted", k_conv);
130-
v_conv = ggml_permute(ctx, v_conv, 0, 2, 1, 3);
131-
report_tensor_size("v_conv_permuted", v_conv);
132-
133-
q_conv = ggml_reshape_3d(ctx, ggml_cont(ctx, q_conv), S_k * H_k, batch_size, n_tokens);
134-
report_tensor_size("q_conv_reshaped", q_conv);
135-
k_conv = ggml_reshape_3d(ctx, ggml_cont(ctx, k_conv), S_k * H_k, batch_size, n_tokens);
136-
report_tensor_size("k_conv_reshaped", k_conv);
137-
v_conv = ggml_reshape_3d(ctx, ggml_cont(ctx, v_conv), S_v * H_v, batch_size, n_tokens);
138-
report_tensor_size("v_conv_reshaped", v_conv);
139-
140106
struct ggml_tensor * q_broadcast = q_conv;
141107
struct ggml_tensor * k_broadcast = k_conv;
142108

109+
// if head keys and value keys are different, repeat to force tensors into matching shapes
143110
if (H_k != H_v) {
144111
GGML_ASSERT(H_v % H_k == 0);
145112
int64_t repeat_factor = H_v / H_k;
146113

147-
q_broadcast = ggml_reshape_4d(ctx, q_conv, S_k, batch_size, H_k, n_tokens);
114+
q_broadcast = ggml_cont_4d(ctx, q_conv, S_k, batch_size, H_k, n_tokens);
148115
report_tensor_size("q_broadcast_reshape1", q_broadcast);
149-
k_broadcast = ggml_reshape_4d(ctx, k_conv, S_k, batch_size, H_k, n_tokens);
116+
k_broadcast = ggml_cont_4d(ctx, k_conv, S_k, batch_size, H_k, n_tokens);
150117
report_tensor_size("k_broadcast_reshape1", k_broadcast);
151118

152119
q_broadcast = ggml_repeat_4d(ctx, q_broadcast, S_k, batch_size * repeat_factor, H_k, n_tokens);
@@ -160,24 +127,14 @@ struct ggml_tensor * ggml_delta_net(
160127
report_tensor_size("k_broadcast_reshape2", k_broadcast);
161128
}
162129

163-
struct ggml_tensor * v_reshape = ggml_reshape_4d(ctx, v_conv, S_v, H_v, n_tokens, batch_size);
130+
struct ggml_tensor * v_reshape = ggml_cont_4d(ctx, v_conv, S_v, H_v, n_tokens, batch_size);
164131
report_tensor_size("v_reshape", v_reshape);
165-
struct ggml_tensor * v_broadcast = ggml_repeat_4d(ctx, v_reshape, S_v, H_v, n_tokens, batch_size);
166-
report_tensor_size("v_broadcast", v_broadcast);
167-
struct ggml_tensor * g_reshape = g;
168-
report_tensor_size("g_reshape", g_reshape);
169-
q_broadcast = ggml_repeat_4d(ctx, q_broadcast, S_k, H_v, n_tokens, batch_size);
170-
report_tensor_size("q_broadcast_final", q_broadcast);
171-
k_broadcast = ggml_repeat_4d(ctx, k_broadcast, S_k, H_v, n_tokens, batch_size);
172-
report_tensor_size("k_broadcast_final", k_broadcast);
173-
struct ggml_tensor * beta_reshape = ggml_reshape_4d(ctx, beta_sigmoid, 1, H_v, n_tokens, batch_size);
174-
report_tensor_size("beta_reshape", beta_reshape);
175-
struct ggml_tensor * beta_broadcast = ggml_repeat_4d(ctx, beta_reshape, 1, H_v, n_tokens, batch_size);
132+
struct ggml_tensor * beta_broadcast = ggml_cont_4d(ctx, beta, 1, H_v, n_tokens, batch_size);
176133
report_tensor_size("beta_broadcast", beta_broadcast);
177134
struct ggml_tensor * state_broadcast = ggml_cont(ctx, state);
178135
report_tensor_size("state_broadcast", state_broadcast);
179136

180-
return ggml_delta_net_op(ctx, q_broadcast, k_broadcast, v_broadcast, g_reshape, beta_broadcast, state_broadcast, use_qk_l2norm, scale);
137+
return ggml_delta_net_op(ctx, q_broadcast, k_broadcast, v_reshape, g, beta_broadcast, state_broadcast, use_qk_l2norm, scale);
181138
}
182139

183140
struct ggml_tensor * ggml_delta_net_op(
@@ -212,9 +169,10 @@ struct ggml_tensor * ggml_delta_net_op(
212169
const int64_t batch_size = q->ne[3];
213170

214171
const int64_t S_v = v->ne[0];
215-
const int64_t H_v = v->ne[1];
172+
const int64_t H_v = v->ne[1];
216173

217174
GGML_LOG_INFO("S_k = %ld, S_v = %ld, H_k = %ld, H_v = %ld\n", S_k, S_v, H_k, H_v);
175+
GGML_ASSERT(H_k == H_v); // we broadcasted the tensors in the main function to guarantee this
218176

219177
GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_v && k->ne[2] == n_tokens && k->ne[3] == batch_size);
220178
GGML_ASSERT(v->ne[1] == H_v && v->ne[2] == n_tokens && v->ne[3] == batch_size);
@@ -289,71 +247,28 @@ struct ggml_tensor * ggml_delta_net_op(
289247
struct ggml_tensor * state_t = state_2d;
290248
report_tensor_size("state_t", state_t);
291249

292-
struct ggml_tensor * state_t_transposed = ggml_transpose(ctx, state_t);
250+
struct ggml_tensor * state_t_transposed = ggml_cont(ctx, ggml_transpose(ctx, state_t));
293251
report_tensor_size("state_t_transposed", state_t_transposed);
294-
252+
295253
struct ggml_tensor * k_t_final_reshaped = ggml_reshape_4d(ctx, k_t_final, H_v, S_k, batch_size, 1);
296254
report_tensor_size("k_t_final_reshaped", k_t_final_reshaped);
297255

298-
struct ggml_tensor * kv_mem = ggml_mul_mat(ctx, k_t_final_reshaped, state_t_transposed);
256+
struct ggml_tensor * kv_mem = ggml_mul_mat(ctx, state_t_transposed, k_t_final_reshaped);
299257
report_tensor_size("kv_mem", kv_mem);
300258

301259
struct ggml_tensor * v_t_final = v_t_reshaped;
302260
struct ggml_tensor * beta_t_final = beta_t_reshaped;
303-
304-
if (H_k != H_v) {
305-
struct ggml_tensor * v_t_4d = ggml_reshape_4d(ctx, v_t_reshaped, S_v, H_k, 1, batch_size);
306-
struct ggml_tensor * v_t_repeated = ggml_repeat_4d(ctx, v_t_4d, S_v, H_v, 1, batch_size);
307-
v_t_final = ggml_reshape_2d(ctx, v_t_repeated, S_v, H_v * batch_size);
308-
309-
struct ggml_tensor * beta_t_4d = ggml_reshape_4d(ctx, beta_t_reshaped, 1, H_k, 1, batch_size);
310-
struct ggml_tensor * beta_t_repeated = ggml_repeat_4d(ctx, beta_t_4d, 1, H_v, 1, batch_size);
311-
beta_t_final = ggml_reshape_2d(ctx, beta_t_repeated, 1, H_v * batch_size);
312-
}
313-
314-
struct ggml_tensor * kv_mem_reshaped;
315-
if (kv_mem->ne[0] == S_v && kv_mem->ne[1] == H_v * batch_size) {
316-
kv_mem_reshaped = kv_mem;
317-
} else if (kv_mem->ne[0] == S_v) {
318-
kv_mem_reshaped = ggml_view_2d(ctx, kv_mem, S_v, H_v * batch_size, kv_mem->nb[1], 0);
319-
} else {
320-
report_tensor_size("kv_mem_before_reshape", kv_mem);
321-
kv_mem_reshaped = ggml_reshape_2d(ctx, kv_mem, S_v, H_v * batch_size);
322-
}
323-
kv_mem_reshaped = ggml_cont(ctx, kv_mem_reshaped);
261+
262+
struct ggml_tensor * kv_mem_reshaped = ggml_transpose(ctx, kv_mem);
324263
report_tensor_size("kv_mem_reshaped", kv_mem_reshaped);
325-
326-
struct ggml_tensor * kv_mem_final;
327-
if (kv_mem_reshaped->ne[0] == v_t_final->ne[0] && kv_mem_reshaped->ne[1] == v_t_final->ne[1]) {
328-
kv_mem_final = kv_mem_reshaped;
329-
} else {
330-
kv_mem_final = ggml_repeat(ctx, kv_mem_reshaped, v_t_final);
331-
}
332-
report_tensor_size("kv_mem_final", kv_mem_final);
333-
334-
struct ggml_tensor * delta = ggml_mul(ctx, ggml_sub(ctx, v_t_final, kv_mem_final), beta_t_final);
264+
265+
struct ggml_tensor * delta = ggml_mul(ctx, ggml_sub(ctx, v_t_final, kv_mem_reshaped), beta_t_final);
335266
report_tensor_size("delta", delta);
336267

337268
struct ggml_tensor * delta_reshaped = ggml_reshape_2d(ctx, delta, S_v, H_v * batch_size);
338269
report_tensor_size("delta_reshaped", delta_reshaped);
339-
340-
if (H_k == H_v) {
341-
k_t_final = k_t_reshaped;
342-
} else {
343-
int64_t repeat_factor = H_v / H_k;
344-
GGML_ASSERT(H_v % H_k == 0);
345-
346-
k_t_final = ggml_reshape_3d(ctx, k_t_reshaped, S_k, 1, H_k * batch_size);
347-
report_tensor_size("k_t_final_reshape1", k_t_final);
348-
349-
k_t_final = ggml_repeat_4d(ctx, k_t_final, S_k, repeat_factor, H_k, batch_size);
350-
report_tensor_size("k_t_final_repeat", k_t_final);
351-
352-
k_t_final = ggml_reshape_2d(ctx, k_t_final, S_k, H_v * batch_size);
353-
report_tensor_size("k_t_final_reshape2", k_t_final);
354-
}
355-
356-
k_t_final = ggml_cont(ctx, k_t_final);
270+
271+
k_t_final = ggml_cont(ctx, k_t_reshaped);
357272
report_tensor_size("k_t_final_cont", k_t_final);
358273

359274
struct ggml_tensor * k_t_for_outer;

0 commit comments

Comments
 (0)