Skip to content

Commit 2ca6ec8

Browse files
[Relax][PyTorch] Sort.default (#17852)
Add support for sort.default in exported program translator. There was an existing _sort() function in base_fx_graph_translator.py, but it would return values only. Pytorch returns a tuple of values and indices, so that was corrected
1 parent fe1b228 commit 2ca6ec8

File tree

4 files changed

+45
-5
lines changed

4 files changed

+45
-5
lines changed

python/tvm/relax/frontend/torch/base_fx_graph_translator.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1278,10 +1278,15 @@ def _scatter(self, node: fx.Node) -> relax.Var:
12781278
return self.block_builder.emit(relax.op.scatter_elements(x, index, src, axis=dim))
12791279

12801280
def _sort(self, node: fx.Node) -> relax.Var:
1281+
# torch.sort() returns a tuple of values and indices
1282+
# we use argsort to get indices and gather_elements to get values
12811283
x = self.env[node.args[0]]
12821284
dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", -1)
12831285
descending = node.args[2] if len(node.args) > 2 else node.kwargs.get("descending", False)
1284-
return self.block_builder.emit(relax.op.sort(x, dim, descending))
1286+
1287+
indices = self.block_builder.emit(relax.op.argsort(x, dim, descending))
1288+
values = self.block_builder.emit(relax.op.gather_elements(x, indices, axis=dim))
1289+
return self.block_builder.emit(relax.Tuple([values, indices]))
12851290

12861291
def _split(self, node: fx.Node) -> relax.Var:
12871292
x = self.env[node.args[0]]

python/tvm/relax/frontend/torch/exported_program_translator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,7 @@ def create_convert_map(
431431
"roll.default": self._roll,
432432
"select.int": self._select,
433433
"slice.Tensor": self._slice,
434+
"sort.default": self._sort,
434435
"split.Tensor": self._split,
435436
"split_with_sizes.default": self._split,
436437
"squeeze.default": self._squeeze,

tests/python/relax/test_from_exported_to_cuda.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,31 @@ def forward(self, x):
208208
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev)
209209

210210

211+
@tvm.testing.parametrize_targets("cuda")
212+
def test_sort(target, dev):
213+
raw_data = np.array([[4, 1, 13], [-30, 1, 3], [4, 0, 10]]).astype("float32")
214+
215+
# Test values
216+
class SortModelValues(nn.Module):
217+
def forward(self, x):
218+
A, _ = torch.sort(x, dim=0, descending=True)
219+
B, _ = torch.sort(x, dim=1, descending=False)
220+
return A + B
221+
222+
torch_module = SortModelValues().eval()
223+
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev)
224+
225+
# Test indices
226+
class SortModelIndices(nn.Module):
227+
def forward(self, x):
228+
_, A = torch.sort(x, dim=0, descending=True)
229+
_, B = torch.sort(x, dim=1, descending=False)
230+
return A + B
231+
232+
torch_module = SortModelIndices().eval()
233+
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev)
234+
235+
211236
@tvm.testing.parametrize_targets("cuda")
212237
def test_tensor_clamp(target, dev):
213238
class ClampBothTensor(torch.nn.Module):

tests/python/relax/test_frontend_from_fx.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4749,11 +4749,20 @@ def forward(self, x):
47494749
class Expected:
47504750
@R.function
47514751
def main(
4752-
inp_0: R.Tensor((5, 3), dtype="float32"),
4753-
) -> R.Tensor((5, 3), dtype="float32"):
4752+
inp_0: R.Tensor((5, 3), dtype="float32")
4753+
) -> R.Tuple(R.Tensor((5, 3), dtype="float32"), R.Tensor((5, 3), dtype="int32")):
47544754
with R.dataflow():
4755-
lv: R.Tensor((5, 3), dtype="float32") = R.sort(inp_0, axis=1, descending=True)
4756-
gv: R.Tensor((5, 3), dtype="float32") = lv
4755+
lv: R.Tensor((5, 3), dtype="int32") = R.argsort(
4756+
inp_0, axis=1, descending=True, dtype="int32"
4757+
)
4758+
lv1: R.Tensor((5, 3), dtype="float32") = R.gather_elements(inp_0, lv, axis=1)
4759+
lv2: R.Tuple(R.Tensor((5, 3), dtype="float32"), R.Tensor((5, 3), dtype="int32")) = (
4760+
lv1,
4761+
lv,
4762+
)
4763+
gv: R.Tuple(
4764+
R.Tensor((5, 3), dtype="float32"), R.Tensor((5, 3), dtype="int32")
4765+
) = lv2
47574766
R.output(gv)
47584767
return gv
47594768

0 commit comments

Comments
 (0)