From ea87e0fff0b46c8085471c153171d8829192f9bc Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Thu, 25 Sep 2025 12:26:32 +0000 Subject: [PATCH 1/2] fix --- src/transformers/trainer.py | 104 ++++++++---------------------------- 1 file changed, 22 insertions(+), 82 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 76d36327b308..e3dba869d167 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -208,13 +208,7 @@ if is_sagemaker_mp_enabled(): import smdistributed.modelparallel.torch as smp - from smdistributed.modelparallel import __version__ as SMP_VERSION - - IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10") - from .trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat -else: - IS_SAGEMAKER_MP_POST_1_10 = False if is_safetensors_available(): @@ -728,22 +722,14 @@ def __init__( if args.bf16: raise ValueError("SageMaker Model Parallelism does not support BF16 yet. Please use FP16 instead ") - if IS_SAGEMAKER_MP_POST_1_10: - # When there's mismatch between SMP config and trainer argument, use SMP config as truth - if args.fp16 != smp.state.cfg.fp16: - logger.warning( - f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16}, " - f"but FP16 provided in trainer argument is {args.fp16}, " - f"setting to {smp.state.cfg.fp16}" - ) - args.fp16 = smp.state.cfg.fp16 - else: - # smp < 1.10 does not support fp16 in trainer. - if hasattr(smp.state.cfg, "fp16"): - logger.warning( - f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16}, " - "but SageMaker Model Parallelism < 1.10 does not support FP16 in trainer." - ) + if args.fp16 != smp.state.cfg.fp16: + logger.warning( + f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16}, " + f"but FP16 provided in trainer argument is {args.fp16}, " + f"setting to {smp.state.cfg.fp16}" + ) + args.fp16 = smp.state.cfg.fp16 + if (args.fp16 or args.bf16) and args.half_precision_backend == "auto": if args.device == torch.device("cpu"): if args.fp16: @@ -1256,8 +1242,8 @@ def create_optimizer_and_scheduler(self, num_training_steps: int): `create_scheduler`) in a subclass. """ self.create_optimizer() - if IS_SAGEMAKER_MP_POST_1_10 and smp.state.cfg.fp16: - # If smp >= 1.10 and fp16 is enabled, we unwrap the optimizer + if is_sagemaker_mp_enabled() and smp.state.cfg.fp16: + # If fp16 is enabled, we unwrap the optimizer optimizer = self.optimizer.optimizer else: optimizer = self.optimizer @@ -2954,26 +2940,9 @@ def _load_from_checkpoint(self, resume_from_checkpoint, model=None): if os.path.isfile(weights_file) or os.path.isfile(safe_weights_file) or is_fsdp_ckpt: # If the model is on the GPU, it still works! if is_sagemaker_mp_enabled(): - if os.path.isfile(os.path.join(resume_from_checkpoint, "user_content.pt")): - # If the 'user_content.pt' file exists, load with the new smp api. - # Checkpoint must have been saved with the new smp api. - smp.resume_from_checkpoint( - path=resume_from_checkpoint, tag=WEIGHTS_NAME, partial=False, load_optimizer=False - ) - else: - # If the 'user_content.pt' file does NOT exist, load with the old smp api. - # Checkpoint must have been saved with the old smp api. - if hasattr(self.args, "fp16") and self.args.fp16 is True: - logger.warning( - "Enabling FP16 and loading from smp < 1.10 checkpoint together is not supported." - ) - check_torch_load_is_safe() - state_dict = torch.load(weights_file, map_location="cpu", weights_only=True) - # Required for smp to not auto-translate state_dict from hf to smp (is already smp). - state_dict["_smp_is_partial"] = False - load_result = model.load_state_dict(state_dict, strict=True) - # release memory - del state_dict + smp.resume_from_checkpoint( + path=resume_from_checkpoint, tag=WEIGHTS_NAME, partial=False, load_optimizer=False + ) elif self.is_fsdp_enabled: load_fsdp_model( self.accelerator.state.fsdp_plugin, @@ -3067,26 +3036,12 @@ def _load_best_model(self): ): has_been_loaded = True if is_sagemaker_mp_enabled(): - if os.path.isfile(os.path.join(self.state.best_model_checkpoint, "user_content.pt")): - # If the 'user_content.pt' file exists, load with the new smp api. - # Checkpoint must have been saved with the new smp api. - smp.resume_from_checkpoint( - path=self.state.best_model_checkpoint, - tag=WEIGHTS_NAME, - partial=False, - load_optimizer=False, - ) - else: - # If the 'user_content.pt' file does NOT exist, load with the old smp api. - # Checkpoint must have been saved with the old smp api. - if self.args.save_safetensors and os.path.isfile(best_safe_model_path): - state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu") - else: - check_torch_load_is_safe() - state_dict = torch.load(best_model_path, map_location="cpu", weights_only=True) - - state_dict["_smp_is_partial"] = False - load_result = model.load_state_dict(state_dict, strict=True) + smp.resume_from_checkpoint( + path=self.state.best_model_checkpoint, + tag=WEIGHTS_NAME, + partial=False, + load_optimizer=False, + ) else: if _is_peft_model(model): # If train a model using PEFT & LoRA, assume that adapter have been saved properly. @@ -3554,21 +3509,8 @@ def _load_optimizer_and_scheduler(self, checkpoint): self.lr_scheduler.load_state_dict(lr_scheduler_state) else: if is_sagemaker_mp_enabled(): - if os.path.isfile(os.path.join(checkpoint, "user_content.pt")): - # Optimizer checkpoint was saved with smp >= 1.10 - def opt_load_hook(mod, opt): - opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True)) - - else: - # Optimizer checkpoint was saved with smp < 1.10 - def opt_load_hook(mod, opt): - if IS_SAGEMAKER_MP_POST_1_10: - opt.load_state_dict( - smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True, back_compat=True) - ) - else: - opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True)) - + def opt_load_hook(mod, opt): + opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True)) self.model_wrapped.register_post_step_hook(opt_load_hook) else: # We use the CPU when training on one GPU to avoid OOM for GPU RAM when training big models. @@ -4191,9 +4133,7 @@ def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = Fa state_dict = self.model_wrapped.state_dict() if self.args.should_save: self._save(output_dir, state_dict=state_dict) - if IS_SAGEMAKER_MP_POST_1_10: - # 'user_content.pt' indicates model state_dict saved with smp >= 1.10 - Path(os.path.join(output_dir, "user_content.pt")).touch() + Path(os.path.join(output_dir, "user_content.pt")).touch() # We are in N-D parallelism if we have parallelism_config set, so we check accelerate if we're on a to_save rank elif getattr(self.accelerator, "parallelism_config", None) is not None: if self.accelerator.should_save_model: From ec0d52aba8e77832a5beda96965ef55fe4225313 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Thu, 25 Sep 2025 12:29:03 +0000 Subject: [PATCH 2/2] fix --- src/transformers/trainer.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index e3dba869d167..23dd635a672a 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -208,6 +208,7 @@ if is_sagemaker_mp_enabled(): import smdistributed.modelparallel.torch as smp + from .trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat @@ -3509,8 +3510,10 @@ def _load_optimizer_and_scheduler(self, checkpoint): self.lr_scheduler.load_state_dict(lr_scheduler_state) else: if is_sagemaker_mp_enabled(): + def opt_load_hook(mod, opt): opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True)) + self.model_wrapped.register_post_step_hook(opt_load_hook) else: # We use the CPU when training on one GPU to avoid OOM for GPU RAM when training big models.