2
2
3
3
import torch
4
4
from torch .backends .cuda import (
5
- can_use_flash_attention ,
6
- can_use_efficient_attention ,
7
5
can_use_cudnn_attention ,
6
+ can_use_efficient_attention ,
7
+ can_use_flash_attention ,
8
8
)
9
- from torch .nn .attention import SDPBackend , SDPAParams
9
+ from torch .nn .attention import SDPAParams , SDPBackend
10
10
11
11
12
12
def filter_sdpa_kernels (
@@ -20,9 +20,7 @@ def filter_sdpa_kernels(
20
20
enable_gqa : bool ,
21
21
** kwargs ,
22
22
) -> List [SDPBackend ]:
23
- params = SDPAParams (
24
- query , key , value , attn_mask , dropout_p , is_causal , enable_gqa
25
- )
23
+ params = SDPAParams (query , key , value , attn_mask , dropout_p , is_causal , enable_gqa )
26
24
new_kernels = []
27
25
for kernel in sdpa_kernels :
28
26
if kernel == SDPBackend .FLASH_ATTENTION and not can_use_flash_attention (params ):
@@ -103,7 +101,10 @@ def mask_cache_bool(
103
101
) -> torch .Tensor :
104
102
# Usual causal mask:
105
103
mask = torch .ones (
106
- max_seq_length , max_seq_length , device = device , dtype = dtype ,
104
+ max_seq_length ,
105
+ max_seq_length ,
106
+ device = device ,
107
+ dtype = dtype ,
107
108
).triu (diagonal = 1 )
108
109
if sliding_window_size is not None :
109
110
mask += torch .ones_like (mask ).tril (diagonal = - sliding_window_size )
@@ -147,25 +148,52 @@ def mask_slice_bool(
147
148
tp_dtype = token_positions .dtype
148
149
batch_size , n_query_groups , _ = token_positions .shape
149
150
assert n_head % n_query_groups == 0 and n_head >= n_query_groups
150
- token_positions = token_positions .to (device = device ).unsqueeze (2 ).expand (
151
- - 1 , - 1 , num , - 1 ,
151
+ token_positions = (
152
+ token_positions .to (device = device )
153
+ .unsqueeze (2 )
154
+ .expand (
155
+ - 1 ,
156
+ - 1 ,
157
+ num ,
158
+ - 1 ,
159
+ )
152
160
)
153
161
kwargs = dict (device = device , dtype = tp_dtype )
154
- bool_mask = torch .arange (
155
- input_pos , input_pos + num , ** kwargs ,
156
- ).view (1 , 1 , - 1 , 1 ).expand_as (token_positions ) < token_positions
157
- if sliding_window_size is not None :
158
- extra_mask = torch .arange (
159
- input_pos - sliding_window_size ,
160
- input_pos + num - sliding_window_size ,
162
+ bool_mask = (
163
+ torch .arange (
164
+ input_pos ,
165
+ input_pos + num ,
161
166
** kwargs ,
162
- ).view (1 , 1 , - 1 , 1 ).expand_as (token_positions ) >= token_positions
167
+ )
168
+ .view (1 , 1 , - 1 , 1 )
169
+ .expand_as (token_positions )
170
+ < token_positions
171
+ )
172
+ if sliding_window_size is not None :
173
+ extra_mask = (
174
+ torch .arange (
175
+ input_pos - sliding_window_size ,
176
+ input_pos + num - sliding_window_size ,
177
+ ** kwargs ,
178
+ )
179
+ .view (1 , 1 , - 1 , 1 )
180
+ .expand_as (token_positions )
181
+ >= token_positions
182
+ )
163
183
bool_mask |= extra_mask
164
184
if n_head != n_query_groups :
165
185
q_per_kv = n_head // n_query_groups
166
- bool_mask = bool_mask .unsqueeze (2 ).expand (
167
- - 1 , - 1 , q_per_kv , - 1 , - 1 ,
168
- ).reshape (batch_size , n_head , num , - 1 )
186
+ bool_mask = (
187
+ bool_mask .unsqueeze (2 )
188
+ .expand (
189
+ - 1 ,
190
+ - 1 ,
191
+ q_per_kv ,
192
+ - 1 ,
193
+ - 1 ,
194
+ )
195
+ .reshape (batch_size , n_head , num , - 1 )
196
+ )
169
197
return bool_mask
170
198
171
199
@@ -197,7 +225,12 @@ def build_mask_slice(
197
225
198
226
"""
199
227
bool_mask = mask_slice_bool (
200
- input_pos , num , token_positions , n_head , device , sliding_window_size ,
228
+ input_pos ,
229
+ num ,
230
+ token_positions ,
231
+ n_head ,
232
+ device ,
233
+ sliding_window_size ,
201
234
)
202
235
mask = torch .zeros (bool_mask .shape , dtype = dtype , device = device )
203
236
mask .masked_fill_ (bool_mask , minus_infinity (dtype ))
0 commit comments