Skip to content

Commit 566d262

Browse files
authored
Don't use TP when tensor_parallel_degree is 1 (#3636)
Co-authored-by: Eitan Turok <[email protected]>
1 parent 82b9d1f commit 566d262

File tree

2 files changed

+32
-13
lines changed

2 files changed

+32
-13
lines changed

composer/core/state.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -612,6 +612,12 @@ def _validate_parallelism_configs(self):
612612
'Tensor parallelism (TP) currently requires FSDP with use_orig_params=True, '
613613
'which is the default and recommended setting.',
614614
)
615+
if self.tp_config.tensor_parallel_degree == 1:
616+
warnings.warn(
617+
'Received tensor_parallel_degree of 1, which is a no-op. Tensor parallelism will not be used.',
618+
UserWarning,
619+
)
620+
self.tp_config = None
615621

616622
# Load monolith rank0 only
617623
if self.load_monolith_rank0_only:

tests/trainer/test_tp.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
# Copyright 2022 MosaicML Composer authors
22
# SPDX-License-Identifier: Apache-2.0
33

4+
import contextlib
5+
46
import pytest
57
import torch
68
from packaging import version
79
from torch.utils.data import DataLoader
810

11+
from composer.optim import DecoupledSGDW
912
from composer.trainer.trainer import Trainer
1013
from composer.utils import dist
1114
from tests.common import (
@@ -17,12 +20,14 @@
1720

1821
@pytest.mark.gpu
1922
@world_size(4)
20-
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.3'), reason='requires PyTorch 2.3+')
2123
@pytest.mark.filterwarnings(r'ignore:.*\(TP\) is experimental.*:FutureWarning')
22-
def test_tp_train(world_size: int):
24+
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.3'), reason='requires PyTorch 2.3+')
25+
@pytest.mark.parametrize('tensor_parallel_degree', [1, 2])
26+
def test_tp_train(world_size: int, tensor_parallel_degree: int):
2327
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel
2428

2529
model = SimpleModel()
30+
optimizer = DecoupledSGDW(model.parameters(), lr=0.1)
2631
dataset = RandomClassificationDataset(size=8)
2732
dataloader = DataLoader(dataset, batch_size=2, sampler=dist.get_sampler(dataset))
2833

@@ -31,18 +36,26 @@ def test_tp_train(world_size: int):
3136
'fc2': RowwiseParallel(),
3237
}
3338

34-
trainer = Trainer(
35-
model=model,
36-
train_dataloader=dataloader,
37-
parallelism_config={
38-
'tp': {
39-
'layer_plan': layer_plan,
40-
'tensor_parallel_degree': 2,
39+
if tensor_parallel_degree == 1:
40+
expected_warning = 'Received tensor_parallel_degree of 1, which is a no-op. Tensor parallelism will not be used.'
41+
ctx = pytest.warns(UserWarning, match=expected_warning)
42+
else:
43+
ctx = contextlib.nullcontext()
44+
45+
with ctx:
46+
trainer = Trainer(
47+
model=model,
48+
optimizers=optimizer,
49+
train_dataloader=dataloader,
50+
parallelism_config={
51+
'tp': {
52+
'layer_plan': layer_plan,
53+
'tensor_parallel_degree': tensor_parallel_degree,
54+
},
55+
'fsdp': {},
4156
},
42-
'fsdp': {},
43-
},
44-
max_duration='3ba',
45-
)
57+
max_duration='3ba',
58+
)
4659

4760
trainer.fit()
4861

0 commit comments

Comments
 (0)