-
Notifications
You must be signed in to change notification settings - Fork 454
Add the ability to load a checkpoint without restoring state #169
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
return CheckpointLoader(checkpoint_filepath=self.checkpoint_filepath, | ||
load_weights_only=self.load_weights_only, | ||
strict=self.strict) | ||
return None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This return None
is probably causing a pyright bug.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Taken care of!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks pretty good overall, just some minor changes and comments.
Cool, addressed all feedback! re: the PyRight issues, they're not problems in practice, should we add a manual ignore? In more detail:
The argument can't be None, because the method signature enforces a default. It seems as if PyRight isn't catching onto this?
Before this line, we check if the |
For the pyright issues where it's complaining about optional variables, you need to add one of these options before the line it's complaining about:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code path looks great! Please a some test to verify it. Thinking something like this:
- Train, save checkpoint
- Load checkpoint with a different optimizer and scheduler with
weights_only=True
- Assert that the weights are the same as the first trainer, but that the optimizer is the new one
composer/trainer/checkpoint.py
Outdated
def __init__(self, | ||
checkpoint_filepath: str, | ||
load_weights_only: Optional[bool] = False, | ||
strict: Optional[bool] = False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
strict: Optional[bool] = False): | |
strict: bool = False): |
Cool, I've addressed all feedback and added the tests that Ravi requested. Hanlin also requested that, instead of passing the Checkpoint{Loader, Saver} object to the Trainer, I pass the hparams directly to make it more BYOT friendly. I agree there, so that has also been reflected. I've addressed all PyRight issues on my changed files, but it seems as if PyRight is complaining about a few extra files that are outside of the scope of this PR. Namely, |
Can you merge in the latest from dev and see if that fixes those files? |
15c5695
to
6f71a71
Compare
@ravi-mosaicml Just did -- didn't help for some reason. Any clue why this is happening? |
Probably a pyright update...can you fix what's it complaining about in
those files if they're small changes?
…On Mon, Jan 3, 2022, 7:22 AM Moin Nadeem ***@***.***> wrote:
@ravi-mosaicml <https://github.com/ravi-mosaicml> Just did -- didn't help
for some reason. Any clue why this is happening?
—
Reply to this email directly, view it on GitHub
<#169 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AUYBL6H5SR4MF4PB5IPANTTUUG5MPANCNFSM5KHHTSGA>
.
Triage notifications on the go with GitHub Mobile for iOS
<https://apps.apple.com/app/apple-store/id1477376905?ct=notification-email&mt=8&pt=524675>
or Android
<https://play.google.com/store/apps/details?id=com.github.android&referrer=utm_campaign%3Dnotification-email%26utm_medium%3Demail%26utm_source%3Dgithub>.
You are receiving this because you were mentioned.Message ID:
***@***.***>
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
…l#169) * fixing checkpoint bug * finalizing fine-tuning a checkpointed model * address PR feedback * adding save_checkpoint and load_checkpoint hparams interface * yapf & pyright * changing interface * everyone always asks 'what is yapf', but never 'how is yapf'? * renaming Checkpointer -> CheckpointSaver * renaming Checkpointer -> CheckpointSaver * addressing feedback & friendly renaming * addressing pyright * yapf * adding tests * moving commits to BERT branch * changing folder to be relative to run dir * adding tests * pyright part 1 * pyright on trainer file * moving restoring RNG & random seed to else clause * Fix tests * Addressed comments Co-authored-by: Moin Nadeem <[email protected]> Co-authored-by: Ravi Rahman <[email protected]>
# setup a new LR scheduler | ||
scheduler_options = [ConstantLRHparams(), CosineAnnealingLRHparams(T_max=f"{second_trainer_hparams.max_epochs}ep")] | ||
second_trainer_hparams.schedulers = [random.choice(scheduler_options)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@moinnadeem just ran into this now - it's very strange to use randomness in a test in this way since it can potentially cause flaky tests. Was there a reason you wanted this randomness here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I see the rationale. I just wanted to make sure that we test several schedulers, but didn't think it was worth the time to test all of them. In hindsight, we should either pick the more difficult one, or do both. I agree with you.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm modifying this line in a PR for a different purpose anyways, so I'll fix it here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Jamie!
Motivation
This pull request enables the ability to load a checkpoint without loading the associated state. It does so via the following YAML changes:
CheckpointLoaderHparams
object, with three fields:checkpoint_filepath
,load_weights_only
, andstrict
.load_weights_only = False
, then nothing changes and the previous codepaths are used. Thestrict
value isn't considered ifload_weights_only = False
, since restoring a checkpoint with state should ensure that the model exactly matches up. YAML validation ensures that thestrict
value cannot be set withoutload_weights_only = True
.load_weights_only = True
, then it loads the checkpoint and avoids recovering the state via a new codepath. Ifstrict = False
, it also prints the keys that did not match up for user safety.It creates the
CheckpointLoader
object when theTrainer
is created viacreate_from_hparams
, and passes the CheckpointLoader in as well.Discussion Points