-
Notifications
You must be signed in to change notification settings - Fork 2.2k
📽 Multi image support for GRPO/RLOO #4113
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…_thw` in GRPO and RLOO trainers; update `split_pixel_values_by_grid` to use `image_grid_thw`
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
) | ||
trainer = GRPOTrainer( | ||
model=model_id, | ||
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we don't support visual reward model, so it doesn't really make sense to test this case, where the image is dropped and a warning is raised.
trl/trainer/grpo_trainer.py
Outdated
# VLM reward models aren't supported yet, so we drop the image and raise a warning if needed | ||
for prompt in prompts: | ||
for turn in prompt: | ||
if isinstance(turn["content"], list): | ||
logger.warning_once("Visual reward models aren't supported yet; dropping image.") | ||
turn["content"] = " ".join( | ||
e["text"] for e in turn["content"] if e["type"] == "text" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from
[{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}]
to
[{"role": "user", "content": "What color is the sky?"}]
plus raise warning
# We don't yet support visual reward models/function, so we keep a copy of the original text-only prompts for | ||
# later use in the reward computation. If images are present, we insert {"type": "image"} as required by the | ||
# VLM chat template. | ||
original_prompts = copy.deepcopy(prompts) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
instead of keeping the original prompt, we just drop the image later, and raise a warning, see https://github.com/huggingface/trl/pull/4113/files#r2364899902
# important because rewards will be normalized per group, and completions are distributed. We will later slice | ||
# rewards_per_func to extract each process's subset. | ||
rewards_per_func = self._calculate_rewards(inputs, original_prompts, completions, completion_ids_list) | ||
rewards_per_func = self._calculate_rewards(inputs, prompts, completions, completion_ids_list) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if self._logs["images"]: | ||
table["images"] = [] | ||
for image_list in self._logs["images"]: | ||
# Convert images to wandb Image objects for proper visualization | ||
table["images"].append([wandb.Image(image) for image in image_list]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
boundaries = [0, *accumulate(batch["num_images"])] # [3, 4, 5] -> [0, 3, 7, 12] | ||
sections = [sum(lengths[boundaries[i] : boundaries[i + 1]]) for i in range(len(batch["num_images"]))] | ||
split_values = list(torch.split(batch["pixel_values"], sections, dim=0)) | ||
image_grid_thw = list(torch.split(batch["image_grid_thw"], batch["num_images"], dim=0)) | ||
return {**batch, "pixel_values": split_values, "image_grid_thw": image_grid_thw} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
instead of keeping image_grid_thw
as is, we need to split it depending on the number of images. It gets concatenated later in _get_per_token_logps_and_entropies
(see line 807)
trl/trainer/grpo_trainer.py
Outdated
model_inputs["image_grid_thw"] = torch.cat(image_grid_thw[start : start + batch_size]) | ||
start_pixel_idx = 0 if start == 0 else torch.cat(image_grid_thw[:start]).prod(-1).sum().item() | ||
end_pixel_idx = torch.cat(image_grid_thw[: start + batch_size]).prod(-1).sum().item() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See https://github.com/huggingface/trl/pull/4113/files#r2364904060, image_grid_thw
is not a tensor anymore, but a list of tensor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM with a question about whether raising an error vs a warning is best when images + text are being passed to the reward function
|
||
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) | ||
|
||
for n, param in previous_trainable_params.items(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does the same comment for GRPO apply here? https://github.com/huggingface/trl/pull/4113/files#diff-96dca172e696190fc3e1469166e88aface95ebae959284c6806f2e25d2217c16R1587
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
answered here #4113 (comment)
trl/trainer/grpo_trainer.py
Outdated
for prompt in prompts: | ||
for turn in prompt: | ||
if isinstance(turn["content"], list): | ||
logger.warning_once("Visual reward models aren't supported yet; dropping image.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would raising an error be better than a warning? Otherwise I could imagine the warning could be missed and the training "fails silently" because the reward is only computed on the text part.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I see. I wonder if anyone would want to train a VLM with a standard LM reward model (ie, not visual reward model). But so far, I've never seen that. We could always support it in the future if there is demand for it. I'll remove this warning, and if the user tries it, the rendering of the chat template will fail, which will prevent from ending in the case of the training failing silently that you describe.
table["images"] = [] | ||
for image_list in self._logs["images"]: | ||
# Convert images to wandb Image objects for proper visualization | ||
table["images"].append([wandb.Image(image) for image in image_list]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At some point it would be nice to also add the trackio
variant for table images
This PR belongs to a sequence of PR that aims to refactor the generation part of GRPO/RLOO to allow for easier customization and ultimately tool calling
Previous:
image_split_sizes
in favour ofimage_grid_thw
#4111Next:
_generate
#4114_generate
in GRPO/RLOO: list of ints instead of tensors #4146_generate
in GRPO/RLOO: Useprompt_ids
from generation #4152_generate
in GRPO/RLOO: Rely on generator for prompt truncation #4153_generate
in GRPO/RLOO: Moveforward_kwargs
outside generation method #4154_generate
in GRPO/RLOO: Insert images in the prompt #4155While refactoring, I realized that having a clean multi-image support help having a cleaner separation between functions.
try with