Skip to content

Commit e17a745

Browse files
hnyls2002shuaills
authored andcommitted
Fix prefill OOM due to wrong token calculation when page > 1 (sgl-project#7397)
1 parent 509db3c commit e17a745

File tree

1 file changed

+39
-23
lines changed

1 file changed

+39
-23
lines changed

python/sglang/srt/managers/schedule_policy.py

Lines changed: 39 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@
5555
)
5656

5757

58+
IGNORE_EOS_RESERVE_TOKENS = 1
59+
60+
5861
class CacheAwarePolicy(Enum):
5962
"""Scheduling policies that are aware of the tree cache."""
6063

@@ -293,6 +296,7 @@ def __init__(
293296
self.can_run_list = []
294297
self.new_chunked_req = None
295298
self.log_hit_tokens = 0
299+
# TODO(lsyin): report the real input tokens excluding page alignment
296300
self.log_input_tokens = 0
297301

298302
if running_batch is not None:
@@ -323,6 +327,9 @@ def cur_rem_tokens(self):
323327
- self.cur_rem_token_offset
324328
)
325329

330+
def ceil_paged_tokens(self, tokens: int) -> int:
331+
return -(-tokens // self.page_size) * self.page_size
332+
326333
def budget_state(self):
327334
if self.rem_total_tokens <= 0 or self.cur_rem_tokens <= 0:
328335
return AddReqResult.NO_TOKEN
@@ -334,9 +341,12 @@ def budget_state(self):
334341

335342
return AddReqResult.CONTINUE
336343

337-
def _prefill_one_req(
344+
def _update_prefill_budget(
338345
self, prefix_len: int, extend_input_len: int, max_new_tokens: int
339346
):
347+
# TODO(lsyin): check this workaround logic, which only ensures the prefill will not out of memory, and may be too conservative
348+
extend_input_len = self.ceil_paged_tokens(extend_input_len)
349+
340350
self.rem_total_token_offset += extend_input_len + max_new_tokens
341351
self.cur_rem_token_offset += extend_input_len
342352
self.rem_input_tokens -= extend_input_len
@@ -351,7 +361,7 @@ def add_chunked_req(self, req: Req):
351361
req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
352362
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
353363
self.can_run_list.append(req)
354-
self._prefill_one_req(
364+
self._update_prefill_budget(
355365
0,
356366
req.extend_input_len,
357367
(
@@ -373,6 +383,12 @@ def _lock_node(self, last_node: TreeNode):
373383
self.tree_cache.dec_lock_ref(last_node)
374384

375385
def add_one_req_ignore_eos(self, req: Req, has_chunked_req: bool):
386+
# Early exit if no enough tokens for the input tokens
387+
if self.ceil_paged_tokens(req.extend_input_len) > min(
388+
self.cur_rem_tokens, self.rem_total_tokens
389+
):
390+
return AddReqResult.NO_TOKEN
391+
376392
def add_req_state(r, insert_sort=False):
377393
new_token_ratio = (
378394
1.0 if r.sampling_params.ignore_eos else self.new_token_ratio
@@ -382,15 +398,17 @@ def add_req_state(r, insert_sort=False):
382398
)
383399
tokens_occupied = len(r.origin_input_ids) + len(r.output_ids)
384400

385-
if tokens_left > 0:
386-
if not insert_sort:
387-
self.req_states.append((tokens_left, tokens_occupied))
388-
else:
389-
i = 0
390-
for i in range(len(self.req_states)):
391-
if tokens_left <= self.req_states[i][0]:
392-
break
393-
self.req_states.insert(i, (tokens_left, tokens_occupied))
401+
if tokens_left <= 0:
402+
return
403+
404+
if not insert_sort:
405+
self.req_states.append((tokens_left, tokens_occupied))
406+
else:
407+
i = 0
408+
for i in range(len(self.req_states)):
409+
if tokens_left <= self.req_states[i][0]:
410+
break
411+
self.req_states.insert(i, (tokens_left, tokens_occupied))
394412

395413
if self.req_states is None:
396414
self.req_states = []
@@ -407,13 +425,11 @@ def add_req_state(r, insert_sort=False):
407425
cur_rem_tokens = self.cur_rem_tokens - len(req.origin_input_ids)
408426
tokens_freed = 0
409427
for i, (tokens_left, tokens_occupied) in enumerate(self.req_states):
410-
decode_steps = (
411-
self.req_states[i + 1][0]
412-
if i + 1 < len(self.req_states)
413-
else tokens_left
414-
)
428+
# tokens_left gives a reservative calculation as the last token is not stored
415429
bs = len(self.req_states) - i
416-
if cur_rem_tokens + tokens_freed - decode_steps * bs <= 0:
430+
min_free_tokens = cur_rem_tokens + tokens_freed - tokens_left * bs
431+
# reserve tokens for corner cases
432+
if min_free_tokens <= IGNORE_EOS_RESERVE_TOKENS * bs:
417433
return AddReqResult.NO_TOKEN
418434
tokens_freed += tokens_occupied
419435

@@ -423,7 +439,7 @@ def add_req_state(r, insert_sort=False):
423439
):
424440
# Non-chunked prefill
425441
self.can_run_list.append(req)
426-
self._prefill_one_req(
442+
self._update_prefill_budget(
427443
0,
428444
req.extend_input_len,
429445
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION),
@@ -439,7 +455,7 @@ def add_req_state(r, insert_sort=False):
439455
req.fill_ids = req.fill_ids[:trunc_len]
440456
self.can_run_list.append(req)
441457
self.new_chunked_req = req
442-
self._prefill_one_req(0, trunc_len, 0)
458+
self._update_prefill_budget(0, trunc_len, 0)
443459

444460
return self.budget_state()
445461

@@ -453,7 +469,7 @@ def add_one_req(self, req: Req, has_chunked_req: bool):
453469

454470
# adjusting the input_tokens based on host_hit_length and page_size
455471
real_input_tokens = req.extend_input_len - req.host_hit_length
456-
real_input_tokens = -(-real_input_tokens // self.page_size) * self.page_size
472+
real_input_tokens = self.ceil_paged_tokens(real_input_tokens)
457473
prefix_len = len(req.prefix_indices)
458474

459475
if total_tokens >= self.rem_total_tokens:
@@ -475,7 +491,7 @@ def add_one_req(self, req: Req, has_chunked_req: bool):
475491
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
476492
prefix_len = len(req.prefix_indices)
477493

478-
input_tokens = -(-req.extend_input_len // self.page_size) * self.page_size
494+
input_tokens = self.ceil_paged_tokens(req.extend_input_len)
479495

480496
if input_tokens >= self.rem_input_tokens and len(self.can_run_list) != 0:
481497
return AddReqResult.OTHER
@@ -484,7 +500,7 @@ def add_one_req(self, req: Req, has_chunked_req: bool):
484500
# Non-chunked prefill
485501
self.can_run_list.append(req)
486502
self.tree_cache.inc_lock_ref(req.last_node)
487-
self._prefill_one_req(
503+
self._update_prefill_budget(
488504
prefix_len,
489505
input_tokens,
490506
min(
@@ -505,6 +521,6 @@ def add_one_req(self, req: Req, has_chunked_req: bool):
505521
self.can_run_list.append(req)
506522
self.new_chunked_req = req
507523
self.tree_cache.inc_lock_ref(req.last_node)
508-
self._prefill_one_req(prefix_len, trunc_len, 0)
524+
self._update_prefill_budget(prefix_len, trunc_len, 0)
509525

510526
return self.budget_state()

0 commit comments

Comments
 (0)