Skip to content

Conversation

ravi-mosaicml
Copy link
Contributor

@ravi-mosaicml ravi-mosaicml commented Jan 13, 2022

  1. Remove the dataspec from state! Instead, the trainer sets batch_num_samples, batch_num_tokens, microbatches, and microbatch_idx so these fields are accessible to algorithms that need to know the pertinent information that the data spec provided.
  2. Removed last_batch_size from state. Replaced it with dist.all_reduce(state.batch_num_samples) where it was used.
  3. Removed train_batch_size and eval_batch_size from state, as algorithms should not depend on constant batch sizing. Replaced it with state.train_dataloader.batch_size * dist.get_world_size() in the few places where it was used
  4. Added a helper function to get_device_of_batch, which is required for part 3 (since dist.all_reduce requires tensors to be placed on the device it torch.dist was initialized with)
  5. Fixed the type annotations for ensure_tuple to support get_device_of_batch

1. Remove the dataspec from state! Instead, the trainer sets `batch_num_samples`, `batch_num_tokens`, `microbatches`, and `microbatch_idx` so these fields are accessible to algorithms that need to know the pertinent information that the data spec provided.
2. Removed last_batch_size from state. Replaced it with `dist.all_reduce(state.batch_num_samples)` where it was used.
3. Removed `train_batch_size` and `eval_batch_size` from state, as algorithms should not depend on constant batch sizing. Replaced it with `state.train_dataloader.batch_size * dist.get_world_size()` in the few places where it was used
4. Added a helper function to get the device of the batch, which is required for part 3 (since `dist.all_reduce` requires tensors to be placed on the device it torch.dist was initialized with)
Copy link
Contributor

@jbloxham jbloxham left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me! Glad to see the state getting simplified.

@ravi-mosaicml ravi-mosaicml merged commit 0da611f into dev Jan 18, 2022
@ravi-mosaicml ravi-mosaicml deleted the ravi/cleanup_state branch January 18, 2022 17:29
ravi-mosaicml added a commit that referenced this pull request Jan 20, 2022
1. #223 introduced a bug where algorithms that run on the AFTER_DATALOADER and did a not-in-place modification of state.batch did not also update state.microbatches (which was used for training), so these algorithms were effectively ignored. Fixed this bug by computing the microbataches AFTER the Event.AFTER_DATALOADER event.

2. Removed the `microbatches` and `microbatch_idx` from the state. Instead, algorithms that need to run on smaller batch sizes should use the Event.BATCH_START event instead of Event.AFTER_DATALOADER, since Event.BATCH_START will get the forward-pass sized batch.
ravi-mosaicml added a commit that referenced this pull request Jan 20, 2022
…258)

1. #223 introduced a bug where algorithms that run on the AFTER_DATALOADER and did a not-in-place modification of state.batch did not also update state.microbatches (which was used for training), so these algorithms were effectively ignored. Fixed this bug by computing the microbataches AFTER the Event.AFTER_DATALOADER event.

2. Removed the `microbatches` and `microbatch_idx` from the state. Instead, algorithms that need to run on smaller batch sizes should use the Event.BATCH_START event instead of Event.AFTER_DATALOADER, since Event.BATCH_START will get the forward-pass sized batch.
coryMosaicML pushed a commit to coryMosaicML/composer that referenced this pull request Feb 23, 2022
1. Remove the dataspec from state! Instead, the trainer sets `batch_num_samples`, `batch_num_tokens`, `microbatches`, and `microbatch_idx` so these fields are accessible to algorithms that need to know the pertinent information that the data spec provided.
2. Removed last_batch_size from state. Replaced it with `dist.all_reduce(state.batch_num_samples)` where it was used.
3. Removed `train_batch_size` and `eval_batch_size` from state, as algorithms should not depend on constant batch sizing. Replaced it with `state.train_dataloader.batch_size * dist.get_world_size()` in the few places where it was used
4. Added a helper function to `get_device_of_batch`, which is required for part 3 (since `dist.all_reduce` requires tensors to be placed on the device it torch.dist was initialized with)
5. Fixed the type annotations for `ensure_tuple` to support `get_device_of_batch`
coryMosaicML pushed a commit to coryMosaicML/composer that referenced this pull request Feb 23, 2022
…osaicml#258)

1. mosaicml#223 introduced a bug where algorithms that run on the AFTER_DATALOADER and did a not-in-place modification of state.batch did not also update state.microbatches (which was used for training), so these algorithms were effectively ignored. Fixed this bug by computing the microbataches AFTER the Event.AFTER_DATALOADER event.

2. Removed the `microbatches` and `microbatch_idx` from the state. Instead, algorithms that need to run on smaller batch sizes should use the Event.BATCH_START event instead of Event.AFTER_DATALOADER, since Event.BATCH_START will get the forward-pass sized batch.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants