20
20
from packaging import version
21
21
from torch .distributed ._shard .sharded_tensor import ShardedTensor
22
22
from torch .distributed ._tensor import DTensor
23
+ from torch .distributed .fsdp import FullyShardedDataParallel as FSDP
23
24
from torch .utils .data import DataLoader
24
25
from torchmetrics import Metric , MetricCollection
25
26
from torchmetrics .classification import MulticlassAccuracy
33
34
from composer .utils import FSDPConfig , TPConfig , dist , parse_uri
34
35
from composer .utils .checkpoint import dist_cp_load
35
36
from composer .utils .file_helpers import get_file
36
- from composer .utils .object_store import S3ObjectStore
37
+ from composer .utils .object_store import UCObjectStore
37
38
from composer .utils .reproducibility import get_rng_state
38
39
from tests .common import RandomClassificationDataset , deep_compare
39
40
from tests .common .markers import world_size
@@ -517,6 +518,7 @@ def test_fsdp_mixed_with_sync(
517
518
'0.28.0' ,
518
519
'0.29.0' ,
519
520
'0.30.0' ,
521
+ '0.31.0' ,
520
522
],
521
523
)
522
524
@pytest .mark .filterwarnings (r'ignore:.*metrics are not saved with sharded state dict.*:UserWarning' )
@@ -529,8 +531,7 @@ def test_fsdp_load_old_checkpoint(
529
531
precision : str ,
530
532
sharding_strategy : str ,
531
533
state_dict_type : str ,
532
- s3_bucket : str ,
533
- s3_read_only_prefix : str ,
534
+ uc_volume_path : str ,
534
535
composer_version : str ,
535
536
):
536
537
if composer_version == '0.18.1' and state_dict_type == 'full' and precision == 'amp_bf16' and sharding_strategy == 'FULL_SHARD' :
@@ -540,25 +541,27 @@ def test_fsdp_load_old_checkpoint(
540
541
if state_dict_type == 'sharded' :
541
542
pytest .skip ('Loading legacy sharded checkpoints are not supported after v0.25.0.' )
542
543
543
- load_path_dir = (
544
- f's3://{ s3_bucket } /{ s3_read_only_prefix } /backwards_compatibility/'
545
- f'{ composer_version } /{ sharding_strategy .lower ()} _{ state_dict_type } _'
546
- f'{ precision } /'
544
+ load_path_dir = os .path .join (
545
+ f'dbfs:/{ uc_volume_path } ' ,
546
+ 'backwards_compatibility' ,
547
+ composer_version ,
548
+ f'{ sharding_strategy .lower ()} _{ state_dict_type } _{ precision } ' ,
547
549
)
548
550
if ((version .parse (composer_version ) > version .parse ('0.15.0' )) and state_dict_type != 'full' ):
549
- load_path_dir = (load_path_dir + 'ep0-ba2/ ' )
551
+ load_path_dir = os . path . join (load_path_dir , 'ep0-ba2' )
550
552
551
- load_path = load_path_dir + f'ba2_rank0.pt'
553
+ load_path = os . path . join ( load_path_dir , f'ba2_rank0.pt' )
552
554
else :
553
- load_path = (
554
- f's3://{ s3_bucket } /{ s3_read_only_prefix } /backwards_compatibility/'
555
- f'{ composer_version } /{ sharding_strategy .lower ()} _{ state_dict_type } _'
556
- f'{ precision } /'
555
+ load_path = os .path .join (
556
+ f'dbfs:/{ uc_volume_path } ' ,
557
+ 'backwards_compatibility' ,
558
+ composer_version ,
559
+ f'{ sharding_strategy .lower ()} _{ state_dict_type } _{ precision } ' ,
557
560
)
558
561
if state_dict_type == 'full' :
559
- load_path += 'ba2_rank0.pt'
562
+ load_path = os . path . join ( load_path , 'ba2_rank0.pt' )
560
563
else :
561
- load_path += 'ep0-ba2/'
564
+ load_path = os . path . join ( load_path , 'ep0-ba2' )
562
565
563
566
if composer_version == '0.15.1' :
564
567
num_classes = 8 # This parameter setting is very important. Don't change or the test will fail.
@@ -619,7 +622,7 @@ def test_fsdp_load_old_checkpoint(
619
622
'rng' : get_rng_state (),
620
623
}
621
624
622
- object_store = S3ObjectStore ( bucket = f' { s3_bucket } ' )
625
+ object_store = UCObjectStore ( path = uc_volume_path )
623
626
storage_reader = DistCPObjectStoreReader (
624
627
source_path = parsed_load_path ,
625
628
destination_path = destination ,
0 commit comments