Skip to content

Commit 5435e8a

Browse files
author
luoyuan.luo
committed
Add torch_symm_mem server arg
1 parent 7eac525 commit 5435e8a

File tree

5 files changed

+22
-8
lines changed

5 files changed

+22
-8
lines changed

benchmark/lora/launch_server.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ def launch_server(args):
2828
cmd += "--disable-custom-all-reduce"
2929
if args.enable_mscclpp:
3030
cmd += "--enable-mscclpp"
31+
if args.enable_torch_symm_mem:
32+
cmd += "--enable-torch-symm-mem"
3133
print(cmd)
3234
os.system(cmd)
3335

@@ -70,6 +72,11 @@ def launch_server(args):
7072
action="store_true",
7173
help="Enable using mscclpp for small messages for all-reduce kernel and fall back to NCCL.",
7274
)
75+
parser.add_argument(
76+
"--enable-torch-symm-mem",
77+
action="store_true",
78+
help="Enable using torch symm mem for all-reduce kernel and fall back to NCCL.",
79+
)
7380
args = parser.parse_args()
7481

7582
launch_server(args)

python/sglang/srt/distributed/parallel_state.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ class GroupCoordinator:
208208
use_pynccl: bool # a hint of whether to use PyNccl
209209
use_pymscclpp: bool # a hint of whether to use PyMsccl
210210
use_custom_allreduce: bool # a hint of whether to use CustomAllreduce
211-
use_symm_mem: bool # a hint of whether to use SymmMemAllReduce
211+
use_torch_symm_mem: bool # a hint of whether to use SymmMemAllReduce
212212
use_message_queue_broadcaster: (
213213
bool # a hint of whether to use message queue broadcaster
214214
)
@@ -226,7 +226,7 @@ def __init__(
226226
use_pynccl: bool,
227227
use_pymscclpp: bool,
228228
use_custom_allreduce: bool,
229-
use_symm_mem: bool,
229+
use_torch_symm_mem: bool,
230230
use_hpu_communicator: bool,
231231
use_xpu_communicator: bool,
232232
use_npu_communicator: bool,
@@ -275,7 +275,7 @@ def __init__(
275275
self.use_pynccl = use_pynccl
276276
self.use_pymscclpp = use_pymscclpp
277277
self.use_custom_allreduce = use_custom_allreduce
278-
self.use_symm_mem = use_symm_mem
278+
self.use_torch_symm_mem = use_torch_symm_mem
279279
self.use_hpu_communicator = use_hpu_communicator
280280
self.use_xpu_communicator = use_xpu_communicator
281281
self.use_npu_communicator = use_npu_communicator
@@ -343,7 +343,7 @@ def __init__(
343343
logger.warning(f"Failed to initialize QuickAllReduce: {e}")
344344

345345
self.symm_mem_comm: Optional[SymmMemCommunicator] = None
346-
if self.use_symm_mem and self.world_size > 1:
346+
if self.use_torch_symm_mem and self.world_size > 1:
347347
self.symm_mem_comm = SymmMemCommunicator(
348348
group=self.cpu_group,
349349
device=self.device,
@@ -453,6 +453,7 @@ def graph_capture(
453453
# custom allreduce | enabled | enabled |
454454
# PyNccl | disabled| enabled |
455455
# PyMscclpp | disabled| enabled |
456+
# TorchSymmMem | disabled| enabled |
456457
# torch.distributed | enabled | disabled|
457458
#
458459
# Note: When custom quick allreduce is enabled, a runtime check
@@ -1223,7 +1224,7 @@ def init_world_group(
12231224
use_pynccl=False,
12241225
use_pymscclpp=False,
12251226
use_custom_allreduce=False,
1226-
use_symm_mem=False,
1227+
use_torch_symm_mem=False,
12271228
use_hpu_communicator=False,
12281229
use_xpu_communicator=False,
12291230
use_npu_communicator=False,
@@ -1254,7 +1255,7 @@ def init_model_parallel_group(
12541255
use_pynccl=not _is_npu,
12551256
use_pymscclpp=use_mscclpp_allreduce,
12561257
use_custom_allreduce=use_custom_allreduce,
1257-
use_symm_mem=use_symm_mem_allreduce,
1258+
use_torch_symm_mem=use_symm_mem_allreduce,
12581259
use_hpu_communicator=True,
12591260
use_xpu_communicator=True,
12601261
use_npu_communicator=True,

python/sglang/srt/layers/dp_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ def initialize_dp_attention(
263263
use_pynccl=SYNC_TOKEN_IDS_ACROSS_TP,
264264
use_pymscclpp=False,
265265
use_custom_allreduce=False,
266-
use_symm_mem=False,
266+
use_torch_symm_mem=False,
267267
use_hpu_communicator=False,
268268
use_xpu_communicator=False,
269269
use_npu_communicator=False,

python/sglang/srt/model_executor/model_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -625,7 +625,7 @@ def init_torch_distributed(self):
625625
dist_init_method = f"tcp://127.0.0.1:{self.dist_port}"
626626
set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
627627
set_mscclpp_all_reduce(self.server_args.enable_mscclpp)
628-
set_symm_mem_all_reduce(self.server_args.enable_symm_mem)
628+
set_symm_mem_all_reduce(self.server_args.enable_torch_symm_mem)
629629

630630
if not self.is_draft_worker:
631631
if self.device == "cpu":

python/sglang/srt/server_args.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,7 @@ class ServerArgs:
358358
disable_outlines_disk_cache: bool = False
359359
disable_custom_all_reduce: bool = False
360360
enable_mscclpp: bool = False
361+
enable_torch_symm_mem: bool = False
361362
disable_overlap_schedule: bool = False
362363
enable_mixed_chunk: bool = False
363364
enable_dp_attention: bool = False
@@ -2090,6 +2091,11 @@ def add_cli_args(parser: argparse.ArgumentParser):
20902091
action="store_true",
20912092
help="Enable using mscclpp for small messages for all-reduce kernel and fall back to NCCL.",
20922093
)
2094+
parser.add_argument(
2095+
"--enable-torch-symm-mem",
2096+
action="store_true",
2097+
help="Enable using torch symm mem for all-reduce kernel and fall back to NCCL.",
2098+
)
20932099
parser.add_argument(
20942100
"--disable-overlap-schedule",
20952101
action="store_true",

0 commit comments

Comments
 (0)