1
1
from typing import List , Optional
2
2
3
+ from flash_attn .flash_attention import FlashAttention
3
4
import torch
4
5
import torch .nn as nn
5
6
@@ -14,20 +15,7 @@ def __init__(self, scale: float) -> None:
14
15
super ().__init__ ()
15
16
self .scale = float (scale )
16
17
17
- def _masked_attention (
18
- self ,
19
- query : torch .Tensor , # [num_queries, num_heads, head_size]
20
- key : torch .Tensor , # [num_keys, num_heads, head_size]
21
- value : torch .Tensor , # [num_keys, num_heads, head_size]
22
- attn_mask : Optional [torch .Tensor ] = None , # [num_queries, num_keys]
23
- ) -> torch .Tensor : # [num_queries, num_heads, head_size]
24
- query = query * self .scale
25
- attn = torch .einsum ('qhd,khd->hqk' , query , key )
26
- if attn_mask is not None :
27
- attn = attn + attn_mask
28
- attn = torch .softmax (attn , dim = - 1 )
29
- out = torch .einsum ('hqk,khd->qhd' , attn , value )
30
- return out
18
+ self .flash_attn = FlashAttention (softmax_scale = self .scale )
31
19
32
20
def multi_query_kv_attention (
33
21
self ,
@@ -37,21 +25,31 @@ def multi_query_kv_attention(
37
25
value : torch .Tensor , # [num_prompt_tokens, num_heads, head_size]
38
26
prompt_lens : List [int ],
39
27
) -> None :
40
- # FIXME(woosuk): Replace the following with a custom op.
41
- start_idx = 0
28
+ if query .dtype == torch .float :
29
+ raise ValueError ('The float data type is not supported by '
30
+ 'FlashAttention. Use the half data type instead.' )
31
+ head_size = query .shape [2 ]
32
+ if head_size > 128 :
33
+ raise ValueError ('FlashAttention does not support head_size > 128.' )
34
+
35
+ device = query .device
36
+ prefix_sum = [0 ]
42
37
for prompt_len in prompt_lens :
43
- out = output [start_idx :start_idx + prompt_len ]
44
- q = query [start_idx :start_idx + prompt_len ]
45
- k = key [start_idx :start_idx + prompt_len ]
46
- v = value [start_idx :start_idx + prompt_len ]
47
-
48
- attention_mask = torch .triu (
49
- torch .ones (q .shape [0 ], k .shape [0 ]), diagonal = 1 ) * - 1e5
50
- attention_mask = attention_mask .to (dtype = q .dtype , device = q .device )
51
- attention_out = self ._masked_attention (q , k , v , attention_mask )
52
- out .copy_ (attention_out , non_blocking = True )
53
-
54
- start_idx += prompt_len
38
+ prefix_sum .append (prefix_sum [- 1 ] + prompt_len )
39
+ prefix_sum = torch .tensor (prefix_sum , dtype = torch .int , device = device )
40
+ max_prompt_len = max (prompt_lens )
41
+
42
+ # FIXME(woosuk): Unnecessary copy. Optimize this.
43
+ qkv = torch .stack ([query , key , value ], dim = 1 )
44
+ out = self .flash_attn (
45
+ qkv ,
46
+ cu_seqlens = prefix_sum ,
47
+ max_s = max_prompt_len ,
48
+ causal = True ,
49
+ )[0 ]
50
+ num_tokens = prefix_sum [- 1 ]
51
+ # FIXME(woosuk): Unnecessary copy. Optimize this.
52
+ output [:num_tokens ].copy_ (out , non_blocking = True )
55
53
56
54
def single_query_cached_kv_attention (
57
55
self ,
@@ -61,6 +59,14 @@ def single_query_cached_kv_attention(
61
59
value_cache : torch .Tensor , # [num_blocks, num_heads, head_size, block_size]
62
60
input_metadata : InputMetadata ,
63
61
) -> None :
62
+ head_size = value_cache .shape [2 ]
63
+ supported_head_sizes = [32 , 64 , 80 , 96 , 128 , 160 , 192 , 256 ]
64
+ if head_size not in supported_head_sizes :
65
+ raise ValueError (f'head_size ({ head_size } ) is not supported by '
66
+ 'the single_query_cached_kv_attention kernel. '
67
+ 'Use one of the following head sizes: '
68
+ f'{ supported_head_sizes } .' )
69
+
64
70
block_size = value_cache .shape [3 ]
65
71
attention_ops .single_query_cached_kv_attention (
66
72
output ,
@@ -101,8 +107,9 @@ def forward(
101
107
output = output .view (- 1 , num_heads , head_size )
102
108
103
109
# Compute the attention op for prompts.
104
- self .multi_query_kv_attention (
105
- output , query , key , value , input_metadata .prompt_lens )
110
+ if input_metadata .num_prompts > 0 :
111
+ self .multi_query_kv_attention (
112
+ output , query , key , value , input_metadata .prompt_lens )
106
113
107
114
# Wait until the cache op is done.
108
115
if cache_event is not None :
0 commit comments