1
1
# Copyright 2022 MosaicML Composer authors
2
2
# SPDX-License-Identifier: Apache-2.0
3
3
4
+ import contextlib
5
+
4
6
import pytest
5
7
import torch
6
8
from packaging import version
7
9
from torch .utils .data import DataLoader
8
10
11
+ from composer .optim import DecoupledSGDW
9
12
from composer .trainer .trainer import Trainer
10
13
from composer .utils import dist
11
14
from tests .common import (
17
20
18
21
@pytest .mark .gpu
19
22
@world_size (4 )
20
- @pytest .mark .skipif (version .parse (torch .__version__ ) < version .parse ('2.3' ), reason = 'requires PyTorch 2.3+' )
21
23
@pytest .mark .filterwarnings (r'ignore:.*\(TP\) is experimental.*:FutureWarning' )
22
- def test_tp_train (world_size : int ):
24
+ @pytest .mark .skipif (version .parse (torch .__version__ ) < version .parse ('2.3' ), reason = 'requires PyTorch 2.3+' )
25
+ @pytest .mark .parametrize ('tensor_parallel_degree' , [1 , 2 ])
26
+ def test_tp_train (world_size : int , tensor_parallel_degree : int ):
23
27
from torch .distributed .tensor .parallel import ColwiseParallel , RowwiseParallel
24
28
25
29
model = SimpleModel ()
30
+ optimizer = DecoupledSGDW (model .parameters (), lr = 0.1 )
26
31
dataset = RandomClassificationDataset (size = 8 )
27
32
dataloader = DataLoader (dataset , batch_size = 2 , sampler = dist .get_sampler (dataset ))
28
33
@@ -31,18 +36,26 @@ def test_tp_train(world_size: int):
31
36
'fc2' : RowwiseParallel (),
32
37
}
33
38
34
- trainer = Trainer (
35
- model = model ,
36
- train_dataloader = dataloader ,
37
- parallelism_config = {
38
- 'tp' : {
39
- 'layer_plan' : layer_plan ,
40
- 'tensor_parallel_degree' : 2 ,
39
+ if tensor_parallel_degree == 1 :
40
+ expected_warning = 'Received tensor_parallel_degree of 1, which is a no-op. Tensor parallelism will not be used.'
41
+ ctx = pytest .warns (UserWarning , match = expected_warning )
42
+ else :
43
+ ctx = contextlib .nullcontext ()
44
+
45
+ with ctx :
46
+ trainer = Trainer (
47
+ model = model ,
48
+ optimizers = optimizer ,
49
+ train_dataloader = dataloader ,
50
+ parallelism_config = {
51
+ 'tp' : {
52
+ 'layer_plan' : layer_plan ,
53
+ 'tensor_parallel_degree' : tensor_parallel_degree ,
54
+ },
55
+ 'fsdp' : {},
41
56
},
42
- 'fsdp' : {},
43
- },
44
- max_duration = '3ba' ,
45
- )
57
+ max_duration = '3ba' ,
58
+ )
46
59
47
60
trainer .fit ()
48
61
0 commit comments