Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,13 +432,9 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None:
state.device,
) and self.final_register_only:
log.error(
'An error occurred in one or more registration processes. Fallback to saving the HuggingFace checkpoint.',
)
self._save_checkpoint(
state,
logger,
upload_to_save_folder=True,
register_to_mlflow=False,
'An error occurred in one or more registration processes. The model should still be logged to'
+
'the Mlflow artifacts, but will need to be registered manually',
)

# Clean up temporary save directory; all processes are done with it.
Expand Down
29 changes: 6 additions & 23 deletions tests/a_scripts/inference/test_convert_composer_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,30 +521,13 @@ def test_final_register_only(
trainer.fit()

if mlflow_registered_model_name is not None:
# We should always attempt to register the model once
assert mlflow_logger_mock.log_model.call_count == 1
if mlflow_registry_error:
# If the registry fails, we should still save the model
assert mlflow_logger_mock.log_model.call_count == 1
assert checkpointer_callback._save_checkpoint.call_count == 2
assert checkpointer_callback._save_checkpoint.call_args_list[
0].kwargs == {
'register_to_mlflow': True,
'upload_to_save_folder': False,
}
assert checkpointer_callback._save_checkpoint.call_args_list[
1].kwargs == {
'register_to_mlflow': False,
'upload_to_save_folder': True,
}
else:
# No mlflow_registry_error, so we should only register the model
assert checkpointer_callback._save_checkpoint.call_count == 1
assert checkpointer_callback._save_checkpoint.call_args_list[
0].kwargs == {
'register_to_mlflow': True,
'upload_to_save_folder': False,
}
assert checkpointer_callback._save_checkpoint.call_count == 1
assert checkpointer_callback._save_checkpoint.call_args_list[
0].kwargs == {
'register_to_mlflow': True,
'upload_to_save_folder': False,
}
else:
# No mlflow_registered_model_name, so we should only save the checkpoint
assert mlflow_logger_mock.log_model.call_count == 0
Expand Down
Loading