Skip to content

Commit 0e1cd88

Browse files
committed
[FFI][FEAT] AutoDLPack to enable external tensor args.
This PR introduces autodlpack feature to the tvm ffi. When an ffi Function takes Tensor argument that conforms to DLPack it automatically imports into NDArray and pass as argument. The feature will allow compiled function to directly take torch.Tensor as input argument without extra set of changes. When a function returns NDArray, the return value still needs to be converted back via torch.from_dlpack. However, a common use case is the destination passing, where all inputs outputs are pre-allocated and passed into the function. AutoDLPack effectively enables zero overhead support for a wide range of python arrays. We also added a benchmark script to measure the overall ffi overhead. One thing to note is that there is still continuguous and alignment requirement that is needed by underlying DSL compiler, as of now we use a global value. So x.continugous is still needed before passing the argument if tranpose or other ops are performed.
1 parent bcb68b1 commit 0e1cd88

File tree

5 files changed

+395
-0
lines changed

5 files changed

+395
-0
lines changed

ffi/scripts/benchmark_dlpack.py

Lines changed: 345 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,345 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""
18+
This script is used to benchmark the API overhead of different
19+
python FFI API calling overhead, through DLPack API.
20+
21+
Specifically, we would like to understand the overall overhead
22+
python/C++ API calls. The general goal is to understand the overall
23+
space and get a sense of what are the possible operations.
24+
25+
We pick function f(x, y, z) where x, y, z are length 1 tensors.
26+
The benchmark is running in eager mode so we can see what is possible.
27+
It is orthogonal to other optimizations. For example cudagraph can
28+
eliminate these overheads completely. So the goal is to get a sense
29+
of what is possible under eager mode.
30+
31+
Summary of some takeaways:
32+
- numpy.add roughly takes 0.36 us per call, which gives roughly what can
33+
be done in python env.
34+
- torch.add on gpu takes about 3.7us per call, giving us an idea of what
35+
roughly we need to get to in eager mode.
36+
-
37+
38+
"""
39+
import torch
40+
import numpy as np
41+
from tvm import ffi as tvm_ffi
42+
import time
43+
44+
45+
def print_speed(name, speed):
46+
print(f"{name:<40} {speed} sec/call")
47+
48+
49+
def print_error(name, error):
50+
print(f"{name:<40} {error}")
51+
52+
53+
def baseline_torch_add(repeat):
54+
"""Run torch.add with one element"""
55+
56+
def run_bench(device):
57+
x = torch.arange(1, device=device)
58+
y = torch.arange(1, device=device)
59+
z = torch.arange(1, device=device)
60+
61+
torch.add(x, y, out=z)
62+
if device == "cuda":
63+
torch.cuda.synchronize()
64+
start = time.time()
65+
for i in range(repeat):
66+
torch.add(x, y, out=z)
67+
# note we deliberately do not use torch.cuda.synchronize()
68+
# because we want to see the overhead of the FFI call.
69+
end = time.time()
70+
print_speed(f"torch.add[{device}]", (end - start) / repeat)
71+
72+
# rough take away: add on cuda roughly takes 3e-6 sec/call
73+
run_bench("cpu")
74+
run_bench("cuda")
75+
76+
77+
def baseline_numpy_add(repeat):
78+
"""Run numpy.add with one element"""
79+
x = np.arange(1)
80+
y = np.arange(1)
81+
z = np.arange(1)
82+
83+
np.add(x, y, out=z)
84+
start = time.time()
85+
for i in range(repeat):
86+
np.add(x, y, out=z)
87+
end = time.time()
88+
speed = (end - start) / repeat
89+
print_speed("numpy.add", speed)
90+
91+
92+
def baseline_cupy_add(repeat):
93+
"""Run cupy.add with one element"""
94+
try:
95+
import cupy
96+
except ImportError:
97+
# skip if cupy is not installed
98+
return
99+
x = cupy.arange(1)
100+
y = cupy.arange(1)
101+
z = cupy.arange(1)
102+
103+
cupy.add(x, y, out=z)
104+
start = time.time()
105+
for i in range(repeat):
106+
cupy.add(x, y, out=z)
107+
end = time.time()
108+
speed = (end - start) / repeat
109+
print_speed("cupy.add", speed)
110+
111+
112+
def tvm_ffi_nop(repeat):
113+
"""Overhead of tvm FFI python call via calling a NOP.
114+
115+
testing.nop is defined in c++ and do nothing.
116+
"""
117+
nop = tvm_ffi.get_global_func("testing.nop")
118+
x = tvm_ffi.from_dlpack(torch.arange(1))
119+
y = tvm_ffi.from_dlpack(torch.arange(1))
120+
z = tvm_ffi.from_dlpack(torch.arange(1))
121+
nop(x, y, z)
122+
start = time.time()
123+
for i in range(repeat):
124+
y = tvm_ffi.from_dlpack(x)
125+
end = time.time()
126+
print_speed("tvm.ffi.nop", (end - start) / repeat)
127+
128+
129+
def bench_ffi_nop_from_dlpack(name, x, y, z, repeat):
130+
"""run dlpack conversion + tvm.ffi.nop
131+
132+
Measures overhead of running dlpack for each args then invoke
133+
"""
134+
nop = tvm_ffi.get_global_func("testing.nop")
135+
tx = tvm_ffi.from_dlpack(x)
136+
ty = tvm_ffi.from_dlpack(y)
137+
tz = tvm_ffi.from_dlpack(z)
138+
nop(tx, ty, tz)
139+
140+
start = time.time()
141+
for i in range(repeat):
142+
tx = tvm_ffi.from_dlpack(x)
143+
ty = tvm_ffi.from_dlpack(y)
144+
tz = tvm_ffi.from_dlpack(z)
145+
nop(tx, ty, tz)
146+
end = time.time()
147+
print_speed(name, (end - start) / repeat)
148+
149+
150+
def tvm_ffi_nop_from_torch_dlpack(repeat):
151+
"""run dlpack conversion + tvm.ffi.nop
152+
153+
Measures overhead of running dlpack for each args then invoke
154+
"""
155+
x = torch.arange(1)
156+
y = torch.arange(1)
157+
z = torch.arange(1)
158+
bench_ffi_nop_from_dlpack("tvm.ffi.nop+from_dlpack(torch)", x, y, z, repeat)
159+
160+
161+
def tvm_ffi_nop_from_numpy_dlpack(repeat):
162+
"""run dlpack conversion + tvm.ffi.nop
163+
164+
Measures overhead of running dlpack for each args then invoke
165+
"""
166+
x = np.arange(1)
167+
y = np.arange(1)
168+
z = np.arange(1)
169+
bench_ffi_nop_from_dlpack("tvm.ffi.nop+from_dlpack(numpy)", x, y, z, repeat)
170+
171+
172+
def tvm_ffi_self_dlpack_nop(repeat):
173+
"""run dlpack conversion + tvm.ffi.nop
174+
175+
Measures overhead of running dlpack for each args then invoke
176+
"""
177+
x = tvm_ffi.from_dlpack(torch.arange(1))
178+
y = tvm_ffi.from_dlpack(torch.arange(1))
179+
z = tvm_ffi.from_dlpack(torch.arange(1))
180+
bench_ffi_nop_from_dlpack("tvm.ffi.nop+from_dlpack(tvm)", x, y, z, repeat)
181+
182+
183+
def bench_ffi_nop_from_dlpack(name, x, y, z, repeat):
184+
"""run dlpack conversion + tvm.ffi.nop
185+
186+
Measures overhead of running dlpack for each args then invoke
187+
"""
188+
nop = tvm_ffi.get_global_func("testing.nop")
189+
tx = tvm_ffi.from_dlpack(x)
190+
ty = tvm_ffi.from_dlpack(y)
191+
tz = tvm_ffi.from_dlpack(z)
192+
nop(tx, ty, tz)
193+
194+
start = time.time()
195+
for i in range(repeat):
196+
tx = tvm_ffi.from_dlpack(x)
197+
ty = tvm_ffi.from_dlpack(y)
198+
tz = tvm_ffi.from_dlpack(z)
199+
nop(tx, ty, tz)
200+
end = time.time()
201+
print_speed(name, (end - start) / repeat)
202+
203+
204+
def tvm_ffi_nop_from_torch_utils_to_dlpack(repeat):
205+
"""
206+
Measures overhead of running dlpack for each args then invoke
207+
but uses the legacy torch.utils.dlpack.to_dlpack API
208+
209+
This helps to measure possible implementation overhead of torch.
210+
"""
211+
nop = tvm_ffi.get_global_func("testing.nop")
212+
x = torch.arange(1)
213+
y = torch.arange(1)
214+
z = torch.arange(1)
215+
216+
tx = tvm_ffi.from_dlpack(torch.utils.dlpack.to_dlpack(x))
217+
ty = tvm_ffi.from_dlpack(torch.utils.dlpack.to_dlpack(y))
218+
tz = tvm_ffi.from_dlpack(torch.utils.dlpack.to_dlpack(z))
219+
nop(tx, ty, tz)
220+
221+
start = time.time()
222+
for i in range(repeat):
223+
tx = tvm_ffi.from_dlpack(torch.utils.dlpack.to_dlpack(x))
224+
ty = tvm_ffi.from_dlpack(torch.utils.dlpack.to_dlpack(y))
225+
tz = tvm_ffi.from_dlpack(torch.utils.dlpack.to_dlpack(z))
226+
nop(tx, ty, tz)
227+
end = time.time()
228+
speed = (end - start) / repeat
229+
print_speed("tvm.ffi.nop+from_dlpack(torch.utils)", speed)
230+
231+
232+
def bench_tvm_ffi_nop_autodlpack(name, x, y, z, repeat):
233+
"""
234+
Measures overhead of running dlpack via auto convert by directly
235+
take torch.Tensor as inputs.
236+
"""
237+
nop = tvm_ffi.get_global_func("testing.nop")
238+
nop(x, y, z)
239+
start = time.time()
240+
for i in range(repeat):
241+
nop(x, y, z)
242+
end = time.time()
243+
speed = (end - start) / repeat
244+
print_speed(name, speed)
245+
246+
247+
def tvm_ffi_nop_autodlpack_from_torch(repeat, device="cpu"):
248+
"""
249+
Measures overhead of running dlpack via auto convert by directly
250+
take torch.Tensor as inputs.
251+
"""
252+
# use larger to ensure alignment req is met
253+
x = torch.arange(1, device=device)
254+
y = torch.arange(1, device=device)
255+
z = torch.arange(1, device=device)
256+
bench_tvm_ffi_nop_autodlpack(f"tvm.ffi.nop.autodlpack(torch[{device}])", x, y, z, repeat)
257+
258+
259+
def tvm_ffi_nop_autodlpack_from_numpy(repeat):
260+
"""
261+
Measures overhead of running dlpack via auto convert by directly
262+
take numpy.ndarray as inputs.
263+
"""
264+
# use larger to ensure alignment req is met
265+
x = np.arange(256)
266+
y = np.arange(256)
267+
z = np.arange(256)
268+
bench_tvm_ffi_nop_autodlpack("tvm.ffi.nop.autodlpack(numpy)", x, y, z, repeat)
269+
270+
271+
def bench_to_dlpack(x, name, repeat):
272+
x.__dlpack__()
273+
start = time.time()
274+
for i in range(repeat):
275+
x.__dlpack__()
276+
end = time.time()
277+
speed = (end - start) / repeat
278+
print_speed(name, speed)
279+
280+
281+
def bench_to_dlpack_versioned(x, name, repeat, max_version=(1, 1)):
282+
"""
283+
Measures overhead of running dlpack with latest 1.1.
284+
"""
285+
try:
286+
x.__dlpack__(max_version=max_version)
287+
start = time.time()
288+
for i in range(repeat):
289+
x.__dlpack__(max_version=max_version)
290+
end = time.time()
291+
speed = (end - start) / repeat
292+
print_speed(name, speed)
293+
except Exception as e:
294+
print_error(name, e)
295+
296+
297+
def bench_torch_utils_to_dlpack(repeat):
298+
"""
299+
Measures overhead of running torch.utils.dlpack.to_dlpack
300+
"""
301+
x = torch.arange(1)
302+
torch.utils.dlpack.to_dlpack(x)
303+
start = time.time()
304+
for i in range(repeat):
305+
torch.utils.dlpack.to_dlpack(x)
306+
end = time.time()
307+
speed = (end - start) / repeat
308+
print_speed("torch.utils.dlpack.to_dlpack", speed)
309+
310+
311+
def main():
312+
repeat = 10000
313+
print("-----------------------------")
314+
print("Benchmark f(x, y, z) overhead")
315+
print("-----------------------------")
316+
baseline_numpy_add(repeat)
317+
baseline_torch_add(repeat)
318+
baseline_cupy_add(repeat)
319+
tvm_ffi_nop(repeat)
320+
tvm_ffi_nop_from_torch_dlpack(repeat)
321+
tvm_ffi_nop_from_numpy_dlpack(repeat)
322+
tvm_ffi_self_dlpack_nop(repeat)
323+
tvm_ffi_nop_from_torch_utils_to_dlpack(repeat)
324+
tvm_ffi_nop_autodlpack_from_torch(repeat, "cpu")
325+
tvm_ffi_nop_autodlpack_from_torch(repeat, "cuda")
326+
tvm_ffi_nop_autodlpack_from_numpy(repeat)
327+
print("-------------------------------")
328+
print("Benchmark x.__dlpack__ overhead")
329+
print("-------------------------------")
330+
bench_torch_utils_to_dlpack(repeat)
331+
bench_to_dlpack(torch.arange(1), "torch.__dlpack__", repeat)
332+
bench_to_dlpack(np.arange(1), "numpy.__dlpack__", repeat)
333+
bench_to_dlpack(tvm_ffi.from_dlpack(torch.arange(1)), "tvm.__dlpack__", repeat)
334+
print("---------------------------------------------------")
335+
print("Benchmark x.__dlpack__(max_version=(1,1)) overhead")
336+
print("---------------------------------------------------")
337+
bench_to_dlpack_versioned(torch.arange(1), "torch.__dlpack__(max_version=(1,1))", repeat)
338+
bench_to_dlpack_versioned(np.arange(1), "numpy.__dlpack__(max_version=(1,1))", repeat)
339+
bench_to_dlpack_versioned(
340+
tvm_ffi.from_dlpack(torch.arange(1)), "tvm.__dlpack__(max_version=(1,1))", repeat
341+
)
342+
343+
344+
if __name__ == "__main__":
345+
main()

python/tvm/ffi/convert.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,11 @@ def convert(value: Any) -> Any:
5454
return core._convert_to_ffi_func(value)
5555
elif value is None:
5656
return None
57+
elif hasattr(value, "__dlpack__"):
58+
return core.from_dlpack(
59+
value,
60+
required_alignment=core.__dlpack_auto_import_required_alignment__,
61+
)
5762
elif isinstance(value, Exception):
5863
return core._convert_to_ffi_error(value)
5964
else:

python/tvm/ffi/cython/function.pxi

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@
1717
import ctypes
1818
from numbers import Real, Integral
1919

20+
try:
21+
import torch
22+
except ImportError:
23+
torch = None
24+
2025

2126
cdef inline object make_ret(TVMFFIAny result):
2227
"""convert result to return value."""
@@ -71,6 +76,17 @@ cdef inline int make_args(tuple py_args, TVMFFIAny* out, list temp_args) except
7176
elif isinstance(arg, Object):
7277
out[i].type_index = TVMFFIObjectGetTypeIndex((<Object>arg).chandle)
7378
out[i].v_ptr = (<Object>arg).chandle
79+
elif torch is not None and isinstance(arg, torch.Tensor):
80+
arg = from_dlpack(torch.utils.dlpack.to_dlpack(arg),
81+
required_alignment=__dlpack_auto_import_required_alignment__)
82+
out[i].type_index = kTVMFFINDArray
83+
out[i].v_ptr = (<NDArray>arg).chandle
84+
temp_args.append(arg)
85+
elif hasattr(arg, "__dlpack__"):
86+
arg = from_dlpack(arg, required_alignment=__dlpack_auto_import_required_alignment__)
87+
out[i].type_index = kTVMFFINDArray
88+
out[i].v_ptr = (<NDArray>arg).chandle
89+
temp_args.append(arg)
7490
elif isinstance(arg, PyNativeObject):
7591
arg = arg.__tvm_ffi_object__
7692
out[i].type_index = TVMFFIObjectGetTypeIndex((<Object>arg).chandle)

python/tvm/ffi/cython/ndarray.pxi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@
1616
# under the License.
1717

1818
__dlpack_version__ = (1, 1)
19+
__dlpack_auto_import_required_alignment__ = 8
1920
_CLASS_NDARRAY = None
2021

22+
2123
def _set_class_ndarray(cls):
2224
global _CLASS_NDARRAY
2325
_CLASS_NDARRAY = cls

0 commit comments

Comments
 (0)