55
55
)
56
56
57
57
58
+ IGNORE_EOS_RESERVE_TOKENS = 1
59
+
60
+
58
61
class CacheAwarePolicy (Enum ):
59
62
"""Scheduling policies that are aware of the tree cache."""
60
63
@@ -293,6 +296,7 @@ def __init__(
293
296
self .can_run_list = []
294
297
self .new_chunked_req = None
295
298
self .log_hit_tokens = 0
299
+ # TODO(lsyin): report the real input tokens excluding page alignment
296
300
self .log_input_tokens = 0
297
301
298
302
if running_batch is not None :
@@ -323,6 +327,9 @@ def cur_rem_tokens(self):
323
327
- self .cur_rem_token_offset
324
328
)
325
329
330
+ def ceil_paged_tokens (self , tokens : int ) -> int :
331
+ return - (- tokens // self .page_size ) * self .page_size
332
+
326
333
def budget_state (self ):
327
334
if self .rem_total_tokens <= 0 or self .cur_rem_tokens <= 0 :
328
335
return AddReqResult .NO_TOKEN
@@ -334,9 +341,12 @@ def budget_state(self):
334
341
335
342
return AddReqResult .CONTINUE
336
343
337
- def _prefill_one_req (
344
+ def _update_prefill_budget (
338
345
self , prefix_len : int , extend_input_len : int , max_new_tokens : int
339
346
):
347
+ # TODO(lsyin): check this workaround logic, which only ensures the prefill will not out of memory, and may be too conservative
348
+ extend_input_len = self .ceil_paged_tokens (extend_input_len )
349
+
340
350
self .rem_total_token_offset += extend_input_len + max_new_tokens
341
351
self .cur_rem_token_offset += extend_input_len
342
352
self .rem_input_tokens -= extend_input_len
@@ -351,7 +361,7 @@ def add_chunked_req(self, req: Req):
351
361
req .extend_input_len = min (req .extend_input_len , self .rem_chunk_tokens )
352
362
req .fill_ids = req .fill_ids [: len (req .prefix_indices ) + req .extend_input_len ]
353
363
self .can_run_list .append (req )
354
- self ._prefill_one_req (
364
+ self ._update_prefill_budget (
355
365
0 ,
356
366
req .extend_input_len ,
357
367
(
@@ -373,6 +383,12 @@ def _lock_node(self, last_node: TreeNode):
373
383
self .tree_cache .dec_lock_ref (last_node )
374
384
375
385
def add_one_req_ignore_eos (self , req : Req , has_chunked_req : bool ):
386
+ # Early exit if no enough tokens for the input tokens
387
+ if self .ceil_paged_tokens (req .extend_input_len ) > min (
388
+ self .cur_rem_tokens , self .rem_total_tokens
389
+ ):
390
+ return AddReqResult .NO_TOKEN
391
+
376
392
def add_req_state (r , insert_sort = False ):
377
393
new_token_ratio = (
378
394
1.0 if r .sampling_params .ignore_eos else self .new_token_ratio
@@ -382,15 +398,17 @@ def add_req_state(r, insert_sort=False):
382
398
)
383
399
tokens_occupied = len (r .origin_input_ids ) + len (r .output_ids )
384
400
385
- if tokens_left > 0 :
386
- if not insert_sort :
387
- self .req_states .append ((tokens_left , tokens_occupied ))
388
- else :
389
- i = 0
390
- for i in range (len (self .req_states )):
391
- if tokens_left <= self .req_states [i ][0 ]:
392
- break
393
- self .req_states .insert (i , (tokens_left , tokens_occupied ))
401
+ if tokens_left <= 0 :
402
+ return
403
+
404
+ if not insert_sort :
405
+ self .req_states .append ((tokens_left , tokens_occupied ))
406
+ else :
407
+ i = 0
408
+ for i in range (len (self .req_states )):
409
+ if tokens_left <= self .req_states [i ][0 ]:
410
+ break
411
+ self .req_states .insert (i , (tokens_left , tokens_occupied ))
394
412
395
413
if self .req_states is None :
396
414
self .req_states = []
@@ -407,13 +425,11 @@ def add_req_state(r, insert_sort=False):
407
425
cur_rem_tokens = self .cur_rem_tokens - len (req .origin_input_ids )
408
426
tokens_freed = 0
409
427
for i , (tokens_left , tokens_occupied ) in enumerate (self .req_states ):
410
- decode_steps = (
411
- self .req_states [i + 1 ][0 ]
412
- if i + 1 < len (self .req_states )
413
- else tokens_left
414
- )
428
+ # tokens_left gives a reservative calculation as the last token is not stored
415
429
bs = len (self .req_states ) - i
416
- if cur_rem_tokens + tokens_freed - decode_steps * bs <= 0 :
430
+ min_free_tokens = cur_rem_tokens + tokens_freed - tokens_left * bs
431
+ # reserve tokens for corner cases
432
+ if min_free_tokens <= IGNORE_EOS_RESERVE_TOKENS * bs :
417
433
return AddReqResult .NO_TOKEN
418
434
tokens_freed += tokens_occupied
419
435
@@ -423,7 +439,7 @@ def add_req_state(r, insert_sort=False):
423
439
):
424
440
# Non-chunked prefill
425
441
self .can_run_list .append (req )
426
- self ._prefill_one_req (
442
+ self ._update_prefill_budget (
427
443
0 ,
428
444
req .extend_input_len ,
429
445
min (req .sampling_params .max_new_tokens , CLIP_MAX_NEW_TOKENS_ESTIMATION ),
@@ -439,7 +455,7 @@ def add_req_state(r, insert_sort=False):
439
455
req .fill_ids = req .fill_ids [:trunc_len ]
440
456
self .can_run_list .append (req )
441
457
self .new_chunked_req = req
442
- self ._prefill_one_req (0 , trunc_len , 0 )
458
+ self ._update_prefill_budget (0 , trunc_len , 0 )
443
459
444
460
return self .budget_state ()
445
461
@@ -453,7 +469,7 @@ def add_one_req(self, req: Req, has_chunked_req: bool):
453
469
454
470
# adjusting the input_tokens based on host_hit_length and page_size
455
471
real_input_tokens = req .extend_input_len - req .host_hit_length
456
- real_input_tokens = - ( - real_input_tokens // self .page_size ) * self . page_size
472
+ real_input_tokens = self .ceil_paged_tokens ( real_input_tokens )
457
473
prefix_len = len (req .prefix_indices )
458
474
459
475
if total_tokens >= self .rem_total_tokens :
@@ -475,7 +491,7 @@ def add_one_req(self, req: Req, has_chunked_req: bool):
475
491
req .extend_input_len = len (req .fill_ids ) - len (req .prefix_indices )
476
492
prefix_len = len (req .prefix_indices )
477
493
478
- input_tokens = - ( - req .extend_input_len // self . page_size ) * self . page_size
494
+ input_tokens = self . ceil_paged_tokens ( req .extend_input_len )
479
495
480
496
if input_tokens >= self .rem_input_tokens and len (self .can_run_list ) != 0 :
481
497
return AddReqResult .OTHER
@@ -484,7 +500,7 @@ def add_one_req(self, req: Req, has_chunked_req: bool):
484
500
# Non-chunked prefill
485
501
self .can_run_list .append (req )
486
502
self .tree_cache .inc_lock_ref (req .last_node )
487
- self ._prefill_one_req (
503
+ self ._update_prefill_budget (
488
504
prefix_len ,
489
505
input_tokens ,
490
506
min (
@@ -505,6 +521,6 @@ def add_one_req(self, req: Req, has_chunked_req: bool):
505
521
self .can_run_list .append (req )
506
522
self .new_chunked_req = req
507
523
self .tree_cache .inc_lock_ref (req .last_node )
508
- self ._prefill_one_req (prefix_len , trunc_len , 0 )
524
+ self ._update_prefill_budget (prefix_len , trunc_len , 0 )
509
525
510
526
return self .budget_state ()
0 commit comments