13
13
import warnings
14
14
from multiprocessing .context import SpawnProcess
15
15
from pathlib import Path
16
- from typing import Any , Optional , Sequence , Union
16
+ from typing import TYPE_CHECKING , Any , Optional , Sequence , Union
17
17
18
18
import numpy as np
19
19
import torch
48
48
from llmfoundry .utils .huggingface_hub_utils import \
49
49
edit_files_for_hf_compatibility
50
50
51
+ if TYPE_CHECKING :
52
+ from peft import PeftModel
53
+
51
54
try :
52
55
import transformer_engine .pytorch as te
53
56
is_te_imported = True
@@ -487,9 +490,9 @@ def _any_register_processes_error(self, device: Device) -> bool:
487
490
488
491
def transform_model_and_tokenizer (
489
492
self ,
490
- model : PreTrainedModel ,
493
+ model : Union [ PreTrainedModel , 'PeftModel' ] ,
491
494
tokenizer : PreTrainedTokenizerBase ,
492
- ) -> tuple [PreTrainedModel , PreTrainedTokenizerBase ]:
495
+ ) -> tuple [Union [ PreTrainedModel , 'PeftModel' ] , PreTrainedTokenizerBase ]:
493
496
"""Transform the model and tokenizer before saving.
494
497
495
498
This allows a subclass to modify the model and tokenizer before saving. The base class implementation will
@@ -537,8 +540,8 @@ def pre_register_edit(self, local_save_path: str):
537
540
538
541
def transform_model_pre_registration (
539
542
self ,
540
- model : PreTrainedModel ,
541
- ) -> PreTrainedModel :
543
+ model : Union [ PreTrainedModel , 'PeftModel' ] ,
544
+ ) -> Union [ PreTrainedModel , 'PeftModel' ] :
542
545
"""Transform the model before registering with MLflow.
543
546
544
547
This allows a subclass to modify the model before registering with MLflow. The base class implementation will
@@ -565,17 +568,23 @@ def _get_hf_model(self, state: State):
565
568
log .debug ('Gathering state dict' )
566
569
567
570
if state .is_model_ddp :
568
- original_model : PreTrainedModel = state .model .module .model # type: ignore
571
+ original_model : Union [
572
+ PreTrainedModel ,
573
+ 'PeftModel' ] = state .model .module .model # type: ignore
569
574
state_dict_model = state .model .module .model # type: ignore
570
- original_tokenizer = state .model .module .tokenizer # type: ignore
575
+ original_tokenizer : PreTrainedTokenizerBase = state .model .module .tokenizer # type: ignore
571
576
elif isinstance (state .model .model , FSDP ):
572
- original_model : PreTrainedModel = state .model .model .module # type: ignore
577
+ original_model : Union [
578
+ PreTrainedModel ,
579
+ 'PeftModel' ] = state .model .model .module # type: ignore
573
580
state_dict_model = state .model .model # type: ignore
574
- original_tokenizer = state .model .tokenizer # type: ignore
581
+ original_tokenizer : PreTrainedTokenizerBase = state .model .tokenizer # type: ignore
575
582
else :
576
- original_model : PreTrainedModel = state .model .model # type: ignore
583
+ original_model : Union [
584
+ PreTrainedModel ,
585
+ 'PeftModel' ] = state .model .model # type: ignore
577
586
state_dict_model = state .model .model # type: ignore
578
- original_tokenizer = state .model .tokenizer # type: ignore
587
+ original_tokenizer : PreTrainedTokenizerBase = state .model .tokenizer # type: ignore
579
588
580
589
cpu_offload = True
581
590
@@ -631,7 +640,7 @@ def tensor_hook(
631
640
632
641
# Transform HF config before building 2nd model copy
633
642
new_config = self .transform_config (
634
- original_config = original_model .config ,
643
+ original_config = original_model .config , # type: ignore
635
644
)
636
645
637
646
log .debug (f'Creating new model instance' )
@@ -640,25 +649,33 @@ def tensor_hook(
640
649
# initialization cost.
641
650
with init_empty_weights ():
642
651
if self .using_peft :
643
- active_adapter = original_model .active_adapter
644
- base_model = original_model .get_base_model ()
652
+ from peft import PeftModel
653
+ assert isinstance (original_model , PeftModel )
654
+ active_adapter = original_model .active_adapter # type: ignore
655
+ base_model : PreTrainedModel = original_model .get_base_model ( # type: ignore
656
+ )
645
657
new_base_model_instance = type (base_model )(new_config )
646
658
647
659
new_model_instance = type (original_model )(
648
- new_base_model_instance ,
649
- original_model .peft_config [active_adapter ],
660
+ new_base_model_instance , # type: ignore
661
+ original_model .
662
+ peft_config [active_adapter ], # type: ignore
650
663
)
651
664
del new_base_model_instance
652
665
else :
666
+ assert isinstance (original_model , PreTrainedModel )
653
667
new_model_instance = type (original_model )(new_config )
654
668
if new_model_instance .generation_config is not None :
669
+ assert original_model .generation_config is not None
655
670
new_model_instance .generation_config .update (
656
671
** original_model .generation_config .to_dict (),
657
672
)
658
673
659
674
# Then load the state dict in with "assign" so that the state dict
660
675
# is loaded properly even though the model is initially on meta device.
661
- new_model_instance .load_state_dict (state_dict , assign = True )
676
+ new_model_instance .load_state_dict ( # type: ignore
677
+ state_dict , assign = True ,
678
+ )
662
679
del state_dict
663
680
664
681
# Transform the model and tokenizer before saving
@@ -671,11 +688,14 @@ def tensor_hook(
671
688
if self .pretrained_model_name is not None :
672
689
new_model_instance .name_or_path = self .pretrained_model_name
673
690
if self .using_peft :
691
+ from peft import PeftModel
692
+ assert isinstance (new_model_instance , PeftModel )
674
693
new_model_instance .base_model .name_or_path = self .pretrained_model_name
675
- for k in new_model_instance .peft_config .keys ():
676
- new_model_instance .peft_config [
694
+ for k in new_model_instance .peft_config .keys ( # type: ignore
695
+ ):
696
+ new_model_instance .peft_config [ # type: ignore
677
697
k
678
- ].base_model_name_or_path = self .pretrained_model_name
698
+ ].base_model_name_or_path = self .pretrained_model_name # type: ignore
679
699
680
700
log .debug ('Saving Hugging Face checkpoint to disk' )
681
701
@@ -686,7 +706,7 @@ def _register_hf_model(
686
706
temp_save_dir : str ,
687
707
original_tokenizer : PreTrainedTokenizerBase ,
688
708
use_temp_dir : bool ,
689
- new_model_instance : PreTrainedModel ,
709
+ new_model_instance : Union [ PreTrainedModel , 'PeftModel' ] ,
690
710
):
691
711
assert new_model_instance is not None
692
712
new_model_instance = self .transform_model_pre_registration (
@@ -802,7 +822,7 @@ def _save_checkpoint(
802
822
)
803
823
804
824
# Only need to edit files for MPT because it has custom code
805
- if new_model_instance .config .model_type == 'mpt' :
825
+ if new_model_instance .config .model_type == 'mpt' : # type: ignore
806
826
log .debug ('Editing MPT files for HuggingFace compatibility' )
807
827
edit_files_for_hf_compatibility (
808
828
temp_save_dir ,
@@ -837,6 +857,12 @@ def _save_checkpoint(
837
857
None ,
838
858
)
839
859
if model_name is not None :
860
+ from peft import PeftModel
861
+ assert isinstance (new_model_instance , PeftModel )
862
+ assert isinstance (
863
+ new_model_instance .model ,
864
+ PreTrainedModel ,
865
+ )
840
866
new_model_instance .name_or_path = model_name
841
867
new_model_instance .model .name_or_path = model_name
842
868
new_model_instance .base_model .name_or_path = model_name
0 commit comments