Skip to content
Merged
Changes from 1 commit
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
269a5f6
Add onnx Huggingface export test
nik-mosaic Sep 23, 2022
64d403a
Add helper method for input conversion
nik-mosaic Sep 23, 2022
187a946
update example, other calls
nik-mosaic Sep 23, 2022
b8af02a
Merge branch 'dev' into nikhil/onnx-export
nik-mosaic Sep 23, 2022
6cedc0f
Add brackets to callback tests
Sep 23, 2022
22d92f7
Merge branch 'nikhil/onnx-export' of github.com:nik-mosaic/composer i…
Sep 23, 2022
7a3d053
Add input names to export
Sep 23, 2022
5debe57
remove hf conversion function
nik-mosaic Sep 23, 2022
e554535
Remove unnecessary sample_input in test
Sep 23, 2022
f2f1f97
Merge branch 'nikhil/onnx-export' of github.com:nik-mosaic/composer i…
Sep 23, 2022
6f695e7
rerun tests
Sep 23, 2022
52c94bb
Merge branch 'dev' into nikhil/onnx-export
nik-mosaic Sep 26, 2022
f6342ad
Bump onnx and onnxruntime version
Sep 27, 2022
1b59a71
Merge branch 'nikhil/onnx-export' of github.com:nik-mosaic/composer i…
Sep 27, 2022
863c744
Merge branch 'dev' into nikhil/onnx-export
nik-mosaic Sep 27, 2022
150d630
Merge branch 'dev' into nikhil/onnx-export
nik-mosaic Sep 28, 2022
c9ccb3d
Merge branch 'dev' into nikhil/onnx-export
nik-mosaic Sep 30, 2022
643d8a7
Add device change
Sep 30, 2022
b1b83d2
Update test, update inference
Sep 30, 2022
192b8eb
Add gpu huggingface export onnx test
nik-mosaic Oct 1, 2022
a7dde16
Merge branch 'dev' into nikhil/onnx-export
nik-mosaic Oct 1, 2022
7360fac
Add doctstring for dynamic_axes input
nik-mosaic Oct 1, 2022
4c6d3db
Merge branch 'nikhil/onnx-export' of https://github.com/nik-mosaic/co…
nik-mosaic Oct 1, 2022
4309b48
remove comment
nik-mosaic Oct 1, 2022
f1eed85
Ensure tuple before moving to cpu
Oct 1, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 25 additions & 6 deletions composer/utils/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def export_for_inference(
save_path: str,
save_object_store: Optional[ObjectStore] = None,
sample_input: Optional[Any] = None,
dynamic_axes: Optional[Any] = None,
surgery_algs: Optional[Union[Callable[[nn.Module], nn.Module], Sequence[Callable[[nn.Module], nn.Module]]]] = None,
transforms: Optional[Sequence[Transform]] = None,
load_path: Optional[str] = None,
Expand Down Expand Up @@ -118,12 +119,24 @@ def export_for_inference(
if dist.get_global_rank() != 0:
return

# make a copy of the model so that we don't modify the original model
# Make a copy of the model so that we don't modify the original model
model = copy.deepcopy(model)

# make a copy of the sample input so that we don't modify the original sample input
# Make a copy of the sample input so that we don't modify the original sample input
sample_input = copy.deepcopy(sample_input)

# Move model and sample input to CPU for export
cpu = torch.device('cpu')
model.to(device=cpu)
if sample_input is not None:
for i in range(len(sample_input)):
if isinstance(sample_input[i], torch.Tensor):
sample_input[i] = sample_input[i].to(cpu)
elif isinstance(sample_input[i], dict):
for key, value in sample_input[i].items():
if isinstance(value, torch.Tensor):
sample_input[i][key] = value.to(cpu)

# Apply surgery algorithms in the given order
for alg in ensure_tuple(surgery_algs):
model = alg(model)
Expand Down Expand Up @@ -180,18 +193,24 @@ def export_for_inference(
raise ValueError(f'sample_input argument is required for onnx export')
sample_input = ensure_tuple(sample_input)

input_names = ['input']
input_names = []

# Extract input names from sample_input if it contains dicts
for i in range(len(sample_input)):
if isinstance(sample_input[i], dict):
input_names += list(sample_input[i].keys())

# Extract input names from sample_input if it is a dict
if isinstance(sample_input[0], dict):
input_names = list(sample_input[0].keys())
# Default input name if no dict present
if input_names == []:
input_names = ['input']

torch.onnx.export(
model,
sample_input,
local_save_path,
input_names=input_names,
output_names=['output'],
dynamic_axes=dynamic_axes,
)

# upload if required.
Expand Down