@@ -52,10 +52,8 @@ struct ggml_tensor * ggml_delta_net(
52
52
GGML_ASSERT (k -> ne [0 ] == S_k && k -> ne [1 ] == H_k && k -> ne [3 ] == n_tokens );
53
53
54
54
GGML_ASSERT (g -> ne [0 ] == S_v && g -> ne [1 ] == H_v && g -> ne [3 ] == n_tokens && g -> ne [2 ] == batch_size );
55
-
56
- struct ggml_tensor * beta_sigmoid = ggml_sigmoid (ctx , beta );
57
- report_tensor_size ("beta_sigmoid" , beta_sigmoid );
58
-
55
+
56
+ // Merge q, k, v into qkv
59
57
struct ggml_tensor * mixed_qkv = ggml_concat (ctx , q , k , 1 );
60
58
report_tensor_size ("mixed_qkv_qk" , mixed_qkv );
61
59
mixed_qkv = ggml_concat (ctx , mixed_qkv , v , 1 );
@@ -68,6 +66,7 @@ struct ggml_tensor * ggml_delta_net(
68
66
struct ggml_tensor * mixed_qkv_padded = ggml_pad (ctx , mixed_qkv , conv_weight -> ne [0 ] - 1 , 0 , 0 , 0 );
69
67
report_tensor_size ("mixed_qkv_padded" , mixed_qkv_padded );
70
68
69
+ // Apply convolution
71
70
struct ggml_tensor * conv_out = ggml_ssm_conv (ctx , mixed_qkv_padded , conv_weight );
72
71
report_tensor_size ("conv_out" , conv_out );
73
72
@@ -85,68 +84,36 @@ struct ggml_tensor * ggml_delta_net(
85
84
conv_out = ggml_permute (ctx , conv_out , 0 , 2 , 1 , 3 );
86
85
report_tensor_size ("conv_out_transposed" , conv_out );
87
86
88
- struct ggml_tensor * q_conv = ggml_view_4d (ctx , conv_out ,
89
- S_k , // ne0
90
- H_k , // ne1
91
- conv_out -> ne [1 ], // ne2 = sequence length (1)
92
- conv_out -> ne [2 ], // ne3 = batch (1)
93
- H_k * sizeof (float ), // nb1 = stride along H_k
94
- conv_out -> nb [1 ], // nb2 = stride along sequence dim
95
- conv_out -> nb [2 ], // nb3 = stride along batch dim
96
- 0 // offset in bytes
97
- );
87
+ // Beta sigmoid
88
+ struct ggml_tensor * beta_sigmoid = ggml_sigmoid (ctx , beta );
89
+ report_tensor_size ("beta_sigmoid" , beta_sigmoid );
90
+
91
+ // Gate calculations are done elsewhere in llama-model.cpp
92
+
93
+ // Re-split the qkv tensors
94
+ struct ggml_tensor * q_conv = ggml_view_4d (ctx , conv_out , S_k , H_k , conv_out -> ne [1 ], conv_out -> ne [2 ],
95
+ H_k * sizeof (float ), conv_out -> nb [1 ], conv_out -> nb [2 ], 0 );
98
96
report_tensor_size ("q_conv_view" , q_conv );
99
97
100
- // k projection view
101
- struct ggml_tensor * k_conv = ggml_view_4d (ctx , conv_out ,
102
- S_k , // ne0
103
- H_k , // ne1
104
- conv_out -> ne [1 ], // ne2
105
- conv_out -> ne [2 ], // ne3
106
- H_k * sizeof (float ), // nb1
107
- conv_out -> nb [1 ], // nb2
108
- conv_out -> nb [2 ], // nb3
109
- S_k * H_k * sizeof (q -> type ) // offset = skip q_out
110
- );
98
+ struct ggml_tensor * k_conv = ggml_view_4d (ctx , conv_out , S_k , H_k , conv_out -> ne [1 ], conv_out -> ne [2 ],
99
+ H_k * sizeof (float ), conv_out -> nb [1 ], conv_out -> nb [2 ], S_k * H_k * sizeof (q -> type ));
111
100
report_tensor_size ("k_conv_view" , k_conv );
112
101
113
- // v projection view
114
- struct ggml_tensor * v_conv = ggml_view_4d (ctx , conv_out ,
115
- S_v , // ne0
116
- H_v , // ne1
117
- conv_out -> ne [1 ], // ne2
118
- conv_out -> ne [2 ], // ne3
119
- H_v * sizeof (float ), // nb1
120
- conv_out -> nb [1 ], // nb2
121
- conv_out -> nb [2 ], // nb3
122
- (2 * S_k * H_k ) * sizeof (q -> type )// offset = skip q_out + k_out
123
- );
102
+ struct ggml_tensor * v_conv = ggml_view_4d (ctx , conv_out , S_v , H_v , conv_out -> ne [1 ], conv_out -> ne [2 ], H_v * sizeof (float ),
103
+ conv_out -> nb [1 ], conv_out -> nb [2 ], (2 * S_k * H_k ) * sizeof (q -> type ));
124
104
report_tensor_size ("v_conv_view" , v_conv );
125
105
126
- q_conv = ggml_permute (ctx , q_conv , 0 , 2 , 1 , 3 );
127
- report_tensor_size ("q_conv_permuted" , q_conv );
128
- k_conv = ggml_permute (ctx , k_conv , 0 , 2 , 1 , 3 );
129
- report_tensor_size ("k_conv_permuted" , k_conv );
130
- v_conv = ggml_permute (ctx , v_conv , 0 , 2 , 1 , 3 );
131
- report_tensor_size ("v_conv_permuted" , v_conv );
132
-
133
- q_conv = ggml_reshape_3d (ctx , ggml_cont (ctx , q_conv ), S_k * H_k , batch_size , n_tokens );
134
- report_tensor_size ("q_conv_reshaped" , q_conv );
135
- k_conv = ggml_reshape_3d (ctx , ggml_cont (ctx , k_conv ), S_k * H_k , batch_size , n_tokens );
136
- report_tensor_size ("k_conv_reshaped" , k_conv );
137
- v_conv = ggml_reshape_3d (ctx , ggml_cont (ctx , v_conv ), S_v * H_v , batch_size , n_tokens );
138
- report_tensor_size ("v_conv_reshaped" , v_conv );
139
-
140
106
struct ggml_tensor * q_broadcast = q_conv ;
141
107
struct ggml_tensor * k_broadcast = k_conv ;
142
108
109
+ // if head keys and value keys are different, repeat to force tensors into matching shapes
143
110
if (H_k != H_v ) {
144
111
GGML_ASSERT (H_v % H_k == 0 );
145
112
int64_t repeat_factor = H_v / H_k ;
146
113
147
- q_broadcast = ggml_reshape_4d (ctx , q_conv , S_k , batch_size , H_k , n_tokens );
114
+ q_broadcast = ggml_cont_4d (ctx , q_conv , S_k , batch_size , H_k , n_tokens );
148
115
report_tensor_size ("q_broadcast_reshape1" , q_broadcast );
149
- k_broadcast = ggml_reshape_4d (ctx , k_conv , S_k , batch_size , H_k , n_tokens );
116
+ k_broadcast = ggml_cont_4d (ctx , k_conv , S_k , batch_size , H_k , n_tokens );
150
117
report_tensor_size ("k_broadcast_reshape1" , k_broadcast );
151
118
152
119
q_broadcast = ggml_repeat_4d (ctx , q_broadcast , S_k , batch_size * repeat_factor , H_k , n_tokens );
@@ -160,24 +127,14 @@ struct ggml_tensor * ggml_delta_net(
160
127
report_tensor_size ("k_broadcast_reshape2" , k_broadcast );
161
128
}
162
129
163
- struct ggml_tensor * v_reshape = ggml_reshape_4d (ctx , v_conv , S_v , H_v , n_tokens , batch_size );
130
+ struct ggml_tensor * v_reshape = ggml_cont_4d (ctx , v_conv , S_v , H_v , n_tokens , batch_size );
164
131
report_tensor_size ("v_reshape" , v_reshape );
165
- struct ggml_tensor * v_broadcast = ggml_repeat_4d (ctx , v_reshape , S_v , H_v , n_tokens , batch_size );
166
- report_tensor_size ("v_broadcast" , v_broadcast );
167
- struct ggml_tensor * g_reshape = g ;
168
- report_tensor_size ("g_reshape" , g_reshape );
169
- q_broadcast = ggml_repeat_4d (ctx , q_broadcast , S_k , H_v , n_tokens , batch_size );
170
- report_tensor_size ("q_broadcast_final" , q_broadcast );
171
- k_broadcast = ggml_repeat_4d (ctx , k_broadcast , S_k , H_v , n_tokens , batch_size );
172
- report_tensor_size ("k_broadcast_final" , k_broadcast );
173
- struct ggml_tensor * beta_reshape = ggml_reshape_4d (ctx , beta_sigmoid , 1 , H_v , n_tokens , batch_size );
174
- report_tensor_size ("beta_reshape" , beta_reshape );
175
- struct ggml_tensor * beta_broadcast = ggml_repeat_4d (ctx , beta_reshape , 1 , H_v , n_tokens , batch_size );
132
+ struct ggml_tensor * beta_broadcast = ggml_cont_4d (ctx , beta , 1 , H_v , n_tokens , batch_size );
176
133
report_tensor_size ("beta_broadcast" , beta_broadcast );
177
134
struct ggml_tensor * state_broadcast = ggml_cont (ctx , state );
178
135
report_tensor_size ("state_broadcast" , state_broadcast );
179
136
180
- return ggml_delta_net_op (ctx , q_broadcast , k_broadcast , v_broadcast , g_reshape , beta_broadcast , state_broadcast , use_qk_l2norm , scale );
137
+ return ggml_delta_net_op (ctx , q_broadcast , k_broadcast , v_reshape , g , beta_broadcast , state_broadcast , use_qk_l2norm , scale );
181
138
}
182
139
183
140
struct ggml_tensor * ggml_delta_net_op (
@@ -212,9 +169,10 @@ struct ggml_tensor * ggml_delta_net_op(
212
169
const int64_t batch_size = q -> ne [3 ];
213
170
214
171
const int64_t S_v = v -> ne [0 ];
215
- const int64_t H_v = v -> ne [1 ];
172
+ const int64_t H_v = v -> ne [1 ];
216
173
217
174
GGML_LOG_INFO ("S_k = %ld, S_v = %ld, H_k = %ld, H_v = %ld\n" , S_k , S_v , H_k , H_v );
175
+ GGML_ASSERT (H_k == H_v ); // we broadcasted the tensors in the main function to guarantee this
218
176
219
177
GGML_ASSERT (k -> ne [0 ] == S_k && k -> ne [1 ] == H_v && k -> ne [2 ] == n_tokens && k -> ne [3 ] == batch_size );
220
178
GGML_ASSERT (v -> ne [1 ] == H_v && v -> ne [2 ] == n_tokens && v -> ne [3 ] == batch_size );
@@ -289,71 +247,28 @@ struct ggml_tensor * ggml_delta_net_op(
289
247
struct ggml_tensor * state_t = state_2d ;
290
248
report_tensor_size ("state_t" , state_t );
291
249
292
- struct ggml_tensor * state_t_transposed = ggml_transpose (ctx , state_t );
250
+ struct ggml_tensor * state_t_transposed = ggml_cont ( ctx , ggml_transpose (ctx , state_t ) );
293
251
report_tensor_size ("state_t_transposed" , state_t_transposed );
294
-
252
+
295
253
struct ggml_tensor * k_t_final_reshaped = ggml_reshape_4d (ctx , k_t_final , H_v , S_k , batch_size , 1 );
296
254
report_tensor_size ("k_t_final_reshaped" , k_t_final_reshaped );
297
255
298
- struct ggml_tensor * kv_mem = ggml_mul_mat (ctx , k_t_final_reshaped , state_t_transposed );
256
+ struct ggml_tensor * kv_mem = ggml_mul_mat (ctx , state_t_transposed , k_t_final_reshaped );
299
257
report_tensor_size ("kv_mem" , kv_mem );
300
258
301
259
struct ggml_tensor * v_t_final = v_t_reshaped ;
302
260
struct ggml_tensor * beta_t_final = beta_t_reshaped ;
303
-
304
- if (H_k != H_v ) {
305
- struct ggml_tensor * v_t_4d = ggml_reshape_4d (ctx , v_t_reshaped , S_v , H_k , 1 , batch_size );
306
- struct ggml_tensor * v_t_repeated = ggml_repeat_4d (ctx , v_t_4d , S_v , H_v , 1 , batch_size );
307
- v_t_final = ggml_reshape_2d (ctx , v_t_repeated , S_v , H_v * batch_size );
308
-
309
- struct ggml_tensor * beta_t_4d = ggml_reshape_4d (ctx , beta_t_reshaped , 1 , H_k , 1 , batch_size );
310
- struct ggml_tensor * beta_t_repeated = ggml_repeat_4d (ctx , beta_t_4d , 1 , H_v , 1 , batch_size );
311
- beta_t_final = ggml_reshape_2d (ctx , beta_t_repeated , 1 , H_v * batch_size );
312
- }
313
-
314
- struct ggml_tensor * kv_mem_reshaped ;
315
- if (kv_mem -> ne [0 ] == S_v && kv_mem -> ne [1 ] == H_v * batch_size ) {
316
- kv_mem_reshaped = kv_mem ;
317
- } else if (kv_mem -> ne [0 ] == S_v ) {
318
- kv_mem_reshaped = ggml_view_2d (ctx , kv_mem , S_v , H_v * batch_size , kv_mem -> nb [1 ], 0 );
319
- } else {
320
- report_tensor_size ("kv_mem_before_reshape" , kv_mem );
321
- kv_mem_reshaped = ggml_reshape_2d (ctx , kv_mem , S_v , H_v * batch_size );
322
- }
323
- kv_mem_reshaped = ggml_cont (ctx , kv_mem_reshaped );
261
+
262
+ struct ggml_tensor * kv_mem_reshaped = ggml_transpose (ctx , kv_mem );
324
263
report_tensor_size ("kv_mem_reshaped" , kv_mem_reshaped );
325
-
326
- struct ggml_tensor * kv_mem_final ;
327
- if (kv_mem_reshaped -> ne [0 ] == v_t_final -> ne [0 ] && kv_mem_reshaped -> ne [1 ] == v_t_final -> ne [1 ]) {
328
- kv_mem_final = kv_mem_reshaped ;
329
- } else {
330
- kv_mem_final = ggml_repeat (ctx , kv_mem_reshaped , v_t_final );
331
- }
332
- report_tensor_size ("kv_mem_final" , kv_mem_final );
333
-
334
- struct ggml_tensor * delta = ggml_mul (ctx , ggml_sub (ctx , v_t_final , kv_mem_final ), beta_t_final );
264
+
265
+ struct ggml_tensor * delta = ggml_mul (ctx , ggml_sub (ctx , v_t_final , kv_mem_reshaped ), beta_t_final );
335
266
report_tensor_size ("delta" , delta );
336
267
337
268
struct ggml_tensor * delta_reshaped = ggml_reshape_2d (ctx , delta , S_v , H_v * batch_size );
338
269
report_tensor_size ("delta_reshaped" , delta_reshaped );
339
-
340
- if (H_k == H_v ) {
341
- k_t_final = k_t_reshaped ;
342
- } else {
343
- int64_t repeat_factor = H_v / H_k ;
344
- GGML_ASSERT (H_v % H_k == 0 );
345
-
346
- k_t_final = ggml_reshape_3d (ctx , k_t_reshaped , S_k , 1 , H_k * batch_size );
347
- report_tensor_size ("k_t_final_reshape1" , k_t_final );
348
-
349
- k_t_final = ggml_repeat_4d (ctx , k_t_final , S_k , repeat_factor , H_k , batch_size );
350
- report_tensor_size ("k_t_final_repeat" , k_t_final );
351
-
352
- k_t_final = ggml_reshape_2d (ctx , k_t_final , S_k , H_v * batch_size );
353
- report_tensor_size ("k_t_final_reshape2" , k_t_final );
354
- }
355
-
356
- k_t_final = ggml_cont (ctx , k_t_final );
270
+
271
+ k_t_final = ggml_cont (ctx , k_t_reshaped );
357
272
report_tensor_size ("k_t_final_cont" , k_t_final );
358
273
359
274
struct ggml_tensor * k_t_for_outer ;
0 commit comments