@@ -404,14 +404,24 @@ def _gather_hidden_states_and_residual(
404
404
if context .attn_dp_size != 1 :
405
405
if context .attn_tp_rank == 0 :
406
406
hidden_states += residual
407
+
408
+ # Perform layernorm on smaller data before comm. Only valid when attn_tp_size is 1 (tp_size == dp_size)
409
+ use_layer_norm_before_gather = context .attn_tp_size == 1
410
+ if use_layer_norm_before_gather :
411
+ residual .copy_ (hidden_states )
412
+ if hidden_states .shape [0 ] != 0 :
413
+ hidden_states = layernorm (hidden_states )
414
+
407
415
hidden_states , local_hidden_states = (
408
416
forward_batch .gathered_buffer ,
409
417
hidden_states ,
410
418
)
411
419
dp_gather_partial (hidden_states , local_hidden_states , forward_batch )
412
- dp_scatter (residual , hidden_states , forward_batch )
413
- if hidden_states .shape [0 ] != 0 :
414
- hidden_states = layernorm (hidden_states )
420
+
421
+ if not use_layer_norm_before_gather :
422
+ dp_scatter (residual , hidden_states , forward_batch )
423
+ if hidden_states .shape [0 ] != 0 :
424
+ hidden_states = layernorm (hidden_states )
415
425
else :
416
426
# According to the discussion in https://github.com/flashinfer-ai/flashinfer/issues/1223#issuecomment-3047256465
417
427
# We set the max token num to 128 for allreduce fusion with min-latency case(use_oneshot=True).
0 commit comments