Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
269a5f6
Add onnx Huggingface export test
nik-mosaic Sep 23, 2022
64d403a
Add helper method for input conversion
nik-mosaic Sep 23, 2022
187a946
update example, other calls
nik-mosaic Sep 23, 2022
b8af02a
Merge branch 'dev' into nikhil/onnx-export
nik-mosaic Sep 23, 2022
6cedc0f
Add brackets to callback tests
Sep 23, 2022
22d92f7
Merge branch 'nikhil/onnx-export' of github.com:nik-mosaic/composer i…
Sep 23, 2022
7a3d053
Add input names to export
Sep 23, 2022
5debe57
remove hf conversion function
nik-mosaic Sep 23, 2022
e554535
Remove unnecessary sample_input in test
Sep 23, 2022
f2f1f97
Merge branch 'nikhil/onnx-export' of github.com:nik-mosaic/composer i…
Sep 23, 2022
6f695e7
rerun tests
Sep 23, 2022
52c94bb
Merge branch 'dev' into nikhil/onnx-export
nik-mosaic Sep 26, 2022
f6342ad
Bump onnx and onnxruntime version
Sep 27, 2022
1b59a71
Merge branch 'nikhil/onnx-export' of github.com:nik-mosaic/composer i…
Sep 27, 2022
863c744
Merge branch 'dev' into nikhil/onnx-export
nik-mosaic Sep 27, 2022
150d630
Merge branch 'dev' into nikhil/onnx-export
nik-mosaic Sep 28, 2022
c9ccb3d
Merge branch 'dev' into nikhil/onnx-export
nik-mosaic Sep 30, 2022
643d8a7
Add device change
Sep 30, 2022
b1b83d2
Update test, update inference
Sep 30, 2022
192b8eb
Add gpu huggingface export onnx test
nik-mosaic Oct 1, 2022
a7dde16
Merge branch 'dev' into nikhil/onnx-export
nik-mosaic Oct 1, 2022
7360fac
Add doctstring for dynamic_axes input
nik-mosaic Oct 1, 2022
4c6d3db
Merge branch 'nikhil/onnx-export' of https://github.com/nik-mosaic/co…
nik-mosaic Oct 1, 2022
4309b48
remove comment
nik-mosaic Oct 1, 2022
f1eed85
Ensure tuple before moving to cpu
Oct 1, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion composer/callbacks/export_for_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,5 +86,5 @@ def export_model(self, state: State, logger: Logger):
save_path=self.save_path,
logger=logger,
save_object_store=self.save_object_store,
sample_input=(self.sample_input,),
sample_input=(self.sample_input, {}),
transforms=self.transforms)
2 changes: 1 addition & 1 deletion composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2581,5 +2581,5 @@ def export_for_inference(
save_path=save_path,
logger=self.logger,
save_object_store=save_object_store,
sample_input=(sample_input,),
sample_input=(sample_input, {}),
transforms=transforms)
8 changes: 7 additions & 1 deletion composer/utils/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,11 +180,17 @@ def export_for_inference(
raise ValueError(f'sample_input argument is required for onnx export')
sample_input = ensure_tuple(sample_input)

input_names = ['input']

# Extract input names from sample_input if it is a dict
if isinstance(sample_input[0], dict):
input_names = list(sample_input[0].keys())

torch.onnx.export(
model,
sample_input,
local_save_path,
input_names=['input'],
input_names=input_names,
output_names=['output'],
)

Expand Down
10 changes: 9 additions & 1 deletion examples/exporting_for_inference.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,14 @@
"print(f\"The predicted classes are {np.argmax(outputs[0], axis=1)}\")"
]
},
{
"cell_type": "markdown",
"id": "0bc52f62",
"metadata": {},
"source": [
"If our input is a dictionary, as if often the case when using a Composer [HuggingFaceModel](https://docs.mosaicml.com/en/stable/examples/huggingface_models.html), we'll need to make sure all the elements of our input dictionary are numpy arrays before calling `ort_session.run()`."
]
},
{
"cell_type": "markdown",
"id": "ca091f8e",
Expand Down Expand Up @@ -454,7 +462,7 @@
"export_for_inference(model=model, \n",
" save_format=save_format, \n",
" save_path=model_save_path, \n",
" sample_input=(input,),\n",
" sample_input=(input, {}),\n",
" surgery_algs=[cf.apply_squeeze_excite],\n",
" load_path=checkpoint_path)"
]
Expand Down
4 changes: 2 additions & 2 deletions tests/callbacks/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_inference_callback_torchscript(model_cls):
save_path=save_path,
logger=trainer.logger,
save_object_store=None,
sample_input=(exp_for_inf_callback.sample_input,),
sample_input=(exp_for_inf_callback.sample_input, {}),
transforms=None)


Expand Down Expand Up @@ -78,5 +78,5 @@ def test_inference_callback_onnx(model_cls):
save_path=save_path,
logger=trainer.logger,
save_object_store=None,
sample_input=(exp_for_inf_callback.sample_input,),
sample_input=(exp_for_inf_callback.sample_input, {}),
transforms=None)
69 changes: 65 additions & 4 deletions tests/utils/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,68 @@ def test_export_for_inference_torchscript(model_cls, sample_input):
)


def test_huggingface_export_for_inference_onnx():
pytest.importorskip('onnx')
pytest.importorskip('onnxruntime')
pytest.importorskip('transformers')

import onnx
import onnx.checker
import onnxruntime as ort
import transformers

from composer.models import HuggingFaceModel

# HuggingFace Bert Model
# dummy sequence batch with 2 labels, 32 sequence length, and 30522 (bert) vocab size).
input_ids = torch.randint(low=0, high=30522, size=(2, 32))
labels = torch.randint(low=0, high=1, size=(2,))
token_type_ids = torch.zeros(size=(2, 32), dtype=torch.int64)
attention_mask = torch.randint(low=0, high=1, size=(2, 32))
sample_input = {
'input_ids': input_ids,
'labels': labels,
'token_type_ids': token_type_ids,
'attention_mask': attention_mask,
}

# non pretrained model to avoid a slow test that downloads the weights.
config = transformers.AutoConfig.from_pretrained('bert-base-uncased', num_labels=2, hidden_act='gelu_new')
hf_model = transformers.AutoModelForSequenceClassification.from_config(config) # type: ignore (thirdparty)

model = HuggingFaceModel(hf_model)
model.eval()
orig_out = model(sample_input)

save_format = 'onnx'
with tempfile.TemporaryDirectory() as tempdir:
save_path = os.path.join(tempdir, f'model.{save_format}')
inference.export_for_inference(
model=model,
save_format=save_format,
save_path=save_path,
sample_input=(sample_input, {}),
)
loaded_model = onnx.load(save_path)

onnx.checker.check_model(loaded_model)

ort_session = ort.InferenceSession(save_path)

for key, value in sample_input.items():
sample_input[key] = value.numpy()

loaded_model_out = ort_session.run(None, sample_input)

torch.testing.assert_close(
orig_out['logits'].detach().numpy(),
loaded_model_out[1],
rtol=1e-4, # lower tolerance for ONNX
atol=1e-3, # lower tolerance for ONNX
msg=f'output mismatch with {save_format}',
)


@pytest.mark.parametrize(
'model_cls, sample_input',
[
Expand All @@ -87,7 +149,7 @@ def test_export_for_inference_onnx(model_cls, sample_input):
model=model,
save_format=save_format,
save_path=save_path,
sample_input=(sample_input,),
sample_input=(sample_input, {}),
)
loaded_model = onnx.load(save_path)
onnx.checker.check_model(loaded_model)
Expand Down Expand Up @@ -152,7 +214,7 @@ def test_export_for_inference_onnx_ddp(model_cls, sample_input):
model=state.model.module,
save_format=save_format,
save_path=save_path,
sample_input=(sample_input,),
sample_input=(sample_input, {}),
)

loaded_model = onnx.load(save_path)
Expand Down Expand Up @@ -247,7 +309,7 @@ def test_export_with_file_artifact_logger(model_cls, sample_input):
model=model,
save_format=save_format,
save_path=save_path,
sample_input=(sample_input,),
sample_input=(sample_input, {}),
logger=mock_logger,
)

Expand Down Expand Up @@ -292,7 +354,6 @@ def test_export_with_other_logger(model_cls, sample_input):
model=model,
save_format=save_format,
save_path=save_path,
sample_input=(sample_input,),
logger=mock_logger,
)

Expand Down