@@ -1406,7 +1406,7 @@ def init_forward_metadata_capture_cuda_graph(
1406
1406
)
1407
1407
metadata .page_table = self .decode_cuda_graph_metadata [
1408
1408
"page_table_draft_decode"
1409
- ][req_pool_indices , :]
1409
+ ][: bs , :]
1410
1410
self .decode_cuda_graph_metadata [bs ] = metadata
1411
1411
else :
1412
1412
# When top k > 1, we need two specific draft decode metadata, and then merge states
@@ -1424,7 +1424,7 @@ def init_forward_metadata_capture_cuda_graph(
1424
1424
][: bs + 1 ]
1425
1425
metadata .page_table = self .draft_decode_metadata_topk_normal [
1426
1426
"page_table"
1427
- ][req_pool_indices , :]
1427
+ ][: bs , :]
1428
1428
1429
1429
# 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
1430
1430
metadata_expand .cache_seqlens_int32 = (
@@ -1461,7 +1461,7 @@ def init_forward_metadata_capture_cuda_graph(
1461
1461
metadata .max_seq_len_k = seq_lens .max ().item ()
1462
1462
# Precompute page table
1463
1463
metadata .page_table = self .decode_cuda_graph_metadata ["page_table" ][
1464
- req_pool_indices , :
1464
+ : bs , :
1465
1465
]
1466
1466
# Precompute cumulative sequence lengths
1467
1467
metadata .cu_seqlens_q = torch .arange (
@@ -1498,9 +1498,7 @@ def init_forward_metadata_capture_cuda_graph(
1498
1498
: (bs + 1 )
1499
1499
]
1500
1500
1501
- metadata .page_table = self .target_verify_metadata ["page_table" ][
1502
- req_pool_indices , :
1503
- ]
1501
+ metadata .page_table = self .target_verify_metadata ["page_table" ][:bs , :]
1504
1502
1505
1503
self .target_verify_metadata [bs ] = metadata
1506
1504
else :
@@ -1519,7 +1517,7 @@ def init_forward_metadata_capture_cuda_graph(
1519
1517
][: bs + 1 ]
1520
1518
metadata .page_table = self .target_verify_metadata_topk_normal [
1521
1519
"page_table"
1522
- ][req_pool_indices , :]
1520
+ ][: bs , :]
1523
1521
1524
1522
# 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
1525
1523
metadata_expand .cache_seqlens_int32 = (
@@ -1562,9 +1560,7 @@ def init_forward_metadata_capture_cuda_graph(
1562
1560
metadata .cu_seqlens_k = self .draft_extend_metadata ["cu_seqlens_k" ][
1563
1561
: (bs + 1 )
1564
1562
]
1565
- metadata .page_table = self .draft_extend_metadata ["page_table" ][
1566
- req_pool_indices , :
1567
- ]
1563
+ metadata .page_table = self .draft_extend_metadata ["page_table" ][:bs , :]
1568
1564
1569
1565
self .draft_extend_metadata [bs ] = metadata
1570
1566
@@ -1578,7 +1574,7 @@ def init_forward_metadata_capture_cuda_graph(
1578
1574
][: (encoder_bs + 1 )]
1579
1575
1580
1576
metadata .encoder_page_table = self .encoder_metadata ["encoder_page_table" ][
1581
- req_pool_indices , :
1577
+ : bs , :
1582
1578
]
1583
1579
1584
1580
self .forward_metadata = metadata
0 commit comments