@@ -122,6 +122,7 @@ def _compute_attention(
122
122
v ,
123
123
attention_mask ,
124
124
training = False ,
125
+ cache_update_index = 0 ,
125
126
):
126
127
if self .query_head_dim_normalize :
127
128
query_normalization = 1 / np .sqrt (self .head_dim )
@@ -152,29 +153,10 @@ def _compute_attention(
152
153
)
153
154
154
155
if self .use_sliding_window_attention :
155
- all_ones = ops .ones_like (attention_mask )
156
- if keras .config .backend () == "tensorflow" :
157
- import tensorflow as tf
158
-
159
- sliding_window_size = ops .minimum (
160
- self .sliding_window_size - 1 , q_len
161
- )
162
- sliding_window_size = ops .cast (
163
- sliding_window_size , dtype = "int32"
164
- )
165
- sliding_mask = tf .linalg .band_part (
166
- all_ones , sliding_window_size - 1 , sliding_window_size - 1
167
- )
168
- sliding_mask = ops .cast (sliding_mask , dtype = "bool" )
169
- bool_attention_mask = ops .cast (attention_mask , dtype = "bool" )
170
- attention_mask = tf .math .logical_and (
171
- sliding_mask , bool_attention_mask
172
- )
173
- else :
174
- sliding_mask = ops .triu (
175
- all_ones , - 1 * self .sliding_window_size + 1
176
- ) * ops .tril (all_ones , self .sliding_window_size - 1 )
177
- attention_mask = sliding_mask * attention_mask
156
+ attention_mask = self ._mask_sliding_window (
157
+ attention_mask ,
158
+ cache_update_index = cache_update_index ,
159
+ )
178
160
179
161
attention_mask = attention_mask [:, None , None , :, :]
180
162
orig_dtype = attention_logits .dtype
@@ -189,6 +171,32 @@ def _compute_attention(
189
171
results = ops .einsum ("bkgts,bskh->btkgh" , attention_softmax , v )
190
172
return ops .reshape (results , (b , q_len , self .num_query_heads , h ))
191
173
174
+ def _mask_sliding_window (
175
+ self ,
176
+ attention_mask ,
177
+ cache_update_index = 0 ,
178
+ ):
179
+ batch_size , query_len , key_len = ops .shape (attention_mask )
180
+ # Compute the sliding window for square attention.
181
+ all_ones = ops .ones ((key_len , key_len ), "bool" )
182
+ if keras .config .backend () == "tensorflow" :
183
+ # TODO: trui/tril has issues with dynamic shape on the tensorflow
184
+ # backend. We should fix, but use `band_part` for now.
185
+ import tensorflow as tf
186
+
187
+ band_size = ops .minimum (key_len , self .sliding_window_size - 1 )
188
+ band_size = ops .cast (band_size , "int32" )
189
+ sliding_mask = tf .linalg .band_part (all_ones , band_size , band_size )
190
+ else :
191
+ sliding_mask = ops .triu (
192
+ all_ones , - 1 * self .sliding_window_size + 1
193
+ ) * ops .tril (all_ones , self .sliding_window_size - 1 )
194
+ # Slice the window for short queries during generation.
195
+ start = (cache_update_index , 0 )
196
+ sliding_mask = ops .slice (sliding_mask , start , (query_len , key_len ))
197
+ sliding_mask = ops .expand_dims (sliding_mask , 0 )
198
+ return ops .logical_and (attention_mask , ops .cast (sliding_mask , "bool" ))
199
+
192
200
def call (
193
201
self ,
194
202
x ,
@@ -216,7 +224,12 @@ def call(
216
224
value = self .value_dense (x )
217
225
218
226
attention_vec = self ._compute_attention (
219
- query , key , value , attention_mask , training = training
227
+ query ,
228
+ key ,
229
+ value ,
230
+ attention_mask ,
231
+ training = training ,
232
+ cache_update_index = cache_update_index ,
220
233
)
221
234
222
235
# Wipe attn vec if there are no attended tokens.
0 commit comments