3
3
4
4
"""Load a checkpoint."""
5
5
import logging
6
- from typing import List , Optional , Union
6
+ from typing import Optional , Union
7
7
8
- from composer .utils .checkpoint import load_checkpoint
9
8
from composer .core import Callback , State
10
9
from composer .core .event import Event
11
10
from composer .loggers import Logger
11
+ from composer .models .huggingface import HuggingFaceModel
12
+ from composer .utils .checkpoint import load_checkpoint
12
13
13
14
log = logging .getLogger (__name__ )
14
15
@@ -27,8 +28,8 @@ def __init__(
27
28
load_path : str ,
28
29
load_weights_only : bool = False ,
29
30
strict_model_weights : bool = True ,
30
- ignore_keys : Optional [List [str ]] = None ,
31
- event : Union [str , Event ] = Event .AFTER_LOAD ,
31
+ ignore_keys : Optional [list [str ]] = None ,
32
+ event : Union [str , Event ] = Event .BEFORE_LOAD ,
32
33
):
33
34
super ().__init__ ()
34
35
self .load_path = load_path
@@ -40,14 +41,22 @@ def __init__(
40
41
41
42
def run_event (self , event : Event , state : State , logger : Logger ) -> None :
42
43
if event == self .event :
43
- log .info (f'Loading checkpoint from { self .load_path } at event { self .event } .' )
44
+ log .info (f'Loading checkpoint from { self .load_path } at { self .event } .' )
44
45
self ._load (state , logger )
45
- log .info (f'Finished loading checkpoint from { self .load_path } at event { self .event } .' )
46
+ log .info (f'Finished loading checkpoint from { self .load_path } at { self .event } .' )
46
47
47
48
return super ().run_event (event , state , logger )
48
49
49
50
def _load (self , state : State , logger : Logger ) -> None :
50
- print ('state state dict' , state .state_dict ()['model' ].keys ())
51
+
52
+ # We need to temporarily disable the `should_save_peft_only` flag on the model
53
+ # so that we can have access to the full model weights if needed for loading.
54
+ model = state .model
55
+ original_should_save_peft_only = False
56
+ if isinstance (model , HuggingFaceModel ):
57
+ original_should_save_peft_only = model .should_save_peft_only
58
+ model .should_save_peft_only = False
59
+
51
60
load_checkpoint (
52
61
path = self .load_path ,
53
62
state = state ,
@@ -57,3 +66,6 @@ def _load(self, state: State, logger: Logger) -> None:
57
66
load_weights_only = self .load_weights_only ,
58
67
)
59
68
69
+ # Restore the original `should_save_peft_only` flag on the model
70
+ if isinstance (model , HuggingFaceModel ):
71
+ model .should_save_peft_only = original_should_save_peft_only
0 commit comments