@@ -1185,23 +1185,28 @@ def handle_generate_request(
1185
1185
def _add_request_to_queue (self , req : Req ):
1186
1186
req .queue_time_start = time .perf_counter ()
1187
1187
if self .disaggregation_mode == DisaggregationMode .PREFILL :
1188
+ self ._prefetch_kvcache (req )
1188
1189
self .disagg_prefill_bootstrap_queue .add (
1189
1190
req , self .model_config .num_key_value_heads
1190
1191
)
1191
1192
elif self .disaggregation_mode == DisaggregationMode .DECODE :
1192
1193
self .disagg_decode_prealloc_queue .add (req )
1193
1194
else :
1194
- if self .enable_hicache_storage :
1195
- req .init_next_round_input (self .tree_cache )
1196
- last_hash = req .last_host_node .get_last_hash_value ()
1197
- matched_len = len (req .prefix_indices ) + req .host_hit_length
1198
- if (matched_len > 0 and last_hash is not None ) or matched_len == 0 :
1199
- new_input_tokens = req .fill_ids [matched_len :]
1200
- self .tree_cache .prefetch_from_storage (
1201
- req .rid , req .last_host_node , new_input_tokens , last_hash
1202
- )
1195
+ self ._prefetch_kvcache (req )
1203
1196
self .waiting_queue .append (req )
1204
1197
1198
+ def _prefetch_kvcache (self , req : Req ):
1199
+ if self .enable_hicache_storage :
1200
+ req .init_next_round_input (self .tree_cache )
1201
+ last_hash = req .last_host_node .get_last_hash_value ()
1202
+ matched_len = len (req .prefix_indices ) + req .host_hit_length
1203
+ # todo, free-form fetching, calculating hash keys on the fly
1204
+ if (matched_len > 0 and last_hash is not None ) or matched_len == 0 :
1205
+ new_input_tokens = req .fill_ids [matched_len :]
1206
+ self .tree_cache .prefetch_from_storage (
1207
+ req .rid , req .last_host_node , new_input_tokens , last_hash
1208
+ )
1209
+
1205
1210
def _extend_requests_to_queue (self , reqs : List [Req ], is_retracted : bool = False ):
1206
1211
if self .disaggregation_mode == DisaggregationMode .PREFILL :
1207
1212
self .disagg_prefill_bootstrap_queue .extend (
0 commit comments