Skip to content

Commit 9eb8b30

Browse files
Add support for bucketize (#18040)
* add support for bucketize * fix lint issue * Fix lint issue * Add GPU code for bucketize * Resolve merge conflict * Fix lint issue
1 parent b6db2ec commit 9eb8b30

File tree

14 files changed

+279
-3
lines changed

14 files changed

+279
-3
lines changed

include/tvm/relax/attrs/search.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,24 @@ struct ArgmaxArgminAttrs : public AttrsNodeReflAdapter<ArgmaxArgminAttrs> {
4949
TVM_FFI_DECLARE_FINAL_OBJECT_INFO(ArgmaxArgminAttrs, BaseAttrsNode);
5050
}; // struct ArgmaxArgminAttrs
5151

52+
/*! \brief Attributes for bucketize operator */
53+
struct BucketizeAttrs : public tvm::AttrsNodeReflAdapter<BucketizeAttrs> {
54+
bool out_int32;
55+
bool right;
56+
57+
static void RegisterReflection() {
58+
namespace refl = tvm::ffi::reflection;
59+
refl::ObjectDef<BucketizeAttrs>()
60+
.def_ro("out_int32", &BucketizeAttrs::out_int32,
61+
"Indicate the output datatype, int32 if True, int64 otherwise.")
62+
.def_ro("right", &BucketizeAttrs::right,
63+
"Determines the behavior for values in boundaries");
64+
}
65+
66+
static constexpr const char* _type_key = "relax.attrs.BucketizeAttrs";
67+
TVM_FFI_DECLARE_FINAL_OBJECT_INFO(BucketizeAttrs, BaseAttrsNode);
68+
}; // struct BucketizeAttrs
69+
5270
} // namespace relax
5371
} // namespace tvm
5472

python/tvm/relax/backend/dispatch_sort_scan.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,18 @@ def visit_call_(self, call: relax.Call) -> relax.Expr:
7575
if not isinstance(call.op, Op):
7676
return super().visit_call_(call)
7777

78+
if call.op.name == "relax.bucketize":
79+
input_tensor = call.args[0]
80+
boundaries = call.args[1]
81+
right = call.attrs.right
82+
tgt = self._get_target(call.struct_info)
83+
te_func = topi.searchsorted
84+
with tgt:
85+
if self.is_gpu_target(tgt):
86+
te_func = topi.gpu.searchsorted
87+
return self.builder_.call_te(
88+
te_func, boundaries, input_tensor, right, input_tensor.struct_info.dtype
89+
)
7890
if call.op.name == "relax.sort":
7991
tgt = self._get_target(call.struct_info)
8092
te_func = topi.sort

python/tvm/relax/frontend/torch/base_fx_graph_translator.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1376,6 +1376,18 @@ def _where(self, node: fx.Node) -> relax.Var:
13761376
y = self.env[node.args[2]]
13771377
return self.block_builder.emit(relax.op.where(condition, x, y))
13781378

1379+
def _bucketize(self, node: fx.Node) -> relax.Var:
1380+
args = self.retrieve_args(node)
1381+
input_tensor = args[0]
1382+
boundaries = args[1]
1383+
1384+
right = node.kwargs.get("right", False)
1385+
out_int32 = node.kwargs.get("out_int32", False)
1386+
1387+
return self.block_builder.emit(
1388+
relax.op.bucketize(input_tensor, boundaries, out_int32, right)
1389+
)
1390+
13791391
########## Manipulation ##########
13801392

13811393
def _argsort(self, node: fx.Node) -> relax.Var:

python/tvm/relax/frontend/torch/exported_program_translator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -507,6 +507,7 @@ def create_convert_map(
507507
"argmax.default": self._argmax_argmin(relax.op.argmax),
508508
"argmin.default": self._argmax_argmin(relax.op.argmin),
509509
"where.self": self._where,
510+
"bucketize.Tensor": self._bucketize,
510511
# tensor manipulation
511512
"argsort.default": self._argsort,
512513
"broadcast_to.default": self._broadcast_to,

python/tvm/relax/frontend/torch/fx_translator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -917,6 +917,7 @@ def create_convert_map(
917917
"argmax": self._argmax_argmin(relax.op.argmax),
918918
"argmin": self._argmax_argmin(relax.op.argmin),
919919
"where": self._where,
920+
"bucketize": self._bucketize,
920921
# tensor manipulation
921922
"argsort": self._argsort,
922923
"broadcast_to": self._broadcast_to,

python/tvm/relax/op/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@
115115
from .mask import masked_fill
116116
from .qdq import dequantize, quantize
117117
from .sampling import multinomial_from_uniform
118-
from .search import argmax, argmin, where
118+
from .search import argmax, argmin, where, bucketize
119119
from .set import nonzero, unique
120120
from .sorting import argsort, sort, topk
121121
from .statistical import cumprod, cumsum, max, mean, min, prod, std, sum, variance

python/tvm/relax/op/search.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,3 +102,28 @@ def argmin(x: Expr, axis: Optional[int] = None, keepdims: bool = False) -> Expr:
102102
The computed result.
103103
"""
104104
return _ffi_api.argmin(x, axis, keepdims) # type: ignore
105+
106+
107+
def bucketize(input_tensor, boundaries, out_int32=False, right=False):
108+
"""Returns the indices of the buckets to which each value in the input belongs.
109+
110+
Parameters
111+
----------
112+
input_tensor : relax.Expr
113+
N-D tensor containing the search values.
114+
115+
boundaries : relax.Expr
116+
1-D tensor, must contain a strictly increasing sequence, or the return value is undefined.
117+
118+
out_int32 : Optional[bool]
119+
Indicate the output data type. int32 if True, int64 otherwise. Default=False
120+
121+
right : Optional[bool]
122+
Determines the behavior for values in boundaries. Similar to torch.bucketize
123+
124+
Returns
125+
-------
126+
result : relax.Expr
127+
The computed result with same shape as input_tensor.
128+
"""
129+
return _ffi_api.bucketize(input_tensor, boundaries, out_int32, right)

python/tvm/relax/transform/legalize_ops/search.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,13 @@ def argmax_argmin_call_te(bb: BlockBuilder, call: Call) -> Expr:
3939

4040
register_legalize("relax.argmax", _argmax_argmin(topi.argmax))
4141
register_legalize("relax.argmin", _argmax_argmin(topi.argmin))
42+
43+
44+
@register_legalize("relax.bucketize")
45+
def _bucketize(bb, call):
46+
input_tensor = call.args[0]
47+
boundaries = call.args[1]
48+
right = call.attrs.right
49+
return bb.call_te(
50+
topi.searchsorted, boundaries, input_tensor, right, input_tensor.struct_info.dtype
51+
)

python/tvm/script/ir_builder/relax/ir.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
bitwise_or,
5959
bitwise_xor,
6060
broadcast_to,
61+
bucketize,
6162
builtin,
6263
call_builtin_with_ctx,
6364
call_dps_packed,
@@ -731,6 +732,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr:
731732
"bitwise_or",
732733
"bitwise_xor",
733734
"broadcast_to",
735+
"bucketize",
734736
"builtin",
735737
"call_inplace_packed",
736738
"call_packed",

python/tvm/topi/gpu/sort.py

Lines changed: 88 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@
2020
from tvm import te
2121

2222
from ..transform import strided_slice, transpose
23-
from ..utils import ceil_div, swap
23+
from ..utils import ceil_div, swap, prod
2424
from ..math import cast, ceil_log2
25+
from ..searchsorted import binary_search
2526

2627

2728
def _get_threads(ib, nthread_tx, nthread_bx, nthread_by):
@@ -937,3 +938,89 @@ def f_compute(ins, outs):
937938
out = out[1]
938939

939940
return out
941+
942+
943+
def searchsorted(sorted_sequence, values, right=False, out_dtype="int64"):
944+
"""Find indices where elements should be inserted to maintain order.
945+
If `sorted_sequence` is N-dimensional, the innermost dimension of
946+
`values` are searched in the corresponding dimension of `sorted_sequence`.
947+
948+
This implementation is optimized for GPU execution.
949+
950+
Parameters
951+
----------
952+
sorted_sequence : te.Tensor
953+
N-D or 1-D Tensor, containing monotonically increasing sequence
954+
on the innermost dimension.
955+
956+
values : te.Tensor
957+
N-D Tensor containing the search values. When `sorted_sequence` is 1-D,
958+
the shape of `values` can be arbitrary. Otherwise, ranks of `sorted_sequence`
959+
and `values` must be the same, and outer N-1 axes must have the same size.
960+
961+
right : bool, optional
962+
Controls which index is returned if a value lands exactly on one of sorted values. If
963+
False (side='left'), the index of the first suitable location found is given. If true
964+
(side='right'), return the last such index.
965+
966+
out_dtype : string, optional
967+
The data type of the output indices.
968+
969+
Returns
970+
-------
971+
indices : te.Tensor
972+
Tensor with same shape as values, representing the indices of
973+
elements of `values` if they are inserted in `sorted_sequence`.
974+
"""
975+
if len(sorted_sequence.shape) > 1:
976+
for i in range(len(values.shape) - 1):
977+
assert (
978+
values.shape[i] == sorted_sequence.shape[i]
979+
), "Outer dimensions of sorted_sequence and values must match for N-D searchsorted"
980+
981+
def ir(sorted_sequence_buf, values_buf, indices_buf):
982+
ib = tvm.tir.ir_builder.create()
983+
sorted_sequence_shape = sorted_sequence_buf.shape
984+
values_shape = values_buf.shape
985+
num_search = prod(values_shape)
986+
search_range = sorted_sequence_shape[-1]
987+
988+
sorted_sequence_ptr = ib.buffer_ptr(sorted_sequence_buf)
989+
values_ptr = ib.buffer_ptr(values_buf)
990+
indices_ptr = ib.buffer_ptr(indices_buf)
991+
992+
max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
993+
nthread_tx = max_threads
994+
nthread_bx = ceil_div(num_search, nthread_tx)
995+
tx = te.thread_axis("threadIdx.x")
996+
bx = te.thread_axis("blockIdx.x")
997+
ib.scope_attr(tx, "thread_extent", nthread_tx)
998+
ib.scope_attr(bx, "thread_extent", nthread_bx)
999+
tid = bx * nthread_tx + tx
1000+
1001+
with ib.if_scope(tid < num_search):
1002+
if len(sorted_sequence_shape) == 1:
1003+
sequence_offset = 0
1004+
else:
1005+
sequence_id = tid // values_shape[-1]
1006+
sequence_offset = sequence_id * search_range
1007+
1008+
indices_ptr[tid] = binary_search(
1009+
ib,
1010+
sequence_offset,
1011+
search_range,
1012+
sorted_sequence_ptr,
1013+
values_ptr[tid],
1014+
right,
1015+
out_dtype,
1016+
)
1017+
1018+
return ib.get()
1019+
1020+
return te.extern(
1021+
values.shape,
1022+
[sorted_sequence, values],
1023+
lambda ins, outs: ir(ins[0], ins[1], outs[0]),
1024+
name="searchsorted_gpu",
1025+
dtype=out_dtype,
1026+
)

0 commit comments

Comments
 (0)