Skip to content

Commit 3e9f991

Browse files
authored
Use FlashAttention for multi_query_kv_attention (#4)
1 parent 0deacbc commit 3e9f991

File tree

4 files changed

+108
-34
lines changed

4 files changed

+108
-34
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
```bash
66
pip install cmake torch transformers
7+
pip install flash-attn # This may take up to 10 mins.
78
pip install -e .
89
```
910

cacheflow/models/attention.py

Lines changed: 37 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import List, Optional
22

3+
from flash_attn.flash_attention import FlashAttention
34
import torch
45
import torch.nn as nn
56

@@ -14,20 +15,7 @@ def __init__(self, scale: float) -> None:
1415
super().__init__()
1516
self.scale = float(scale)
1617

17-
def _masked_attention(
18-
self,
19-
query: torch.Tensor, # [num_queries, num_heads, head_size]
20-
key: torch.Tensor, # [num_keys, num_heads, head_size]
21-
value: torch.Tensor, # [num_keys, num_heads, head_size]
22-
attn_mask: Optional[torch.Tensor] = None, # [num_queries, num_keys]
23-
) -> torch.Tensor: # [num_queries, num_heads, head_size]
24-
query = query * self.scale
25-
attn = torch.einsum('qhd,khd->hqk', query, key)
26-
if attn_mask is not None:
27-
attn = attn + attn_mask
28-
attn = torch.softmax(attn, dim=-1)
29-
out = torch.einsum('hqk,khd->qhd', attn, value)
30-
return out
18+
self.flash_attn = FlashAttention(softmax_scale=self.scale)
3119

3220
def multi_query_kv_attention(
3321
self,
@@ -37,21 +25,31 @@ def multi_query_kv_attention(
3725
value: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
3826
prompt_lens: List[int],
3927
) -> None:
40-
# FIXME(woosuk): Replace the following with a custom op.
41-
start_idx = 0
28+
if query.dtype == torch.float:
29+
raise ValueError('The float data type is not supported by '
30+
'FlashAttention. Use the half data type instead.')
31+
head_size = query.shape[2]
32+
if head_size > 128:
33+
raise ValueError('FlashAttention does not support head_size > 128.')
34+
35+
device = query.device
36+
prefix_sum = [0]
4237
for prompt_len in prompt_lens:
43-
out = output[start_idx:start_idx + prompt_len]
44-
q = query[start_idx:start_idx + prompt_len]
45-
k = key[start_idx:start_idx + prompt_len]
46-
v = value[start_idx:start_idx + prompt_len]
47-
48-
attention_mask = torch.triu(
49-
torch.ones(q.shape[0], k.shape[0]), diagonal=1) * -1e5
50-
attention_mask = attention_mask.to(dtype=q.dtype, device=q.device)
51-
attention_out = self._masked_attention(q, k, v, attention_mask)
52-
out.copy_(attention_out, non_blocking=True)
53-
54-
start_idx += prompt_len
38+
prefix_sum.append(prefix_sum[-1] + prompt_len)
39+
prefix_sum = torch.tensor(prefix_sum, dtype=torch.int, device=device)
40+
max_prompt_len = max(prompt_lens)
41+
42+
# FIXME(woosuk): Unnecessary copy. Optimize this.
43+
qkv = torch.stack([query, key, value], dim=1)
44+
out = self.flash_attn(
45+
qkv,
46+
cu_seqlens=prefix_sum,
47+
max_s=max_prompt_len,
48+
causal=True,
49+
)[0]
50+
num_tokens = prefix_sum[-1]
51+
# FIXME(woosuk): Unnecessary copy. Optimize this.
52+
output[:num_tokens].copy_(out, non_blocking=True)
5553

5654
def single_query_cached_kv_attention(
5755
self,
@@ -61,6 +59,14 @@ def single_query_cached_kv_attention(
6159
value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
6260
input_metadata: InputMetadata,
6361
) -> None:
62+
head_size = value_cache.shape[2]
63+
supported_head_sizes = [32, 64, 80, 96, 128, 160, 192, 256]
64+
if head_size not in supported_head_sizes:
65+
raise ValueError(f'head_size ({head_size}) is not supported by '
66+
'the single_query_cached_kv_attention kernel. '
67+
'Use one of the following head sizes: '
68+
f'{supported_head_sizes}.')
69+
6470
block_size = value_cache.shape[3]
6571
attention_ops.single_query_cached_kv_attention(
6672
output,
@@ -101,8 +107,9 @@ def forward(
101107
output = output.view(-1, num_heads, head_size)
102108

103109
# Compute the attention op for prompts.
104-
self.multi_query_kv_attention(
105-
output, query, key, value, input_metadata.prompt_lens)
110+
if input_metadata.num_prompts > 0:
111+
self.multi_query_kv_attention(
112+
output, query, key, value, input_metadata.prompt_lens)
106113

107114
# Wait until the cache op is done.
108115
if cache_event is not None:

server.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@
99
parser.add_argument('--model', type=str, default='facebook/opt-125m', help='model name')
1010
parser.add_argument('--num-nodes', type=int, default=1, help='number of nodes')
1111
parser.add_argument('--num-workers', type=int, default=1, help='number of workers per node')
12-
parser.add_argument('--block-size', type=int, default=8, help='token block size')
12+
parser.add_argument('--block-size', type=int, default=8, choices=[8, 16], help='token block size')
1313
# TODO(woosuk): Add an analytical model to determine the maximum number of GPU/CPU blocks.
1414
parser.add_argument('--num-gpu-blocks', type=int, default=1024, help='number of GPU blocks (per GPU)')
15-
parser.add_argument('--num-cpu-blocks', type=int, default=256, help='number of CPU blocks (per GPU)')
15+
parser.add_argument('--num-cpu-blocks', type=int, default=32, help='number of CPU blocks (per GPU)')
16+
# NOTE(woosuk): If FlashAttention is used, the float data type is not supported.
17+
parser.add_argument('--dtype', type=str, default='half', choices=['half', 'float'], help='data type')
1618
args = parser.parse_args()
1719

1820

@@ -27,6 +29,7 @@ def main():
2729
block_size=args.block_size,
2830
num_gpu_blocks=args.num_gpu_blocks,
2931
num_cpu_blocks=args.num_cpu_blocks,
32+
dtype=args.dtype,
3033
)
3134
controllers.append(controller)
3235

tests/kernels/attention.py

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
import random
22
from typing import Optional
33

4+
from flash_attn.flash_attention import FlashAttention
45
import torch
56

67
from cacheflow import attention_ops
78

9+
MAX_SEQ_LEN = 4096
10+
811

912
def ref_masked_attention(
1013
query: torch.Tensor,
@@ -79,7 +82,7 @@ def test_single_query_cached_kv_attention(
7982
value_cache = torch.randn(
8083
size=(num_blocks, *value_block_shape), dtype=dtype, device='cuda')
8184

82-
context_lens = [random.randint(1, 4096) for _ in range(num_tokens)]
85+
context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_tokens)]
8386
max_context_len = max(context_lens)
8487
context_lens = torch.tensor(context_lens, dtype=torch.int, device='cuda')
8588

@@ -123,11 +126,60 @@ def test_single_query_cached_kv_attention(
123126
assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)
124127

125128

129+
def test_multi_query_kv_attention(
130+
num_seqs: int,
131+
num_heads: int,
132+
head_size: int,
133+
dtype: torch.dtype,
134+
) -> None:
135+
seq_lens = random.sample(range(1, MAX_SEQ_LEN), num_seqs)
136+
max_seq_len = max(seq_lens)
137+
num_tokens = sum(seq_lens)
138+
139+
cu_seq_lens = [0]
140+
for seq_len in seq_lens:
141+
cu_seq_lens.append(cu_seq_lens[-1] + seq_len)
142+
cu_seq_lens = torch.tensor(cu_seq_lens, dtype=torch.int, device='cuda')
143+
144+
scale = float(1.0 / (head_size ** 0.5))
145+
query = torch.randn(
146+
num_tokens, num_heads, head_size, dtype=dtype, device='cuda')
147+
key = torch.rand_like(query)
148+
value = torch.rand_like(query)
149+
150+
qkv = torch.stack([query, key, value], dim=1)
151+
flash_attn = FlashAttention(softmax_scale=scale)
152+
output = flash_attn(
153+
qkv,
154+
cu_seqlens=cu_seq_lens,
155+
max_s=max_seq_len,
156+
causal=True,
157+
)[0]
158+
159+
ref_outputs = []
160+
for i, seq_len in enumerate(seq_lens):
161+
attn_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1) * -1e5
162+
attn_mask = attn_mask.to(dtype=dtype, device='cuda')
163+
start_idx = cu_seq_lens[i]
164+
end_idx = cu_seq_lens[i + 1]
165+
ref_output = ref_masked_attention(
166+
query[start_idx:end_idx],
167+
key[start_idx:end_idx],
168+
value[start_idx:end_idx],
169+
scale,
170+
attn_mask=attn_mask,
171+
)
172+
ref_outputs.append(ref_output)
173+
ref_output = torch.cat(ref_outputs, dim=0)
174+
175+
assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)
176+
177+
126178
@torch.inference_mode()
127179
def test_attention() -> None:
128180
for dtype in [torch.half, torch.float]:
129181
for block_size in [8, 16]:
130-
for head_size in [64, 80, 96, 128, 256]:
182+
for head_size in [32, 64, 80, 96, 128, 160, 192, 256]:
131183
test_single_query_cached_kv_attention(
132184
num_tokens=37,
133185
num_heads=3,
@@ -137,6 +189,17 @@ def test_attention() -> None:
137189
dtype=dtype,
138190
)
139191

192+
# NOTE(woosuk): FlashAttention does not support FP32.
193+
for dtype in [torch.half]:
194+
# NOTE(woosuk): FlashAttention does not support head_size > 128.
195+
for head_size in [64, 80, 96, 128]:
196+
test_multi_query_kv_attention(
197+
num_seqs=11,
198+
num_heads=3,
199+
head_size=head_size,
200+
dtype=dtype,
201+
)
202+
140203

141204
if __name__ == '__main__':
142205
test_attention()

0 commit comments

Comments
 (0)