Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion python/tvm/relax/frontend/torch/base_fx_graph_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1278,10 +1278,15 @@ def _scatter(self, node: fx.Node) -> relax.Var:
return self.block_builder.emit(relax.op.scatter_elements(x, index, src, axis=dim))

def _sort(self, node: fx.Node) -> relax.Var:
# torch.sort() returns a tuple of values and indices
# we use argsort to get indices and gather_elements to get values
x = self.env[node.args[0]]
dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", -1)
descending = node.args[2] if len(node.args) > 2 else node.kwargs.get("descending", False)
return self.block_builder.emit(relax.op.sort(x, dim, descending))

indices = self.block_builder.emit(relax.op.argsort(x, dim, descending))
values = self.block_builder.emit(relax.op.gather_elements(x, indices, axis=dim))
return self.block_builder.emit(relax.Tuple([values, indices]))

def _split(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,7 @@ def create_convert_map(
"roll.default": self._roll,
"select.int": self._select,
"slice.Tensor": self._slice,
"sort.default": self._sort,
"split.Tensor": self._split,
"split_with_sizes.default": self._split,
"squeeze.default": self._squeeze,
Expand Down
25 changes: 25 additions & 0 deletions tests/python/relax/test_from_exported_to_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,31 @@ def forward(self, x):
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev)


@tvm.testing.parametrize_targets("cuda")
def test_sort(target, dev):
raw_data = np.array([[4, 1, 13], [-30, 1, 3], [4, 0, 10]]).astype("float32")

# Test values
class SortModelValues(nn.Module):
def forward(self, x):
A, _ = torch.sort(x, dim=0, descending=True)
B, _ = torch.sort(x, dim=1, descending=False)
return A + B

torch_module = SortModelValues().eval()
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev)

# Test indices
class SortModelIndices(nn.Module):
def forward(self, x):
_, A = torch.sort(x, dim=0, descending=True)
_, B = torch.sort(x, dim=1, descending=False)
return A + B

torch_module = SortModelIndices().eval()
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev)


@tvm.testing.parametrize_targets("cuda")
def test_tensor_clamp(target, dev):
class ClampBothTensor(torch.nn.Module):
Expand Down
17 changes: 13 additions & 4 deletions tests/python/relax/test_frontend_from_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -4749,11 +4749,20 @@ def forward(self, x):
class Expected:
@R.function
def main(
inp_0: R.Tensor((5, 3), dtype="float32"),
) -> R.Tensor((5, 3), dtype="float32"):
inp_0: R.Tensor((5, 3), dtype="float32")
) -> R.Tuple(R.Tensor((5, 3), dtype="float32"), R.Tensor((5, 3), dtype="int32")):
with R.dataflow():
lv: R.Tensor((5, 3), dtype="float32") = R.sort(inp_0, axis=1, descending=True)
gv: R.Tensor((5, 3), dtype="float32") = lv
lv: R.Tensor((5, 3), dtype="int32") = R.argsort(
inp_0, axis=1, descending=True, dtype="int32"
)
lv1: R.Tensor((5, 3), dtype="float32") = R.gather_elements(inp_0, lv, axis=1)
lv2: R.Tuple(R.Tensor((5, 3), dtype="float32"), R.Tensor((5, 3), dtype="int32")) = (
lv1,
lv,
)
gv: R.Tuple(
R.Tensor((5, 3), dtype="float32"), R.Tensor((5, 3), dtype="int32")
) = lv2
R.output(gv)
return gv

Expand Down
Loading