Skip to content

Commit 99ec675

Browse files
ch-wanssssnow
authored andcommitted
Save peak memory in logits processor (#8343)
1 parent 2de6d42 commit 99ec675

File tree

1 file changed

+14
-3
lines changed

1 file changed

+14
-3
lines changed

python/sglang/srt/layers/logits_processor.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,6 @@ def from_forward_batch(cls, forward_batch: ForwardBatch):
170170
)
171171

172172
def compute_dp_attention_metadata(self):
173-
# TODO(ch-wan): gathered_buffer here is larger than the actual required size in draft extend,
174-
# we may use a smaller buffer in draft extend.
175173

176174
cumtokens = torch.cumsum(self.global_num_tokens_for_logprob_gpu, dim=0)
177175
dp_rank = get_attention_dp_rank()
@@ -186,6 +184,19 @@ def compute_dp_attention_metadata(self):
186184
self.dp_local_start_pos = dp_local_start_pos
187185
self.dp_local_num_tokens = dp_local_num_tokens
188186

187+
if self.global_num_tokens_for_logprob_cpu is not None:
188+
# create a smaller buffer to reduce peak memory usage
189+
self.gathered_buffer = torch.empty(
190+
(
191+
sum(self.global_num_tokens_for_logprob_cpu),
192+
self.gathered_buffer.shape[1],
193+
),
194+
dtype=self.gathered_buffer.dtype,
195+
device=self.gathered_buffer.device,
196+
)
197+
else:
198+
self.gathered_buffer = torch.empty_like(self.gathered_buffer)
199+
189200

190201
class LogitsProcessor(nn.Module):
191202
def __init__(
@@ -430,7 +441,7 @@ def _get_logits(
430441
if self.do_tensor_parallel_all_gather_dp_attn:
431442
logits_metadata.compute_dp_attention_metadata()
432443
hidden_states, local_hidden_states = (
433-
torch.empty_like(logits_metadata.gathered_buffer),
444+
logits_metadata.gathered_buffer,
434445
hidden_states,
435446
)
436447
dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata)

0 commit comments

Comments
 (0)