Skip to content

Commit aab764f

Browse files
committed
Only use when tp==dp
1 parent 2d40945 commit aab764f

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

python/sglang/srt/layers/communicator.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -404,14 +404,24 @@ def _gather_hidden_states_and_residual(
404404
if context.attn_dp_size != 1:
405405
if context.attn_tp_rank == 0:
406406
hidden_states += residual
407-
residual.copy_(hidden_states)
408-
if hidden_states.shape[0] != 0:
409-
hidden_states = layernorm(hidden_states)
407+
408+
# Perform layernorm on smaller data before comm. Only valid when attn_tp_size is 1 (tp_size == dp_size)
409+
use_layer_norm_before_gather = context.attn_tp_size == 1
410+
if use_layer_norm_before_gather:
411+
residual.copy_(hidden_states)
412+
if hidden_states.shape[0] != 0:
413+
hidden_states = layernorm(hidden_states)
414+
410415
hidden_states, local_hidden_states = (
411416
forward_batch.gathered_buffer,
412417
hidden_states,
413418
)
414419
dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
420+
421+
if not use_layer_norm_before_gather:
422+
dp_scatter(residual, hidden_states, forward_batch)
423+
if hidden_states.shape[0] != 0:
424+
hidden_states = layernorm(hidden_states)
415425
else:
416426
# According to the discussion in https://github.com/flashinfer-ai/flashinfer/issues/1223#issuecomment-3047256465
417427
# We set the max token num to 128 for allreduce fusion with min-latency case(use_oneshot=True).

0 commit comments

Comments
 (0)