Skip to content

Commit 62d0148

Browse files
ch-wanssssnow
authored andcommitted
Save cuda graph memory for fa3 (#8567)
1 parent a0cf784 commit 62d0148

File tree

1 file changed

+7
-11
lines changed

1 file changed

+7
-11
lines changed

python/sglang/srt/layers/attention/flashattention_backend.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1406,7 +1406,7 @@ def init_forward_metadata_capture_cuda_graph(
14061406
)
14071407
metadata.page_table = self.decode_cuda_graph_metadata[
14081408
"page_table_draft_decode"
1409-
][req_pool_indices, :]
1409+
][:bs, :]
14101410
self.decode_cuda_graph_metadata[bs] = metadata
14111411
else:
14121412
# When top k > 1, we need two specific draft decode metadata, and then merge states
@@ -1424,7 +1424,7 @@ def init_forward_metadata_capture_cuda_graph(
14241424
][: bs + 1]
14251425
metadata.page_table = self.draft_decode_metadata_topk_normal[
14261426
"page_table"
1427-
][req_pool_indices, :]
1427+
][:bs, :]
14281428

14291429
# 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
14301430
metadata_expand.cache_seqlens_int32 = (
@@ -1461,7 +1461,7 @@ def init_forward_metadata_capture_cuda_graph(
14611461
metadata.max_seq_len_k = seq_lens.max().item()
14621462
# Precompute page table
14631463
metadata.page_table = self.decode_cuda_graph_metadata["page_table"][
1464-
req_pool_indices, :
1464+
:bs, :
14651465
]
14661466
# Precompute cumulative sequence lengths
14671467
metadata.cu_seqlens_q = torch.arange(
@@ -1498,9 +1498,7 @@ def init_forward_metadata_capture_cuda_graph(
14981498
: (bs + 1)
14991499
]
15001500

1501-
metadata.page_table = self.target_verify_metadata["page_table"][
1502-
req_pool_indices, :
1503-
]
1501+
metadata.page_table = self.target_verify_metadata["page_table"][:bs, :]
15041502

15051503
self.target_verify_metadata[bs] = metadata
15061504
else:
@@ -1519,7 +1517,7 @@ def init_forward_metadata_capture_cuda_graph(
15191517
][: bs + 1]
15201518
metadata.page_table = self.target_verify_metadata_topk_normal[
15211519
"page_table"
1522-
][req_pool_indices, :]
1520+
][:bs, :]
15231521

15241522
# 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
15251523
metadata_expand.cache_seqlens_int32 = (
@@ -1562,9 +1560,7 @@ def init_forward_metadata_capture_cuda_graph(
15621560
metadata.cu_seqlens_k = self.draft_extend_metadata["cu_seqlens_k"][
15631561
: (bs + 1)
15641562
]
1565-
metadata.page_table = self.draft_extend_metadata["page_table"][
1566-
req_pool_indices, :
1567-
]
1563+
metadata.page_table = self.draft_extend_metadata["page_table"][:bs, :]
15681564

15691565
self.draft_extend_metadata[bs] = metadata
15701566

@@ -1578,7 +1574,7 @@ def init_forward_metadata_capture_cuda_graph(
15781574
][: (encoder_bs + 1)]
15791575

15801576
metadata.encoder_page_table = self.encoder_metadata["encoder_page_table"][
1581-
req_pool_indices, :
1577+
:bs, :
15821578
]
15831579

15841580
self.forward_metadata = metadata

0 commit comments

Comments
 (0)