@@ -404,14 +404,24 @@ def _gather_hidden_states_and_residual(
404
404
if context .attn_dp_size != 1 :
405
405
if context .attn_tp_rank == 0 :
406
406
hidden_states += residual
407
- residual .copy_ (hidden_states )
408
- if hidden_states .shape [0 ] != 0 :
409
- hidden_states = layernorm (hidden_states )
407
+
408
+ # Perform layernorm on smaller data before comm. Only valid when attn_tp_size is 1 (tp_size == dp_size)
409
+ use_layer_norm_before_gather = context .attn_tp_size == 1
410
+ if use_layer_norm_before_gather :
411
+ residual .copy_ (hidden_states )
412
+ if hidden_states .shape [0 ] != 0 :
413
+ hidden_states = layernorm (hidden_states )
414
+
410
415
hidden_states , local_hidden_states = (
411
416
forward_batch .gathered_buffer ,
412
417
hidden_states ,
413
418
)
414
419
dp_gather_partial (hidden_states , local_hidden_states , forward_batch )
420
+
421
+ if not use_layer_norm_before_gather :
422
+ dp_scatter (residual , hidden_states , forward_batch )
423
+ if hidden_states .shape [0 ] != 0 :
424
+ hidden_states = layernorm (hidden_states )
415
425
else :
416
426
# According to the discussion in https://github.com/flashinfer-ai/flashinfer/issues/1223#issuecomment-3047256465
417
427
# We set the max token num to 128 for allreduce fusion with min-latency case(use_oneshot=True).
0 commit comments