Skip to content
62 changes: 39 additions & 23 deletions python/sglang/srt/managers/schedule_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@
)


IGNORE_EOS_RESERVE_TOKENS = 1


class CacheAwarePolicy(Enum):
"""Scheduling policies that are aware of the tree cache."""

Expand Down Expand Up @@ -293,6 +296,7 @@ def __init__(
self.can_run_list = []
self.new_chunked_req = None
self.log_hit_tokens = 0
# TODO(lsyin): report the real input tokens excluding page alignment
self.log_input_tokens = 0

if running_batch is not None:
Expand Down Expand Up @@ -323,6 +327,9 @@ def cur_rem_tokens(self):
- self.cur_rem_token_offset
)

def ceil_paged_tokens(self, tokens: int) -> int:
return -(-tokens // self.page_size) * self.page_size

def budget_state(self):
if self.rem_total_tokens <= 0 or self.cur_rem_tokens <= 0:
return AddReqResult.NO_TOKEN
Expand All @@ -334,9 +341,12 @@ def budget_state(self):

return AddReqResult.CONTINUE

def _prefill_one_req(
def _update_prefill_budget(
self, prefix_len: int, extend_input_len: int, max_new_tokens: int
):
# TODO(lsyin): check this workaround logic, which only ensures the prefill will not out of memory, and may be too conservative
extend_input_len = self.ceil_paged_tokens(extend_input_len)

self.rem_total_token_offset += extend_input_len + max_new_tokens
self.cur_rem_token_offset += extend_input_len
self.rem_input_tokens -= extend_input_len
Expand All @@ -351,7 +361,7 @@ def add_chunked_req(self, req: Req):
req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
self.can_run_list.append(req)
self._prefill_one_req(
self._update_prefill_budget(
0,
req.extend_input_len,
(
Expand All @@ -373,6 +383,12 @@ def _lock_node(self, last_node: TreeNode):
self.tree_cache.dec_lock_ref(last_node)

def add_one_req_ignore_eos(self, req: Req, has_chunked_req: bool):
# Early exit if no enough tokens for the input tokens
if self.ceil_paged_tokens(req.extend_input_len) > min(
self.cur_rem_tokens, self.rem_total_tokens
):
return AddReqResult.NO_TOKEN

def add_req_state(r, insert_sort=False):
new_token_ratio = (
1.0 if r.sampling_params.ignore_eos else self.new_token_ratio
Expand All @@ -382,15 +398,17 @@ def add_req_state(r, insert_sort=False):
)
tokens_occupied = len(r.origin_input_ids) + len(r.output_ids)

if tokens_left > 0:
if not insert_sort:
self.req_states.append((tokens_left, tokens_occupied))
else:
i = 0
for i in range(len(self.req_states)):
if tokens_left <= self.req_states[i][0]:
break
self.req_states.insert(i, (tokens_left, tokens_occupied))
if tokens_left <= 0:
return

if not insert_sort:
self.req_states.append((tokens_left, tokens_occupied))
else:
i = 0
for i in range(len(self.req_states)):
if tokens_left <= self.req_states[i][0]:
break
self.req_states.insert(i, (tokens_left, tokens_occupied))

if self.req_states is None:
self.req_states = []
Expand All @@ -407,13 +425,11 @@ def add_req_state(r, insert_sort=False):
cur_rem_tokens = self.cur_rem_tokens - len(req.origin_input_ids)
tokens_freed = 0
for i, (tokens_left, tokens_occupied) in enumerate(self.req_states):
decode_steps = (
self.req_states[i + 1][0]
if i + 1 < len(self.req_states)
else tokens_left
)
# tokens_left gives a reservative calculation as the last token is not stored
bs = len(self.req_states) - i
if cur_rem_tokens + tokens_freed - decode_steps * bs <= 0:
min_free_tokens = cur_rem_tokens + tokens_freed - tokens_left * bs
# reserve tokens for corner cases
if min_free_tokens <= IGNORE_EOS_RESERVE_TOKENS * bs:
return AddReqResult.NO_TOKEN
tokens_freed += tokens_occupied

Expand All @@ -423,7 +439,7 @@ def add_req_state(r, insert_sort=False):
):
# Non-chunked prefill
self.can_run_list.append(req)
self._prefill_one_req(
self._update_prefill_budget(
0,
req.extend_input_len,
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION),
Expand All @@ -439,7 +455,7 @@ def add_req_state(r, insert_sort=False):
req.fill_ids = req.fill_ids[:trunc_len]
self.can_run_list.append(req)
self.new_chunked_req = req
self._prefill_one_req(0, trunc_len, 0)
self._update_prefill_budget(0, trunc_len, 0)

return self.budget_state()

Expand All @@ -453,7 +469,7 @@ def add_one_req(self, req: Req, has_chunked_req: bool):

# adjusting the input_tokens based on host_hit_length and page_size
real_input_tokens = req.extend_input_len - req.host_hit_length
real_input_tokens = -(-real_input_tokens // self.page_size) * self.page_size
real_input_tokens = self.ceil_paged_tokens(real_input_tokens)
prefix_len = len(req.prefix_indices)

if total_tokens >= self.rem_total_tokens:
Expand All @@ -475,7 +491,7 @@ def add_one_req(self, req: Req, has_chunked_req: bool):
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
prefix_len = len(req.prefix_indices)

input_tokens = -(-req.extend_input_len // self.page_size) * self.page_size
input_tokens = self.ceil_paged_tokens(req.extend_input_len)

if input_tokens >= self.rem_input_tokens and len(self.can_run_list) != 0:
return AddReqResult.OTHER
Expand All @@ -484,7 +500,7 @@ def add_one_req(self, req: Req, has_chunked_req: bool):
# Non-chunked prefill
self.can_run_list.append(req)
self.tree_cache.inc_lock_ref(req.last_node)
self._prefill_one_req(
self._update_prefill_budget(
prefix_len,
input_tokens,
min(
Expand All @@ -505,6 +521,6 @@ def add_one_req(self, req: Req, has_chunked_req: bool):
self.can_run_list.append(req)
self.new_chunked_req = req
self.tree_cache.inc_lock_ref(req.last_node)
self._prefill_one_req(prefix_len, trunc_len, 0)
self._update_prefill_budget(prefix_len, trunc_len, 0)

return self.budget_state()
Loading