Skip to content

Commit b5b0337

Browse files
[Relax][PyTorch] support for index.Tensor (#17836)
New op for advanced indexing + unit tests
1 parent 111ddf7 commit b5b0337

File tree

10 files changed

+424
-6
lines changed

10 files changed

+424
-6
lines changed

python/tvm/relax/frontend/torch/base_fx_graph_translator.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1148,6 +1148,11 @@ def _gather(self, node: fx.Node) -> relax.Var:
11481148
index = self.env[node.args[2]]
11491149
return self.block_builder.emit(relax.op.gather_elements(x, index, axis=dim))
11501150

1151+
def _index_tensor(self, node: fx.Node) -> relax.Var:
1152+
args = self.retrieve_args(node)
1153+
indices = args[1]
1154+
return self.block_builder.emit(relax.op.index_tensor(args[0], indices))
1155+
11511156
def _permute(self, node: fx.Node) -> relax.Var:
11521157
import torch # type: ignore
11531158

python/tvm/relax/frontend/torch/exported_program_translator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,7 @@ def create_convert_map(
420420
"flatten.using_ints": self._flatten,
421421
"flip.default": self._flip,
422422
"gather.default": self._gather,
423+
"index.Tensor": self._index_tensor,
423424
"narrow.default": self._narrow,
424425
"permute.default": self._permute,
425426
"repeat.default": self._repeat,

python/tvm/relax/op/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@
9595
flip,
9696
gather_elements,
9797
gather_nd,
98+
index_tensor,
9899
layout_transform,
99100
one_hot,
100101
permute_dims,

python/tvm/relax/op/manipulate.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -532,6 +532,69 @@ def gather_nd(data: Expr, indices: Expr, batch_dims: int = 0) -> Expr:
532532
return _ffi_api.gather_nd(data, indices, batch_dims) # type: ignore
533533

534534

535+
def index_tensor(data: Expr, indices: Union[Expr, List[Expr]]) -> Expr:
536+
"""Advanced‑tensor indexing (NumPy/PyTorch‐style).
537+
538+
Given k index tensors ``indices = (I0, I1, …, Ik‑1)`` this
539+
operator selects elements from ``data`` as if one had written
540+
``data[I0, I1, …, Ik‑1]`` in NumPy/PyTorch:
541+
542+
All index tensors must have an integer dtype.
543+
544+
Their shapes are broadcast together to a common shape ``B`` in
545+
the usual NumPy way.
546+
547+
The result shape is ``B + data.shape[k:]`` (i.e. the broadcast
548+
shape followed by the remaining axes of ``data`` that are *not*
549+
indexed).
550+
551+
At compile‑time Relax checks that the number of index tensors
552+
``k`` does not exceed ``data.ndim``, that the dtypes are integer,
553+
and that the shapes are consitent (broadcast‑compatible).
554+
555+
Parameters
556+
----------
557+
data : relax.Expr
558+
The input tensor to be indexed.
559+
560+
indices : Union[relax.Expr, List[relax.Expr]]
561+
A Tuple expression containing the index tensors,
562+
or a Python ``list`` / ``tuple`` that will be promoted to a
563+
tuple expression automatically. Each tensor must have an
564+
integer dtype.
565+
566+
Returns
567+
-------
568+
result : relax.Expr
569+
The tensor obtained after advanced indexing. Its dtype equals
570+
``data.dtype``
571+
572+
Examples
573+
--------
574+
.. code-block:: python
575+
576+
import numpy as np
577+
import tvm.relax as R
578+
579+
x = R.const(np.arange(9).reshape(3, 3).astype("float32"))
580+
row = R.const(np.array([0, 2])) # shape (2,)
581+
col = R.const(np.array([1, 0])) # shape (2,)
582+
583+
y = R.index_tensor(x, [row, col])
584+
# y.shape == (2,) ; y == [1., 6.]
585+
586+
# Broadcasting: row : (2,1), col : (1,3) → B = (2,3)
587+
row = R.const(np.array([[0],[1]]))
588+
col = R.const(np.array([[0,1,2]]))
589+
z = R.index_tensor(x, [row, col])
590+
# z.shape == (2,3)
591+
592+
"""
593+
if isinstance(indices, (list, tuple)):
594+
indices = RxTuple(indices)
595+
return _ffi_api.index_tensor(data, indices) # type: ignore
596+
597+
535598
def scatter_elements(
536599
data: Expr, indices: Expr, updates: Expr, axis: int = 0, reduction: str = "update"
537600
):

python/tvm/relax/transform/legalize_ops/manipulate.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def reshape_call_te(bb: BlockBuilder, call: Call):
4949
"relax.collapse_sum_like",
5050
_reshape(topi.collapse_sum, "collapse_sum", is_collapse_sum_like=True),
5151
)
52+
5253
register_legalize("relax.collapse_sum_to", _reshape(topi.collapse_sum, "collapse_sum"))
5354

5455

@@ -184,6 +185,14 @@ def te_gather_nd(data, indices, batch_dims):
184185
return bb.call_te(te_gather_nd, call.args[0], call.args[1], int(call.attrs.batch_dims))
185186

186187

188+
@register_legalize("relax.index_tensor")
189+
def _index_tensor(bb: BlockBuilder, call: Call) -> Expr:
190+
t = call.args[1]
191+
n_field = len(t.struct_info.fields)
192+
fields = [bb.emit(TupleGetItem(t, i)) for i in range(n_field)]
193+
return bb.call_te(topi.index_tensor, call.args[0], fields)
194+
195+
187196
@register_legalize("relax.scatter_elements")
188197
def _scatter_elements(bb: BlockBuilder, call: Call) -> Expr:
189198
return bb.call_te(

python/tvm/script/ir_builder/relax/ir.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@
101101
greater_equal,
102102
hint_on_device,
103103
image,
104+
index_tensor,
104105
invoke_closure,
105106
invoke_pure_closure,
106107
isfinite,
@@ -785,6 +786,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr:
785786
"hexagon",
786787
"hint_on_device",
787788
"image",
789+
"index_tensor",
788790
"invoke_closure",
789791
"invoke_pure_closure",
790792
"isfinite",

python/tvm/topi/transform.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1054,3 +1054,54 @@ def _apply_trilu(*indices):
10541054
return tvm.tir.Select(check_position, value, tvm.tir.const(0, data.dtype))
10551055

10561056
return te.compute(data.shape, _apply_trilu, name="trilu", tag=topi.tag.ELEMWISE)
1057+
1058+
1059+
def index_tensor(data, indices):
1060+
"""Advanced‑tensor indexing (NumPy/PyTorch‐style).
1061+
1062+
Given k index tensors ``indices = (I0, I1, …, Ik‑1)`` this
1063+
operator selects elements from ``data`` as if one had written
1064+
``data[I0, I1, …, Ik‑1]`` in NumPy/PyTorch:
1065+
1066+
* All index tensors must have an integer dtype.
1067+
* Their shapes are broadcast together to a common shape ``B`` in
1068+
the usual NumPy way.
1069+
* The result shape is ``B + data.shape[k:]`` (i.e. the broadcast
1070+
shape followed by the remaining axes of ``data`` that are *not*
1071+
indexed).
1072+
* ``k`` must not exceed ``data.ndim``; otherwise a compile‑time
1073+
error is raised.
1074+
1075+
Parameters
1076+
----------
1077+
data : tvm.te.Tensor
1078+
The tensor to be indexed.
1079+
1080+
indices : Sequence[tvm.te.Tensor]
1081+
A Python ``list`` / ``tuple`` of **k** index tensors,
1082+
or a `tvm.te.Tensor` tuple expression. Each tensor must have an
1083+
integer dtype.
1084+
1085+
Returns
1086+
-------
1087+
result : tvm.te.Tensor
1088+
The tensor obtained after advanced indexing. Its dtype equals
1089+
``data.dtype``
1090+
1091+
Examples
1092+
--------
1093+
.. code-block:: python
1094+
1095+
x = te.placeholder((3, 3), name="x") # shape (3,3)
1096+
row = te.placeholder((2,), name="row", dtype="int32")
1097+
col = te.placeholder((2,), name="col", dtype="int32")
1098+
1099+
# Equivalent to x[row, col] in NumPy / PyTorch
1100+
y = topi.index_tensor(x, [row, col]) # shape (2,)
1101+
1102+
# Broadcasting example:
1103+
row = te.placeholder((2, 1), name="row", dtype="int32")
1104+
col = te.placeholder((1, 3), name="col", dtype="int32")
1105+
z = topi.index_tensor(x, [row, col]) # shape (2, 3)
1106+
"""
1107+
return topi.adv_index(data, indices)

src/relax/op/tensor/manipulate.cc

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -474,6 +474,151 @@ TVM_REGISTER_OP("relax.flatten")
474474
.set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow)
475475
.set_attr<Bool>("FPurity", Bool(true));
476476

477+
/* relax.index_tensor */
478+
479+
Expr index_tensor(Expr first, Expr tensors) {
480+
static const Op& op = Op::Get("relax.index_tensor");
481+
return Call(op, {std::move(first), std::move(tensors)}, Attrs(), {});
482+
}
483+
484+
TVM_REGISTER_GLOBAL("relax.op.index_tensor").set_body_typed(index_tensor);
485+
486+
StructInfo InferStructInfoIndexTensor(const Call& call, const BlockBuilder& ctx) {
487+
if (call->args.size() != 2) {
488+
ctx->ReportFatal(Diagnostic::Error(call) << "Index.Tensor op should have 2 arguments");
489+
}
490+
491+
TensorStructInfo data_sinfo = GetInputTensorStructInfo(call, 0, ctx);
492+
Array<TensorStructInfo> indices_sinfo = GetTensorStructInfoFromTuple(call, ctx, call->args[1]);
493+
494+
if (indices_sinfo.empty()) {
495+
ctx->ReportFatal(Diagnostic::Error(call)
496+
<< "index_tensor expects a non‑empty tuple of index tensors");
497+
}
498+
499+
DataType output_dtype = data_sinfo->dtype;
500+
int n_indices = static_cast<int>(indices_sinfo.size());
501+
Optional<VDevice> vdev = data_sinfo->vdevice;
502+
503+
// Indices must be integers
504+
for (int i = 0; i < n_indices; ++i) {
505+
const auto& s = indices_sinfo[i];
506+
if (!s->IsUnknownDtype() && !s->dtype.is_int()) {
507+
ctx->ReportFatal(Diagnostic::Error(call)
508+
<< "index_tensor requires every index tensor to have an integer dtype; "
509+
<< "index " << i << " has dtype " << s->dtype);
510+
}
511+
}
512+
513+
// Count of indices must be less than or equal to data.ndim
514+
if (!data_sinfo->IsUnknownNdim() && n_indices > data_sinfo->ndim) {
515+
ctx->ReportFatal(Diagnostic::Error(call)
516+
<< "index_tensor received " << n_indices
517+
<< " index tensors, but data has only " << data_sinfo->ndim << " dimensions");
518+
}
519+
520+
arith::Analyzer* analyzer = ctx->GetAnalyzer();
521+
bool all_index_have_shape_value = true;
522+
std::vector<Array<PrimExpr>> index_shapes;
523+
int max_index_ndim = 0;
524+
525+
for (const auto& s : indices_sinfo) {
526+
const auto* shp = s->shape.as<ShapeExprNode>();
527+
if (!shp) {
528+
all_index_have_shape_value = false;
529+
} else {
530+
index_shapes.push_back(shp->values);
531+
max_index_ndim = std::max(max_index_ndim, static_cast<int>(shp->values.size()));
532+
}
533+
if (!s->IsUnknownNdim()) {
534+
max_index_ndim = std::max(max_index_ndim, s->ndim);
535+
}
536+
}
537+
538+
Optional<Array<PrimExpr>> broadcast_shape;
539+
bool shape_unknown = !all_index_have_shape_value;
540+
541+
if (all_index_have_shape_value) {
542+
// initialise broadcast result with 1’s
543+
Array<PrimExpr> out_shape;
544+
for (int i = 0; i < max_index_ndim; ++i) {
545+
out_shape.push_back(IntImm(DataType::Int(64), 1));
546+
}
547+
548+
for (const auto& ishape : index_shapes) {
549+
int cur_ndim = ishape.size();
550+
for (int axis = 0; axis < max_index_ndim; ++axis) {
551+
int lhs_axis = max_index_ndim - 1 - axis; // aligned from right
552+
int rhs_axis = cur_ndim - 1 - axis;
553+
if (rhs_axis < 0) break; // shorter rank – done
554+
555+
PrimExpr lhs_dim = out_shape[lhs_axis];
556+
PrimExpr rhs_dim = ishape[rhs_axis];
557+
558+
const auto* lhs_int = lhs_dim.as<IntImmNode>();
559+
const auto* rhs_int = rhs_dim.as<IntImmNode>();
560+
561+
// Case 1: current broadcast slot is 1 -> always replace
562+
if (lhs_int && lhs_int->value == 1) {
563+
out_shape.Set(lhs_axis, rhs_dim);
564+
continue;
565+
}
566+
// Case 2: rhs is 1 -> keep lhs_dim unchanged
567+
if (rhs_int && rhs_int->value == 1) {
568+
continue;
569+
}
570+
// Both are non‑one constants: must equal
571+
if (lhs_int && rhs_int && lhs_int->value != rhs_int->value) {
572+
ctx->ReportFatal(Diagnostic::Error(call)
573+
<< "index_tensor: cannot broadcast index shapes. Mismatch at axis "
574+
<< lhs_axis << ": " << lhs_dim << " vs " << rhs_dim);
575+
}
576+
// Give up if not provablt equal
577+
if (!analyzer->CanProveEqual(lhs_dim, rhs_dim)) {
578+
shape_unknown = true;
579+
break;
580+
}
581+
}
582+
if (shape_unknown) break;
583+
}
584+
585+
if (!shape_unknown) broadcast_shape = out_shape;
586+
}
587+
588+
// Count of dimensions in output
589+
int out_ndim = kUnknownNDim;
590+
if (!data_sinfo->IsUnknownNdim()) {
591+
int tail_ndim = data_sinfo->ndim - n_indices;
592+
if (broadcast_shape.defined()) {
593+
out_ndim = static_cast<int>(broadcast_shape.value().size()) + tail_ndim;
594+
} else if (!shape_unknown) {
595+
out_ndim = max_index_ndim + tail_ndim;
596+
}
597+
}
598+
599+
// Derive output shape
600+
if (broadcast_shape.defined()) {
601+
const auto* data_shape_expr = data_sinfo->shape.as<ShapeExprNode>();
602+
if (data_shape_expr) {
603+
Array<PrimExpr> result_shape = broadcast_shape.value();
604+
for (int i = n_indices; i < data_sinfo->ndim; ++i) {
605+
result_shape.push_back(data_shape_expr->values[i]);
606+
}
607+
return TensorStructInfo(ShapeExpr(result_shape), output_dtype, vdev);
608+
}
609+
}
610+
611+
// Unknown output shape
612+
return TensorStructInfo(output_dtype, out_ndim, vdev);
613+
}
614+
615+
TVM_REGISTER_OP("relax.index_tensor")
616+
.set_num_inputs(2)
617+
.add_argument("data", "Tensor", "The input data.")
618+
.add_argument("indices", "List of Tensors", "The indices used to index.")
619+
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoIndexTensor)
620+
.set_attr<Bool>("FPurity", Bool(true));
621+
477622
/* relax.layout_transform */
478623
TVM_REGISTER_NODE_TYPE(LayoutTransformAttrs);
479624

src/relax/op/tensor/manipulate.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,18 @@ Expr gather_elements(Expr data, Expr indices, int axis = 0);
206206
*/
207207
Expr gather_nd(Expr data, Expr indices, int batch_dims = 0);
208208

209+
/*!
210+
* \brief NumPy/PyTorch‑style advanced indexing with tensors.
211+
* \param data The input tensor.
212+
* \param indices A Tuple expression (or list) containing the index tensors.
213+
* \return The indexed tensor.
214+
*
215+
* \note When all shapes are static, Relax checks that the index shapes are
216+
* broadcast-compatible. Bounds checking of the values in indices is
217+
* deferred to runtime.
218+
*/
219+
Expr index_tensor(Expr data, Expr indices);
220+
209221
/*!
210222
* \brief Scatter updates into an array according to indices.
211223
* \param data The input tensor.

0 commit comments

Comments
 (0)