Skip to content

Commit 2fbb754

Browse files
hzh0425ShangmingCaixiezhq-hermann
authored
feature(pd-hicache): Prefill instances support reusing the RemoteStorage Cache via HiCache. (#8516)
Co-authored-by: Shangming Cai <[email protected]> Co-authored-by: Zhiqiang Xie <[email protected]>
1 parent a85ebf5 commit 2fbb754

File tree

1 file changed

+14
-9
lines changed

1 file changed

+14
-9
lines changed

python/sglang/srt/managers/scheduler.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1185,23 +1185,28 @@ def handle_generate_request(
11851185
def _add_request_to_queue(self, req: Req):
11861186
req.queue_time_start = time.perf_counter()
11871187
if self.disaggregation_mode == DisaggregationMode.PREFILL:
1188+
self._prefetch_kvcache(req)
11881189
self.disagg_prefill_bootstrap_queue.add(
11891190
req, self.model_config.num_key_value_heads
11901191
)
11911192
elif self.disaggregation_mode == DisaggregationMode.DECODE:
11921193
self.disagg_decode_prealloc_queue.add(req)
11931194
else:
1194-
if self.enable_hicache_storage:
1195-
req.init_next_round_input(self.tree_cache)
1196-
last_hash = req.last_host_node.get_last_hash_value()
1197-
matched_len = len(req.prefix_indices) + req.host_hit_length
1198-
if (matched_len > 0 and last_hash is not None) or matched_len == 0:
1199-
new_input_tokens = req.fill_ids[matched_len:]
1200-
self.tree_cache.prefetch_from_storage(
1201-
req.rid, req.last_host_node, new_input_tokens, last_hash
1202-
)
1195+
self._prefetch_kvcache(req)
12031196
self.waiting_queue.append(req)
12041197

1198+
def _prefetch_kvcache(self, req: Req):
1199+
if self.enable_hicache_storage:
1200+
req.init_next_round_input(self.tree_cache)
1201+
last_hash = req.last_host_node.get_last_hash_value()
1202+
matched_len = len(req.prefix_indices) + req.host_hit_length
1203+
# todo, free-form fetching, calculating hash keys on the fly
1204+
if (matched_len > 0 and last_hash is not None) or matched_len == 0:
1205+
new_input_tokens = req.fill_ids[matched_len:]
1206+
self.tree_cache.prefetch_from_storage(
1207+
req.rid, req.last_host_node, new_input_tokens, last_hash
1208+
)
1209+
12051210
def _extend_requests_to_queue(self, reqs: List[Req], is_retracted: bool = False):
12061211
if self.disaggregation_mode == DisaggregationMode.PREFILL:
12071212
self.disagg_prefill_bootstrap_queue.extend(

0 commit comments

Comments
 (0)