Skip to content

Commit 01e6b45

Browse files
ch-wannarutolhy
authored andcommitted
Save peak memory in logits processor (sgl-project#8343)
1 parent 8defc87 commit 01e6b45

File tree

1 file changed

+14
-3
lines changed

1 file changed

+14
-3
lines changed

python/sglang/srt/layers/logits_processor.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -174,8 +174,6 @@ def from_forward_batch(cls, forward_batch: ForwardBatch):
174174
)
175175

176176
def compute_dp_attention_metadata(self):
177-
# TODO(ch-wan): gathered_buffer here is larger than the actual required size in draft extend,
178-
# we may use a smaller buffer in draft extend.
179177

180178
cumtokens = torch.cumsum(self.global_num_tokens_for_logprob_gpu, dim=0)
181179
dp_rank = get_attention_dp_rank()
@@ -190,6 +188,19 @@ def compute_dp_attention_metadata(self):
190188
self.dp_local_start_pos = dp_local_start_pos
191189
self.dp_local_num_tokens = dp_local_num_tokens
192190

191+
if self.global_num_tokens_for_logprob_cpu is not None:
192+
# create a smaller buffer to reduce peak memory usage
193+
self.gathered_buffer = torch.empty(
194+
(
195+
sum(self.global_num_tokens_for_logprob_cpu),
196+
self.gathered_buffer.shape[1],
197+
),
198+
dtype=self.gathered_buffer.dtype,
199+
device=self.gathered_buffer.device,
200+
)
201+
else:
202+
self.gathered_buffer = torch.empty_like(self.gathered_buffer)
203+
193204

194205
class LogitsProcessor(nn.Module):
195206
def __init__(
@@ -434,7 +445,7 @@ def _get_logits(
434445
if self.do_tensor_parallel_all_gather_dp_attn:
435446
logits_metadata.compute_dp_attention_metadata()
436447
hidden_states, local_hidden_states = (
437-
torch.empty_like(logits_metadata.gathered_buffer),
448+
logits_metadata.gathered_buffer,
438449
hidden_states,
439450
)
440451
dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata)

0 commit comments

Comments
 (0)