From e13fa3d0718eb75ede16a60b138e62b48360d586 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 13 Apr 2025 20:24:29 -0400 Subject: [PATCH 01/12] unit test --- tests/python/relax/test_from_exported_to_cuda.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index 8405f48576d8..43107f015313 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -63,6 +63,21 @@ def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, tar np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, atol=1e-5) +@tvm.testing.parametrize_targets("cuda") +def test_full(target, dev): + class FullModel(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.full((2, 3), 3.141592) + + torch_module = FullModel().eval() + + raw_data = np.random.rand(3,3).astype("float32") + + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + @tvm.testing.parametrize_targets("cuda") def test_tensor_clamp(target, dev): class ClampBothTensor(torch.nn.Module): From 5b23c30341fff3765c1226261d84f178531485b1 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 13 Apr 2025 20:26:10 -0400 Subject: [PATCH 02/12] full.default --- .../frontend/torch/base_fx_graph_translator.py | 17 +++++++++++++++++ .../torch/exported_program_translator.py | 1 + .../tvm/relax/frontend/torch/fx_translator.py | 17 ----------------- 3 files changed, 18 insertions(+), 17 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index c9c6afd71a64..55a603e20c60 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1271,6 +1271,23 @@ def _fill(self, node: fx.Node) -> relax.Var: value = args[1] if isinstance(args[1], relax.Expr) else relax.const(args[1], dtype) return self.block_builder.emit(relax.op.full(x.struct_info.shape, value, dtype)) + def _full(self, node: fx.Node) -> relax.Var: + import torch + + args = self.retrieve_args(node) + size = relax.ShapeExpr(args[0] if isinstance(args[0], (list, tuple)) else (args[0],)) + dtype = self._convert_data_type( + node.kwargs.get("dtype", torch.get_default_dtype()), self.env + ) + value = args[1] if isinstance(args[1], relax.expr.Constant) else relax.const(args[1], dtype) + return self.block_builder.emit( + relax.op.full( + size, + value, + dtype, + ) + ) + def _index_select(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] dim = node.args[1] diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 875ec3b83ea8..26e73dd6b84b 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -433,6 +433,7 @@ def create_convert_map( "empty.memory_format": self._empty, "empty_like.default": self._empty_like, "fill.Scalar": self._fill, + "full.default": self._full, "index_select.default": self._index_select, "lift_fresh_copy.default": self._to_copy, "new_ones.default": self._new_ones, diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index a5b50a7d1dce..80031cd7a403 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -468,23 +468,6 @@ def _inplace_fill(self, node: fx.Node) -> relax.Var: self.env[node.args[0]] = filled return filled - def _full(self, node: fx.Node) -> relax.Var: - import torch - - args = self.retrieve_args(node) - size = relax.ShapeExpr(args[0] if isinstance(args[0], (list, tuple)) else (args[0],)) - dtype = self._convert_data_type( - node.kwargs.get("dtype", torch.get_default_dtype()), self.env - ) - value = args[1] if isinstance(args[1], relax.expr.Constant) else relax.const(args[1], dtype) - return self.block_builder.emit( - relax.op.full( - size, - value, - dtype, - ) - ) - def _inplace_masked_fill(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] mask = self.env[node.args[1]] From 35aee297ba2ca01dbdf2695267cf1869b399a95b Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 13 Apr 2025 20:27:32 -0400 Subject: [PATCH 03/12] linting --- tests/python/relax/test_from_exported_to_cuda.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index 43107f015313..0a120aa8fb70 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -71,13 +71,14 @@ def __init__(self): def forward(self, x): return torch.full((2, 3), 3.141592) - + torch_module = FullModel().eval() - raw_data = np.random.rand(3,3).astype("float32") + raw_data = np.random.rand(3, 3).astype("float32") assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + @tvm.testing.parametrize_targets("cuda") def test_tensor_clamp(target, dev): class ClampBothTensor(torch.nn.Module): From 5c0e18b7b419a8194c696d9e4c5f6194af7b251a Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 13 Apr 2025 20:31:20 -0400 Subject: [PATCH 04/12] ones ok --- .../frontend/torch/base_fx_graph_translator.py | 16 ++++++++++++++++ .../torch/exported_program_translator.py | 1 + python/tvm/relax/frontend/torch/fx_translator.py | 16 ---------------- tests/python/relax/test_from_exported_to_cuda.py | 16 ++++++++++++++++ 4 files changed, 33 insertions(+), 16 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 55a603e20c60..2a811fd33e1e 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1308,6 +1308,22 @@ def _new_ones(self, node: fx.Node) -> relax.Var: self_var.struct_info.dtype, ) ) + + def _ones(self, node: fx.Node) -> relax.Var: + import torch + + args = self.retrieve_args(node) + size = relax.ShapeExpr(args[0] if isinstance(args[0], (list, tuple)) else (args[0],)) + dtype = self._convert_data_type( + node.kwargs.get("dtype", torch.get_default_dtype()), self.env + ) + return self.block_builder.emit( + relax.op.full( + size, + relax.const(1, dtype), + dtype, + ) + ) ########## DataType ########## diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 26e73dd6b84b..e962fbdbc696 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -438,6 +438,7 @@ def create_convert_map( "lift_fresh_copy.default": self._to_copy, "new_ones.default": self._new_ones, "one_hot.default": self._one_hot, + "ones.default": self._ones, # datatype "to.dtype": self._to, "to.dtype_layout": self._to, diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 80031cd7a403..f1b9a6d6e28c 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -510,22 +510,6 @@ def _masked_scatter(self, node: fx.Node) -> relax.Var: mask = self.block_builder.emit(relax.op.broadcast_to(mask, x.struct_info.shape)) return self.block_builder.emit(relax.op.where(mask, gathered_source, x)) - def _ones(self, node: fx.Node) -> relax.Var: - import torch - - args = self.retrieve_args(node) - size = relax.ShapeExpr(args[0] if isinstance(args[0], (list, tuple)) else (args[0],)) - dtype = self._convert_data_type( - node.kwargs.get("dtype", torch.get_default_dtype()), self.env - ) - return self.block_builder.emit( - relax.op.full( - size, - relax.const(1, dtype), - dtype, - ) - ) - def _one_hot(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] num_classes = node.args[1] if len(node.args) > 1 else node.kwargs.get("num_classes") diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index 0a120aa8fb70..5a0435d44484 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -78,6 +78,22 @@ def forward(self, x): assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) +@tvm.testing.parametrize_targets("cuda") +def test_ones(target, dev): + class FullModel(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.ones((2, 3)) + + torch_module = FullModel().eval() + + raw_data = np.random.rand(1,1).astype("float32") + + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + @tvm.testing.parametrize_targets("cuda") def test_tensor_clamp(target, dev): From e5c025857c1604ec38a36cda899d21ec47c4fcbb Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 13 Apr 2025 20:45:14 -0400 Subject: [PATCH 05/12] sort default - still as TODO --- .../torch/exported_program_translator.py | 1 + .../relax/test_from_exported_to_cuda.py | 14 ++++ .../relax/test_from_exported_to_cuda_NEW.py | 81 +++++++++++++++++++ 3 files changed, 96 insertions(+) create mode 100644 tests/python/relax/test_from_exported_to_cuda_NEW.py diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index e962fbdbc696..5aa06fa14ab8 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -406,6 +406,7 @@ def create_convert_map( "repeat.default": self._repeat, "select.int": self._select, "slice.Tensor": self._slice, + "sort.default": self._sort, "split.Tensor": self._split, "split_with_sizes.default": self._split, "squeeze.default": self._squeeze, diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index 5a0435d44484..5ad85c458a3c 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -95,6 +95,20 @@ def forward(self, x): +@tvm.testing.parametrize_targets("cuda") +def test_sort(target, dev): + class SortModel(nn.Module): + + def forward(self, x): + A, _ = torch.sort(x, dim=0, descending=True) + B, _ = torch.sort(x, dim=1, descending=False) + return A + B + + torch_module = SortModel().eval() + raw_data = np.array([[4,1,13],[-30,1,3],[4,0,10]]).astype("float32") + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + @tvm.testing.parametrize_targets("cuda") def test_tensor_clamp(target, dev): class ClampBothTensor(torch.nn.Module): diff --git a/tests/python/relax/test_from_exported_to_cuda_NEW.py b/tests/python/relax/test_from_exported_to_cuda_NEW.py new file mode 100644 index 000000000000..d79701136713 --- /dev/null +++ b/tests/python/relax/test_from_exported_to_cuda_NEW.py @@ -0,0 +1,81 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +from tvm import relax +import tvm.testing +import numpy as np +import torch +from torch import nn +from torch.export import export +from tvm.relax.frontend.torch import from_exported_program +from torch.nn import Softmax, Upsample + + +def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev): + """ + This util ensures that a torch module can successfully be exported to TVM + using torch.export and that the resuling IR program gives the same result + as PyTorch when ran on CUDA. + """ + raw_data_for_tvm = raw_data.copy() # In case the data is modified + torch_data = torch.from_numpy(raw_data) + example_args = (torch_data,) + + with torch.no_grad(): + exported_program = export(torch_module, example_args) + mod_from_torch = from_exported_program(exported_program, keep_params_as_input=True) + + tvm_mod, tvm_params = relax.frontend.detach_params(mod_from_torch) + + relax_pipeline = relax.get_default_pipeline(tvm.target.Target.from_device(tvm.cuda())) + ex = relax.build(tvm_mod, target=target, relax_pipeline=relax_pipeline) + vm = relax.VirtualMachine(ex, dev) + + gpu_data = tvm.nd.array(raw_data_for_tvm, dev) + gpu_params = [tvm.nd.array(p, dev) for p in tvm_params["main"]] + gpu_out = vm["main"](gpu_data, *gpu_params) + + pytorch_out = torch_module(torch_data) + + if isinstance(pytorch_out, tuple): + for i in range(len(pytorch_out)): + actual = gpu_out[i].numpy() + desired = pytorch_out[i].detach().numpy() + np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, atol=1e-5) + else: + actual = gpu_out[0].numpy() + desired = pytorch_out.detach().numpy() + np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, atol=1e-5) + + +@tvm.testing.parametrize_targets("cuda") +def test_sort(target, dev): + class SortModel(nn.Module): + + def forward(self, x): + A, _ = torch.sort(x, dim=0, descending=True) + B, _ = torch.sort(x, dim=1, descending=False) + return A + B + + torch_module = SortModel().eval() + raw_data = np.array([[4,1,13],[-30,1,3],[4,0,10]]).astype("float32") + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +if __name__ == "__main__": + tvm.testing.main() From d8ec9a50d10746b8a75bc560924d724250dc132d Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Thu, 17 Apr 2025 08:53:51 -0400 Subject: [PATCH 06/12] sort test passes --- .lesshst | 1 - .../torch/base_fx_graph_translator.py | 30 ++++++++----------- .../torch/exported_program_translator.py | 1 + .../relax/test_from_exported_to_cuda.py | 10 ------- .../relax/test_from_exported_to_cuda_NEW.py | 14 +++++++-- 5 files changed, 24 insertions(+), 32 deletions(-) delete mode 100644 .lesshst diff --git a/.lesshst b/.lesshst deleted file mode 100644 index 4d1c30b7a584..000000000000 --- a/.lesshst +++ /dev/null @@ -1 +0,0 @@ -.less-history-file: diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 298aac550741..097a8f4fc2c2 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1155,8 +1155,18 @@ def _scatter(self, node: fx.Node) -> relax.Var: def _sort(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", -1) - descending = node.args[2] if len(node.args) > 2 else node.kwargs.get("descending", False) - return self.block_builder.emit(relax.op.sort(x, dim, descending)) + descending = ( + node.args[2] if len(node.args) > 2 else node.kwargs.get("descending", False) + ) + + # 1. indices: argsort already gives the permutation we need + indices = self.block_builder.emit(relax.op.argsort(x, dim, descending)) + + # 2. values: gather the tensor elements with those indices + values = self.block_builder.emit(relax.op.gather_elements(x, indices, axis=dim)) + + # 3. return the exact PyTorch ABI: (values, indices) + return self.block_builder.emit(relax.Tuple([values, indices])) def _split(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] @@ -1357,22 +1367,6 @@ def _new_ones(self, node: fx.Node) -> relax.Var: self_var.struct_info.dtype, ) ) - - def _ones(self, node: fx.Node) -> relax.Var: - import torch - - args = self.retrieve_args(node) - size = relax.ShapeExpr(args[0] if isinstance(args[0], (list, tuple)) else (args[0],)) - dtype = self._convert_data_type( - node.kwargs.get("dtype", torch.get_default_dtype()), self.env - ) - return self.block_builder.emit( - relax.op.full( - size, - relax.const(1, dtype), - dtype, - ) - ) def _ones(self, node: fx.Node) -> relax.Var: import torch diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index f20d7567dde0..6d34a6452c82 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -555,6 +555,7 @@ def from_exported_program( self.env[node] = getattr(exported_program.graph_module, node.target) elif node.op == "call_function": func_name = node.target.__name__ + print("about to grab and apply function !!!!!!!!!!!!!!!!!!!", func_name) self.env[node] = self.convert_map[func_name](node) else: raise ValueError(f"Unsupported op {node.op}") diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index 9fe6a9fa8663..a00dab01e56c 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -105,18 +105,10 @@ def forward(self, x): return torch.ones((2, 3)) torch_module = FullModel().eval() - -<<<<<<< HEAD - raw_data = np.random.rand(1,1).astype("float32") -======= raw_data = np.random.rand(1, 1).astype("float32") ->>>>>>> main - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) -<<<<<<< HEAD - @tvm.testing.parametrize_targets("cuda") def test_sort(target, dev): class SortModel(nn.Module): @@ -131,8 +123,6 @@ def forward(self, x): assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) -======= ->>>>>>> main @tvm.testing.parametrize_targets("cuda") def test_tensor_clamp(target, dev): class ClampBothTensor(torch.nn.Module): diff --git a/tests/python/relax/test_from_exported_to_cuda_NEW.py b/tests/python/relax/test_from_exported_to_cuda_NEW.py index d79701136713..1fd3f79b620c 100644 --- a/tests/python/relax/test_from_exported_to_cuda_NEW.py +++ b/tests/python/relax/test_from_exported_to_cuda_NEW.py @@ -38,6 +38,10 @@ def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, tar with torch.no_grad(): exported_program = export(torch_module, example_args) + print("Exported program:", exported_program) # TODO remove + print("Exported program graph:", exported_program.graph) # TODO remove + print("Exported program graph signature:", exported_program.graph_signature) # TODO remove + print("Exported program graph module:", exported_program.graph_module) # TODO remove mod_from_torch = from_exported_program(exported_program, keep_params_as_input=True) tvm_mod, tvm_params = relax.frontend.detach_params(mod_from_torch) @@ -68,9 +72,13 @@ def test_sort(target, dev): class SortModel(nn.Module): def forward(self, x): - A, _ = torch.sort(x, dim=0, descending=True) - B, _ = torch.sort(x, dim=1, descending=False) - return A + B + A = torch.sort(x, dim=0, descending=True) + return A + + # TODO revert below + # A, _ = torch.sort(x, dim=0, descending=True) + # B, _ = torch.sort(x, dim=1, descending=False) + # return A + B torch_module = SortModel().eval() raw_data = np.array([[4,1,13],[-30,1,3],[4,0,10]]).astype("float32") From 23ee26a7d809f0cd4371f34f7afaa3b0de5bca99 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Thu, 17 Apr 2025 08:58:17 -0400 Subject: [PATCH 07/12] cleanup --- .../relax/frontend/torch/base_fx_graph_translator.py | 11 +++-------- tests/python/relax/test_from_exported_to_cuda_NEW.py | 4 ---- 2 files changed, 3 insertions(+), 12 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 097a8f4fc2c2..d2698d6a1a2d 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1153,19 +1153,14 @@ def _scatter(self, node: fx.Node) -> relax.Var: return self.block_builder.emit(relax.op.scatter_elements(x, index, src, axis=dim)) def _sort(self, node: fx.Node) -> relax.Var: + # torch.sort() returns a tuple of values and indices + # we use argsort to get indices and gather_elements to get values x = self.env[node.args[0]] dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", -1) - descending = ( - node.args[2] if len(node.args) > 2 else node.kwargs.get("descending", False) - ) + descending = (node.args[2] if len(node.args) > 2 else node.kwargs.get("descending", False)) - # 1. indices: argsort already gives the permutation we need indices = self.block_builder.emit(relax.op.argsort(x, dim, descending)) - - # 2. values: gather the tensor elements with those indices values = self.block_builder.emit(relax.op.gather_elements(x, indices, axis=dim)) - - # 3. return the exact PyTorch ABI: (values, indices) return self.block_builder.emit(relax.Tuple([values, indices])) def _split(self, node: fx.Node) -> relax.Var: diff --git a/tests/python/relax/test_from_exported_to_cuda_NEW.py b/tests/python/relax/test_from_exported_to_cuda_NEW.py index 1fd3f79b620c..991dc93beba5 100644 --- a/tests/python/relax/test_from_exported_to_cuda_NEW.py +++ b/tests/python/relax/test_from_exported_to_cuda_NEW.py @@ -38,10 +38,6 @@ def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, tar with torch.no_grad(): exported_program = export(torch_module, example_args) - print("Exported program:", exported_program) # TODO remove - print("Exported program graph:", exported_program.graph) # TODO remove - print("Exported program graph signature:", exported_program.graph_signature) # TODO remove - print("Exported program graph module:", exported_program.graph_module) # TODO remove mod_from_torch = from_exported_program(exported_program, keep_params_as_input=True) tvm_mod, tvm_params = relax.frontend.detach_params(mod_from_torch) From fe53f9c4788c90ef10299fae3567eccf22a3133d Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Thu, 17 Apr 2025 09:00:40 -0400 Subject: [PATCH 08/12] cleanup test --- .../relax/test_from_exported_to_cuda_NEW.py | 85 ------------------- 1 file changed, 85 deletions(-) delete mode 100644 tests/python/relax/test_from_exported_to_cuda_NEW.py diff --git a/tests/python/relax/test_from_exported_to_cuda_NEW.py b/tests/python/relax/test_from_exported_to_cuda_NEW.py deleted file mode 100644 index 991dc93beba5..000000000000 --- a/tests/python/relax/test_from_exported_to_cuda_NEW.py +++ /dev/null @@ -1,85 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import tvm -from tvm import relax -import tvm.testing -import numpy as np -import torch -from torch import nn -from torch.export import export -from tvm.relax.frontend.torch import from_exported_program -from torch.nn import Softmax, Upsample - - -def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev): - """ - This util ensures that a torch module can successfully be exported to TVM - using torch.export and that the resuling IR program gives the same result - as PyTorch when ran on CUDA. - """ - raw_data_for_tvm = raw_data.copy() # In case the data is modified - torch_data = torch.from_numpy(raw_data) - example_args = (torch_data,) - - with torch.no_grad(): - exported_program = export(torch_module, example_args) - mod_from_torch = from_exported_program(exported_program, keep_params_as_input=True) - - tvm_mod, tvm_params = relax.frontend.detach_params(mod_from_torch) - - relax_pipeline = relax.get_default_pipeline(tvm.target.Target.from_device(tvm.cuda())) - ex = relax.build(tvm_mod, target=target, relax_pipeline=relax_pipeline) - vm = relax.VirtualMachine(ex, dev) - - gpu_data = tvm.nd.array(raw_data_for_tvm, dev) - gpu_params = [tvm.nd.array(p, dev) for p in tvm_params["main"]] - gpu_out = vm["main"](gpu_data, *gpu_params) - - pytorch_out = torch_module(torch_data) - - if isinstance(pytorch_out, tuple): - for i in range(len(pytorch_out)): - actual = gpu_out[i].numpy() - desired = pytorch_out[i].detach().numpy() - np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, atol=1e-5) - else: - actual = gpu_out[0].numpy() - desired = pytorch_out.detach().numpy() - np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, atol=1e-5) - - -@tvm.testing.parametrize_targets("cuda") -def test_sort(target, dev): - class SortModel(nn.Module): - - def forward(self, x): - A = torch.sort(x, dim=0, descending=True) - return A - - # TODO revert below - # A, _ = torch.sort(x, dim=0, descending=True) - # B, _ = torch.sort(x, dim=1, descending=False) - # return A + B - - torch_module = SortModel().eval() - raw_data = np.array([[4,1,13],[-30,1,3],[4,0,10]]).astype("float32") - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) - - -if __name__ == "__main__": - tvm.testing.main() From 18bf9527c73c537de49aed8d2d5b483917b27735 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Thu, 17 Apr 2025 09:01:08 -0400 Subject: [PATCH 09/12] lint --- python/tvm/relax/frontend/torch/base_fx_graph_translator.py | 4 ++-- tests/python/relax/test_from_exported_to_cuda.py | 5 ++--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index d2698d6a1a2d..e8680be24a59 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1157,8 +1157,8 @@ def _sort(self, node: fx.Node) -> relax.Var: # we use argsort to get indices and gather_elements to get values x = self.env[node.args[0]] dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", -1) - descending = (node.args[2] if len(node.args) > 2 else node.kwargs.get("descending", False)) - + descending = node.args[2] if len(node.args) > 2 else node.kwargs.get("descending", False) + indices = self.block_builder.emit(relax.op.argsort(x, dim, descending)) values = self.block_builder.emit(relax.op.gather_elements(x, indices, axis=dim)) return self.block_builder.emit(relax.Tuple([values, indices])) diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index a00dab01e56c..f2b65fdb7f21 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -112,14 +112,13 @@ def forward(self, x): @tvm.testing.parametrize_targets("cuda") def test_sort(target, dev): class SortModel(nn.Module): - def forward(self, x): A, _ = torch.sort(x, dim=0, descending=True) B, _ = torch.sort(x, dim=1, descending=False) - return A + B + return A + B torch_module = SortModel().eval() - raw_data = np.array([[4,1,13],[-30,1,3],[4,0,10]]).astype("float32") + raw_data = np.array([[4, 1, 13], [-30, 1, 3], [4, 0, 10]]).astype("float32") assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) From 9ed717e302e71b0665785606c288d05fa7b9a632 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Thu, 17 Apr 2025 09:02:00 -0400 Subject: [PATCH 10/12] remove print --- python/tvm/relax/frontend/torch/exported_program_translator.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 6d34a6452c82..f20d7567dde0 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -555,7 +555,6 @@ def from_exported_program( self.env[node] = getattr(exported_program.graph_module, node.target) elif node.op == "call_function": func_name = node.target.__name__ - print("about to grab and apply function !!!!!!!!!!!!!!!!!!!", func_name) self.env[node] = self.convert_map[func_name](node) else: raise ValueError(f"Unsupported op {node.op}") From 68bf17767f4e53f7f17e3732fe4be54ab047afb0 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Thu, 17 Apr 2025 09:10:22 -0400 Subject: [PATCH 11/12] test values and indices --- .../python/relax/test_from_exported_to_cuda.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index f2b65fdb7f21..4a94039cb853 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -111,14 +111,26 @@ def forward(self, x): @tvm.testing.parametrize_targets("cuda") def test_sort(target, dev): - class SortModel(nn.Module): + raw_data = np.array([[4, 1, 13], [-30, 1, 3], [4, 0, 10]]).astype("float32") + + # Test values + class SortModelValues(nn.Module): def forward(self, x): A, _ = torch.sort(x, dim=0, descending=True) B, _ = torch.sort(x, dim=1, descending=False) return A + B - torch_module = SortModel().eval() - raw_data = np.array([[4, 1, 13], [-30, 1, 3], [4, 0, 10]]).astype("float32") + torch_module = SortModelValues().eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + # Test indices + class SortModelIndices(nn.Module): + def forward(self, x): + _, A = torch.sort(x, dim=0, descending=True) + _, B = torch.sort(x, dim=1, descending=False) + return A + B + + torch_module = SortModelIndices().eval() assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) From 775c491af15d428f8f51be5ff5e3d7b70c444b6e Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sat, 19 Apr 2025 11:07:42 -0400 Subject: [PATCH 12/12] frontend test --- tests/python/relax/test_frontend_from_fx.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index a962de8a3237..f1ea92dfb499 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -4437,11 +4437,20 @@ def forward(self, x): class Expected: @R.function def main( - inp_0: R.Tensor((5, 3), dtype="float32"), - ) -> R.Tensor((5, 3), dtype="float32"): + inp_0: R.Tensor((5, 3), dtype="float32") + ) -> R.Tuple(R.Tensor((5, 3), dtype="float32"), R.Tensor((5, 3), dtype="int32")): with R.dataflow(): - lv: R.Tensor((5, 3), dtype="float32") = R.sort(inp_0, axis=1, descending=True) - gv: R.Tensor((5, 3), dtype="float32") = lv + lv: R.Tensor((5, 3), dtype="int32") = R.argsort( + inp_0, axis=1, descending=True, dtype="int32" + ) + lv1: R.Tensor((5, 3), dtype="float32") = R.gather_elements(inp_0, lv, axis=1) + lv2: R.Tuple(R.Tensor((5, 3), dtype="float32"), R.Tensor((5, 3), dtype="int32")) = ( + lv1, + lv, + ) + gv: R.Tuple( + R.Tensor((5, 3), dtype="float32"), R.Tensor((5, 3), dtype="int32") + ) = lv2 R.output(gv) return gv