Skip to content

Commit 8624131

Browse files
committed
temp disable should save peft only
1 parent 524816b commit 8624131

File tree

1 file changed

+19
-7
lines changed

1 file changed

+19
-7
lines changed

composer/callbacks/load_checkpoint.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@
33

44
"""Load a checkpoint."""
55
import logging
6-
from typing import List, Optional, Union
6+
from typing import Optional, Union
77

8-
from composer.utils.checkpoint import load_checkpoint
98
from composer.core import Callback, State
109
from composer.core.event import Event
1110
from composer.loggers import Logger
11+
from composer.models.huggingface import HuggingFaceModel
12+
from composer.utils.checkpoint import load_checkpoint
1213

1314
log = logging.getLogger(__name__)
1415

@@ -27,8 +28,8 @@ def __init__(
2728
load_path: str,
2829
load_weights_only: bool = False,
2930
strict_model_weights: bool = True,
30-
ignore_keys: Optional[List[str]] = None,
31-
event: Union[str, Event] = Event.AFTER_LOAD,
31+
ignore_keys: Optional[list[str]] = None,
32+
event: Union[str, Event] = Event.BEFORE_LOAD,
3233
):
3334
super().__init__()
3435
self.load_path = load_path
@@ -40,14 +41,22 @@ def __init__(
4041

4142
def run_event(self, event: Event, state: State, logger: Logger) -> None:
4243
if event == self.event:
43-
log.info(f'Loading checkpoint from {self.load_path} at event {self.event}.')
44+
log.info(f'Loading checkpoint from {self.load_path} at {self.event}.')
4445
self._load(state, logger)
45-
log.info(f'Finished loading checkpoint from {self.load_path} at event {self.event}.')
46+
log.info(f'Finished loading checkpoint from {self.load_path} at {self.event}.')
4647

4748
return super().run_event(event, state, logger)
4849

4950
def _load(self, state: State, logger: Logger) -> None:
50-
print('state state dict', state.state_dict()['model'].keys())
51+
52+
# We need to temporarily disable the `should_save_peft_only` flag on the model
53+
# so that we can have access to the full model weights if needed for loading.
54+
model = state.model
55+
original_should_save_peft_only = False
56+
if isinstance(model, HuggingFaceModel):
57+
original_should_save_peft_only = model.should_save_peft_only
58+
model.should_save_peft_only = False
59+
5160
load_checkpoint(
5261
path=self.load_path,
5362
state=state,
@@ -57,3 +66,6 @@ def _load(self, state: State, logger: Logger) -> None:
5766
load_weights_only=self.load_weights_only,
5867
)
5968

69+
# Restore the original `should_save_peft_only` flag on the model
70+
if isinstance(model, HuggingFaceModel):
71+
model.should_save_peft_only = original_should_save_peft_only

0 commit comments

Comments
 (0)