@@ -208,7 +208,7 @@ class GroupCoordinator:
208
208
use_pynccl : bool # a hint of whether to use PyNccl
209
209
use_pymscclpp : bool # a hint of whether to use PyMsccl
210
210
use_custom_allreduce : bool # a hint of whether to use CustomAllreduce
211
- use_symm_mem : bool # a hint of whether to use SymmMemAllReduce
211
+ use_torch_symm_mem : bool # a hint of whether to use SymmMemAllReduce
212
212
use_message_queue_broadcaster : (
213
213
bool # a hint of whether to use message queue broadcaster
214
214
)
@@ -226,7 +226,7 @@ def __init__(
226
226
use_pynccl : bool ,
227
227
use_pymscclpp : bool ,
228
228
use_custom_allreduce : bool ,
229
- use_symm_mem : bool ,
229
+ use_torch_symm_mem : bool ,
230
230
use_hpu_communicator : bool ,
231
231
use_xpu_communicator : bool ,
232
232
use_npu_communicator : bool ,
@@ -275,7 +275,7 @@ def __init__(
275
275
self .use_pynccl = use_pynccl
276
276
self .use_pymscclpp = use_pymscclpp
277
277
self .use_custom_allreduce = use_custom_allreduce
278
- self .use_symm_mem = use_symm_mem
278
+ self .use_torch_symm_mem = use_torch_symm_mem
279
279
self .use_hpu_communicator = use_hpu_communicator
280
280
self .use_xpu_communicator = use_xpu_communicator
281
281
self .use_npu_communicator = use_npu_communicator
@@ -343,7 +343,7 @@ def __init__(
343
343
logger .warning (f"Failed to initialize QuickAllReduce: { e } " )
344
344
345
345
self .symm_mem_comm : Optional [SymmMemCommunicator ] = None
346
- if self .use_symm_mem and self .world_size > 1 :
346
+ if self .use_torch_symm_mem and self .world_size > 1 :
347
347
self .symm_mem_comm = SymmMemCommunicator (
348
348
group = self .cpu_group ,
349
349
device = self .device ,
@@ -453,6 +453,7 @@ def graph_capture(
453
453
# custom allreduce | enabled | enabled |
454
454
# PyNccl | disabled| enabled |
455
455
# PyMscclpp | disabled| enabled |
456
+ # TorchSymmMem | disabled| enabled |
456
457
# torch.distributed | enabled | disabled|
457
458
#
458
459
# Note: When custom quick allreduce is enabled, a runtime check
@@ -1223,7 +1224,7 @@ def init_world_group(
1223
1224
use_pynccl = False ,
1224
1225
use_pymscclpp = False ,
1225
1226
use_custom_allreduce = False ,
1226
- use_symm_mem = False ,
1227
+ use_torch_symm_mem = False ,
1227
1228
use_hpu_communicator = False ,
1228
1229
use_xpu_communicator = False ,
1229
1230
use_npu_communicator = False ,
@@ -1254,7 +1255,7 @@ def init_model_parallel_group(
1254
1255
use_pynccl = not _is_npu ,
1255
1256
use_pymscclpp = use_mscclpp_allreduce ,
1256
1257
use_custom_allreduce = use_custom_allreduce ,
1257
- use_symm_mem = use_symm_mem_allreduce ,
1258
+ use_torch_symm_mem = use_symm_mem_allreduce ,
1258
1259
use_hpu_communicator = True ,
1259
1260
use_xpu_communicator = True ,
1260
1261
use_npu_communicator = True ,
0 commit comments