@@ -174,8 +174,6 @@ def from_forward_batch(cls, forward_batch: ForwardBatch):
174
174
)
175
175
176
176
def compute_dp_attention_metadata (self ):
177
- # TODO(ch-wan): gathered_buffer here is larger than the actual required size in draft extend,
178
- # we may use a smaller buffer in draft extend.
179
177
180
178
cumtokens = torch .cumsum (self .global_num_tokens_for_logprob_gpu , dim = 0 )
181
179
dp_rank = get_attention_dp_rank ()
@@ -190,6 +188,19 @@ def compute_dp_attention_metadata(self):
190
188
self .dp_local_start_pos = dp_local_start_pos
191
189
self .dp_local_num_tokens = dp_local_num_tokens
192
190
191
+ if self .global_num_tokens_for_logprob_cpu is not None :
192
+ # create a smaller buffer to reduce peak memory usage
193
+ self .gathered_buffer = torch .empty (
194
+ (
195
+ sum (self .global_num_tokens_for_logprob_cpu ),
196
+ self .gathered_buffer .shape [1 ],
197
+ ),
198
+ dtype = self .gathered_buffer .dtype ,
199
+ device = self .gathered_buffer .device ,
200
+ )
201
+ else :
202
+ self .gathered_buffer = torch .empty_like (self .gathered_buffer )
203
+
193
204
194
205
class LogitsProcessor (nn .Module ):
195
206
def __init__ (
@@ -434,7 +445,7 @@ def _get_logits(
434
445
if self .do_tensor_parallel_all_gather_dp_attn :
435
446
logits_metadata .compute_dp_attention_metadata ()
436
447
hidden_states , local_hidden_states = (
437
- torch . empty_like ( logits_metadata .gathered_buffer ) ,
448
+ logits_metadata .gathered_buffer ,
438
449
hidden_states ,
439
450
)
440
451
dp_gather_replicate (hidden_states , local_hidden_states , logits_metadata )
0 commit comments