Skip to content

Conversation

moinnadeem
Copy link
Contributor

Motivation

This pull request enables the ability to load a checkpoint without loading the associated state. It does so via the following YAML changes:

  1. Introduces a CheckpointLoaderHparams object, with three fields: checkpoint_filepath, load_weights_only, and strict.
  2. If load_weights_only = False, then nothing changes and the previous codepaths are used. The strict value isn't considered if load_weights_only = False, since restoring a checkpoint with state should ensure that the model exactly matches up. YAML validation ensures that the strict value cannot be set without load_weights_only = True.
  3. If load_weights_only = True, then it loads the checkpoint and avoids recovering the state via a new codepath. If strict = False, it also prints the keys that did not match up for user safety.

It creates the CheckpointLoader object when the Trainer is created via create_from_hparams, and passes the CheckpointLoader in as well.

Discussion Points

  1. Are we happy with the YAML API change?
  2. Should we add any tests for these new codepaths?

return CheckpointLoader(checkpoint_filepath=self.checkpoint_filepath,
load_weights_only=self.load_weights_only,
strict=self.strict)
return None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This return None is probably causing a pyright bug.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Taken care of!

Copy link
Contributor

@ravi-mosaicml ravi-mosaicml left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks pretty good overall, just some minor changes and comments.

@moinnadeem
Copy link
Contributor Author

moinnadeem commented Dec 18, 2021

Cool, addressed all feedback!

re: the PyRight issues, they're not problems in practice, should we add a manual ignore? In more detail:

 /home/runner/work/composer/composer/composer/trainer/checkpoint.py:47:60 - error: Argument of type "bool | None" cannot be assigned to parameter "strict" of type "bool" in function "load_model_state"
    Type "bool | None" cannot be assigned to type "bool"
      Type "None" cannot be assigned to type "bool" (reportGeneralTypeIssues)

The argument can't be None, because the method signature enforces a default. It seems as if PyRight isn't catching onto this?

/home/runner/work/composer/composer/composer/trainer/trainer.py
  /home/runner/work/composer/composer/composer/trainer/trainer.py:304:52 - error: "load_checkpoint" is not a known member of "None" (reportOptionalMemberAccess)

Before this line, we check if the checkpoint_loader is not None, so load_checkpoint can never be run on an object of NoneType. Should we manually ignore this?

@ravi-mosaicml
Copy link
Contributor

ravi-mosaicml commented Dec 21, 2021

For the pyright issues where it's complaining about optional variables, you need to add one of these options before the line it's complaining about:

  1. assert x is not None (if it's an invariant violation)
  2. if x is None: raise ValueError(f"x is None, but it shouldn't be because ...") (if it's a user error)

Copy link
Contributor

@ravi-mosaicml ravi-mosaicml left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code path looks great! Please a some test to verify it. Thinking something like this:

  1. Train, save checkpoint
  2. Load checkpoint with a different optimizer and scheduler with weights_only=True
  3. Assert that the weights are the same as the first trainer, but that the optimizer is the new one

def __init__(self,
checkpoint_filepath: str,
load_weights_only: Optional[bool] = False,
strict: Optional[bool] = False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
strict: Optional[bool] = False):
strict: bool = False):

@moinnadeem
Copy link
Contributor Author

Cool, I've addressed all feedback and added the tests that Ravi requested. Hanlin also requested that, instead of passing the Checkpoint{Loader, Saver} object to the Trainer, I pass the hparams directly to make it more BYOT friendly. I agree there, so that has also been reflected.

I've addressed all PyRight issues on my changed files, but it seems as if PyRight is complaining about a few extra files that are outside of the scope of this PR. Namely, composer/algorithms/augmix/augmix.py, composer/algorithms/randaugment/randaugment.py, and composer/datasets/brats.py. What should we do about these?

@ravi-mosaicml
Copy link
Contributor

Can you merge in the latest from dev and see if that fixes those files?

@moinnadeem moinnadeem force-pushed the moin/finetune_checkpoints branch from 15c5695 to 6f71a71 Compare January 3, 2022 15:15
@moinnadeem
Copy link
Contributor Author

@ravi-mosaicml Just did -- didn't help for some reason. Any clue why this is happening?

@ravi-mosaicml
Copy link
Contributor

ravi-mosaicml commented Jan 3, 2022 via email

Copy link
Contributor

@ravi-mosaicml ravi-mosaicml left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@moinnadeem moinnadeem merged commit 2b25192 into dev Jan 3, 2022
@moinnadeem moinnadeem deleted the moin/finetune_checkpoints branch January 3, 2022 23:31
coryMosaicML pushed a commit to coryMosaicML/composer that referenced this pull request Feb 23, 2022
…l#169)

* fixing checkpoint bug

* finalizing fine-tuning a checkpointed model

* address PR feedback

* adding save_checkpoint and load_checkpoint hparams interface

* yapf & pyright

* changing interface

* everyone always asks 'what is yapf', but never 'how is yapf'?

* renaming Checkpointer -> CheckpointSaver

* renaming Checkpointer -> CheckpointSaver

* addressing feedback & friendly renaming

* addressing pyright

* yapf

* adding tests

* moving commits to BERT branch

* changing folder to be relative to run dir

* adding tests

* pyright part 1

* pyright on trainer file

* moving restoring RNG & random seed to else clause

* Fix tests

* Addressed comments

Co-authored-by: Moin Nadeem <[email protected]>
Co-authored-by: Ravi Rahman <[email protected]>
Comment on lines +217 to +219
# setup a new LR scheduler
scheduler_options = [ConstantLRHparams(), CosineAnnealingLRHparams(T_max=f"{second_trainer_hparams.max_epochs}ep")]
second_trainer_hparams.schedulers = [random.choice(scheduler_options)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@moinnadeem just ran into this now - it's very strange to use randomness in a test in this way since it can potentially cause flaky tests. Was there a reason you wanted this randomness here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see the rationale. I just wanted to make sure that we test several schedulers, but didn't think it was worth the time to test all of them. In hindsight, we should either pick the more difficult one, or do both. I agree with you.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm modifying this line in a PR for a different purpose anyways, so I'll fix it here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Jamie!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants