Skip to content

Commit de592a5

Browse files
trevor-mssssnow
authored andcommitted
Do layernorm before allgather for DP attention (#8631)
1 parent 0bdca82 commit de592a5

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

python/sglang/srt/layers/communicator.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -404,14 +404,24 @@ def _gather_hidden_states_and_residual(
404404
if context.attn_dp_size != 1:
405405
if context.attn_tp_rank == 0:
406406
hidden_states += residual
407+
408+
# Perform layernorm on smaller data before comm. Only valid when attn_tp_size is 1 (tp_size == dp_size)
409+
use_layer_norm_before_gather = context.attn_tp_size == 1
410+
if use_layer_norm_before_gather:
411+
residual.copy_(hidden_states)
412+
if hidden_states.shape[0] != 0:
413+
hidden_states = layernorm(hidden_states)
414+
407415
hidden_states, local_hidden_states = (
408416
forward_batch.gathered_buffer,
409417
hidden_states,
410418
)
411419
dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
412-
dp_scatter(residual, hidden_states, forward_batch)
413-
if hidden_states.shape[0] != 0:
414-
hidden_states = layernorm(hidden_states)
420+
421+
if not use_layer_norm_before_gather:
422+
dp_scatter(residual, hidden_states, forward_batch)
423+
if hidden_states.shape[0] != 0:
424+
hidden_states = layernorm(hidden_states)
415425
else:
416426
# According to the discussion in https://github.com/flashinfer-ai/flashinfer/issues/1223#issuecomment-3047256465
417427
# We set the max token num to 128 for allreduce fusion with min-latency case(use_oneshot=True).

0 commit comments

Comments
 (0)