|
20 | 20 | from tvm import te
|
21 | 21 |
|
22 | 22 | from ..transform import strided_slice, transpose
|
23 |
| -from ..utils import ceil_div, swap |
| 23 | +from ..utils import ceil_div, swap, prod |
24 | 24 | from ..math import cast, ceil_log2
|
| 25 | +from ..searchsorted import binary_search |
25 | 26 |
|
26 | 27 |
|
27 | 28 | def _get_threads(ib, nthread_tx, nthread_bx, nthread_by):
|
@@ -937,3 +938,89 @@ def f_compute(ins, outs):
|
937 | 938 | out = out[1]
|
938 | 939 |
|
939 | 940 | return out
|
| 941 | + |
| 942 | + |
| 943 | +def searchsorted(sorted_sequence, values, right=False, out_dtype="int64"): |
| 944 | + """Find indices where elements should be inserted to maintain order. |
| 945 | + If `sorted_sequence` is N-dimensional, the innermost dimension of |
| 946 | + `values` are searched in the corresponding dimension of `sorted_sequence`. |
| 947 | +
|
| 948 | + This implementation is optimized for GPU execution. |
| 949 | +
|
| 950 | + Parameters |
| 951 | + ---------- |
| 952 | + sorted_sequence : te.Tensor |
| 953 | + N-D or 1-D Tensor, containing monotonically increasing sequence |
| 954 | + on the innermost dimension. |
| 955 | +
|
| 956 | + values : te.Tensor |
| 957 | + N-D Tensor containing the search values. When `sorted_sequence` is 1-D, |
| 958 | + the shape of `values` can be arbitrary. Otherwise, ranks of `sorted_sequence` |
| 959 | + and `values` must be the same, and outer N-1 axes must have the same size. |
| 960 | +
|
| 961 | + right : bool, optional |
| 962 | + Controls which index is returned if a value lands exactly on one of sorted values. If |
| 963 | + False (side='left'), the index of the first suitable location found is given. If true |
| 964 | + (side='right'), return the last such index. |
| 965 | +
|
| 966 | + out_dtype : string, optional |
| 967 | + The data type of the output indices. |
| 968 | +
|
| 969 | + Returns |
| 970 | + ------- |
| 971 | + indices : te.Tensor |
| 972 | + Tensor with same shape as values, representing the indices of |
| 973 | + elements of `values` if they are inserted in `sorted_sequence`. |
| 974 | + """ |
| 975 | + if len(sorted_sequence.shape) > 1: |
| 976 | + for i in range(len(values.shape) - 1): |
| 977 | + assert ( |
| 978 | + values.shape[i] == sorted_sequence.shape[i] |
| 979 | + ), "Outer dimensions of sorted_sequence and values must match for N-D searchsorted" |
| 980 | + |
| 981 | + def ir(sorted_sequence_buf, values_buf, indices_buf): |
| 982 | + ib = tvm.tir.ir_builder.create() |
| 983 | + sorted_sequence_shape = sorted_sequence_buf.shape |
| 984 | + values_shape = values_buf.shape |
| 985 | + num_search = prod(values_shape) |
| 986 | + search_range = sorted_sequence_shape[-1] |
| 987 | + |
| 988 | + sorted_sequence_ptr = ib.buffer_ptr(sorted_sequence_buf) |
| 989 | + values_ptr = ib.buffer_ptr(values_buf) |
| 990 | + indices_ptr = ib.buffer_ptr(indices_buf) |
| 991 | + |
| 992 | + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) |
| 993 | + nthread_tx = max_threads |
| 994 | + nthread_bx = ceil_div(num_search, nthread_tx) |
| 995 | + tx = te.thread_axis("threadIdx.x") |
| 996 | + bx = te.thread_axis("blockIdx.x") |
| 997 | + ib.scope_attr(tx, "thread_extent", nthread_tx) |
| 998 | + ib.scope_attr(bx, "thread_extent", nthread_bx) |
| 999 | + tid = bx * nthread_tx + tx |
| 1000 | + |
| 1001 | + with ib.if_scope(tid < num_search): |
| 1002 | + if len(sorted_sequence_shape) == 1: |
| 1003 | + sequence_offset = 0 |
| 1004 | + else: |
| 1005 | + sequence_id = tid // values_shape[-1] |
| 1006 | + sequence_offset = sequence_id * search_range |
| 1007 | + |
| 1008 | + indices_ptr[tid] = binary_search( |
| 1009 | + ib, |
| 1010 | + sequence_offset, |
| 1011 | + search_range, |
| 1012 | + sorted_sequence_ptr, |
| 1013 | + values_ptr[tid], |
| 1014 | + right, |
| 1015 | + out_dtype, |
| 1016 | + ) |
| 1017 | + |
| 1018 | + return ib.get() |
| 1019 | + |
| 1020 | + return te.extern( |
| 1021 | + values.shape, |
| 1022 | + [sorted_sequence, values], |
| 1023 | + lambda ins, outs: ir(ins[0], ins[1], outs[0]), |
| 1024 | + name="searchsorted_gpu", |
| 1025 | + dtype=out_dtype, |
| 1026 | + ) |
0 commit comments