@@ -170,8 +170,6 @@ def from_forward_batch(cls, forward_batch: ForwardBatch):
170
170
)
171
171
172
172
def compute_dp_attention_metadata (self ):
173
- # TODO(ch-wan): gathered_buffer here is larger than the actual required size in draft extend,
174
- # we may use a smaller buffer in draft extend.
175
173
176
174
cumtokens = torch .cumsum (self .global_num_tokens_for_logprob_gpu , dim = 0 )
177
175
dp_rank = get_attention_dp_rank ()
@@ -186,6 +184,19 @@ def compute_dp_attention_metadata(self):
186
184
self .dp_local_start_pos = dp_local_start_pos
187
185
self .dp_local_num_tokens = dp_local_num_tokens
188
186
187
+ if self .global_num_tokens_for_logprob_cpu is not None :
188
+ # create a smaller buffer to reduce peak memory usage
189
+ self .gathered_buffer = torch .empty (
190
+ (
191
+ sum (self .global_num_tokens_for_logprob_cpu ),
192
+ self .gathered_buffer .shape [1 ],
193
+ ),
194
+ dtype = self .gathered_buffer .dtype ,
195
+ device = self .gathered_buffer .device ,
196
+ )
197
+ else :
198
+ self .gathered_buffer = torch .empty_like (self .gathered_buffer )
199
+
189
200
190
201
class LogitsProcessor (nn .Module ):
191
202
def __init__ (
@@ -430,7 +441,7 @@ def _get_logits(
430
441
if self .do_tensor_parallel_all_gather_dp_attn :
431
442
logits_metadata .compute_dp_attention_metadata ()
432
443
hidden_states , local_hidden_states = (
433
- torch . empty_like ( logits_metadata .gathered_buffer ) ,
444
+ logits_metadata .gathered_buffer ,
434
445
hidden_states ,
435
446
)
436
447
dp_gather_replicate (hidden_states , local_hidden_states , logits_metadata )
0 commit comments