-
Notifications
You must be signed in to change notification settings - Fork 454
Cleaned Up State #223
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Cleaned Up State #223
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
1. Remove the dataspec from state! Instead, the trainer sets `batch_num_samples`, `batch_num_tokens`, `microbatches`, and `microbatch_idx` so these fields are accessible to algorithms that need to know the pertinent information that the data spec provided. 2. Removed last_batch_size from state. Replaced it with `dist.all_reduce(state.batch_num_samples)` where it was used. 3. Removed `train_batch_size` and `eval_batch_size` from state, as algorithms should not depend on constant batch sizing. Replaced it with `state.train_dataloader.batch_size * dist.get_world_size()` in the few places where it was used 4. Added a helper function to get the device of the batch, which is required for part 3 (since `dist.all_reduce` requires tensors to be placed on the device it torch.dist was initialized with)
jbloxham
approved these changes
Jan 14, 2022
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me! Glad to see the state getting simplified.
ravi-mosaicml
added a commit
that referenced
this pull request
Jan 20, 2022
1. #223 introduced a bug where algorithms that run on the AFTER_DATALOADER and did a not-in-place modification of state.batch did not also update state.microbatches (which was used for training), so these algorithms were effectively ignored. Fixed this bug by computing the microbataches AFTER the Event.AFTER_DATALOADER event. 2. Removed the `microbatches` and `microbatch_idx` from the state. Instead, algorithms that need to run on smaller batch sizes should use the Event.BATCH_START event instead of Event.AFTER_DATALOADER, since Event.BATCH_START will get the forward-pass sized batch.
ravi-mosaicml
added a commit
that referenced
this pull request
Jan 20, 2022
…258) 1. #223 introduced a bug where algorithms that run on the AFTER_DATALOADER and did a not-in-place modification of state.batch did not also update state.microbatches (which was used for training), so these algorithms were effectively ignored. Fixed this bug by computing the microbataches AFTER the Event.AFTER_DATALOADER event. 2. Removed the `microbatches` and `microbatch_idx` from the state. Instead, algorithms that need to run on smaller batch sizes should use the Event.BATCH_START event instead of Event.AFTER_DATALOADER, since Event.BATCH_START will get the forward-pass sized batch.
coryMosaicML
pushed a commit
to coryMosaicML/composer
that referenced
this pull request
Feb 23, 2022
1. Remove the dataspec from state! Instead, the trainer sets `batch_num_samples`, `batch_num_tokens`, `microbatches`, and `microbatch_idx` so these fields are accessible to algorithms that need to know the pertinent information that the data spec provided. 2. Removed last_batch_size from state. Replaced it with `dist.all_reduce(state.batch_num_samples)` where it was used. 3. Removed `train_batch_size` and `eval_batch_size` from state, as algorithms should not depend on constant batch sizing. Replaced it with `state.train_dataloader.batch_size * dist.get_world_size()` in the few places where it was used 4. Added a helper function to `get_device_of_batch`, which is required for part 3 (since `dist.all_reduce` requires tensors to be placed on the device it torch.dist was initialized with) 5. Fixed the type annotations for `ensure_tuple` to support `get_device_of_batch`
coryMosaicML
pushed a commit
to coryMosaicML/composer
that referenced
this pull request
Feb 23, 2022
…osaicml#258) 1. mosaicml#223 introduced a bug where algorithms that run on the AFTER_DATALOADER and did a not-in-place modification of state.batch did not also update state.microbatches (which was used for training), so these algorithms were effectively ignored. Fixed this bug by computing the microbataches AFTER the Event.AFTER_DATALOADER event. 2. Removed the `microbatches` and `microbatch_idx` from the state. Instead, algorithms that need to run on smaller batch sizes should use the Event.BATCH_START event instead of Event.AFTER_DATALOADER, since Event.BATCH_START will get the forward-pass sized batch.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
batch_num_samples
,batch_num_tokens
,microbatches
, andmicrobatch_idx
so these fields are accessible to algorithms that need to know the pertinent information that the data spec provided.dist.all_reduce(state.batch_num_samples)
where it was used.train_batch_size
andeval_batch_size
from state, as algorithms should not depend on constant batch sizing. Replaced it withstate.train_dataloader.batch_size * dist.get_world_size()
in the few places where it was usedget_device_of_batch
, which is required for part 3 (sincedist.all_reduce
requires tensors to be placed on the device it torch.dist was initialized with)ensure_tuple
to supportget_device_of_batch