Skip to content

Commit 6bd55f0

Browse files
[Relax][PyTorch] full.default, full_like.default, ones.default (#17832)
* unit test * full.default * linting * ones ok * tests for ones, full, and full like work
1 parent 4e41b42 commit 6bd55f0

File tree

4 files changed

+89
-33
lines changed

4 files changed

+89
-33
lines changed

python/tvm/relax/frontend/torch/base_fx_graph_translator.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1271,6 +1271,28 @@ def _fill(self, node: fx.Node) -> relax.Var:
12711271
value = args[1] if isinstance(args[1], relax.Expr) else relax.const(args[1], dtype)
12721272
return self.block_builder.emit(relax.op.full(x.struct_info.shape, value, dtype))
12731273

1274+
def _full(self, node: fx.Node) -> relax.Var:
1275+
import torch
1276+
1277+
args = self.retrieve_args(node)
1278+
size = relax.ShapeExpr(args[0] if isinstance(args[0], (list, tuple)) else (args[0],))
1279+
dtype = self._convert_data_type(
1280+
node.kwargs.get("dtype", torch.get_default_dtype()), self.env
1281+
)
1282+
value = args[1] if isinstance(args[1], relax.expr.Constant) else relax.const(args[1], dtype)
1283+
return self.block_builder.emit(
1284+
relax.op.full(
1285+
size,
1286+
value,
1287+
dtype,
1288+
)
1289+
)
1290+
1291+
def _full_like(self, node: fx.Node) -> relax.Var:
1292+
x = self.env[node.args[0]]
1293+
fill_value = relax.const(node.args[1])
1294+
return self.block_builder.emit(relax.op.full_like(x, fill_value))
1295+
12741296
def _index_select(self, node: fx.Node) -> relax.Var:
12751297
x = self.env[node.args[0]]
12761298
dim = node.args[1]
@@ -1292,6 +1314,22 @@ def _new_ones(self, node: fx.Node) -> relax.Var:
12921314
)
12931315
)
12941316

1317+
def _ones(self, node: fx.Node) -> relax.Var:
1318+
import torch
1319+
1320+
args = self.retrieve_args(node)
1321+
size = relax.ShapeExpr(args[0] if isinstance(args[0], (list, tuple)) else (args[0],))
1322+
dtype = self._convert_data_type(
1323+
node.kwargs.get("dtype", torch.get_default_dtype()), self.env
1324+
)
1325+
return self.block_builder.emit(
1326+
relax.op.full(
1327+
size,
1328+
relax.const(1, dtype),
1329+
dtype,
1330+
)
1331+
)
1332+
12951333
########## DataType ##########
12961334

12971335
def _to(self, node: fx.Node) -> relax.Var:

python/tvm/relax/frontend/torch/exported_program_translator.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,10 +442,13 @@ def create_convert_map(
442442
"empty.memory_format": self._empty,
443443
"empty_like.default": self._empty_like,
444444
"fill.Scalar": self._fill,
445+
"full.default": self._full,
446+
"full_like.default": self._full_like,
445447
"index_select.default": self._index_select,
446448
"lift_fresh_copy.default": self._to_copy,
447449
"new_ones.default": self._new_ones,
448450
"one_hot.default": self._one_hot,
451+
"ones.default": self._ones,
449452
# datatype
450453
"to.dtype": self._to,
451454
"to.dtype_layout": self._to,

python/tvm/relax/frontend/torch/fx_translator.py

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -468,23 +468,6 @@ def _inplace_fill(self, node: fx.Node) -> relax.Var:
468468
self.env[node.args[0]] = filled
469469
return filled
470470

471-
def _full(self, node: fx.Node) -> relax.Var:
472-
import torch
473-
474-
args = self.retrieve_args(node)
475-
size = relax.ShapeExpr(args[0] if isinstance(args[0], (list, tuple)) else (args[0],))
476-
dtype = self._convert_data_type(
477-
node.kwargs.get("dtype", torch.get_default_dtype()), self.env
478-
)
479-
value = args[1] if isinstance(args[1], relax.expr.Constant) else relax.const(args[1], dtype)
480-
return self.block_builder.emit(
481-
relax.op.full(
482-
size,
483-
value,
484-
dtype,
485-
)
486-
)
487-
488471
def _inplace_masked_fill(self, node: fx.Node) -> relax.Var:
489472
x = self.env[node.args[0]]
490473
mask = self.env[node.args[1]]
@@ -527,22 +510,6 @@ def _masked_scatter(self, node: fx.Node) -> relax.Var:
527510
mask = self.block_builder.emit(relax.op.broadcast_to(mask, x.struct_info.shape))
528511
return self.block_builder.emit(relax.op.where(mask, gathered_source, x))
529512

530-
def _ones(self, node: fx.Node) -> relax.Var:
531-
import torch
532-
533-
args = self.retrieve_args(node)
534-
size = relax.ShapeExpr(args[0] if isinstance(args[0], (list, tuple)) else (args[0],))
535-
dtype = self._convert_data_type(
536-
node.kwargs.get("dtype", torch.get_default_dtype()), self.env
537-
)
538-
return self.block_builder.emit(
539-
relax.op.full(
540-
size,
541-
relax.const(1, dtype),
542-
dtype,
543-
)
544-
)
545-
546513
def _one_hot(self, node: fx.Node) -> relax.Var:
547514
x = self.env[node.args[0]]
548515
num_classes = node.args[1] if len(node.args) > 1 else node.kwargs.get("num_classes")

tests/python/relax/test_from_exported_to_cuda.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,54 @@ def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, tar
6363
np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, atol=1e-5)
6464

6565

66+
@tvm.testing.parametrize_targets("cuda")
67+
def test_full(target, dev):
68+
class FullModel(nn.Module):
69+
def __init__(self):
70+
super().__init__()
71+
72+
def forward(self, x):
73+
return torch.full((2, 3), 3.141592)
74+
75+
torch_module = FullModel().eval()
76+
77+
raw_data = np.random.rand(3, 3).astype("float32")
78+
79+
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev)
80+
81+
82+
@tvm.testing.parametrize_targets("cuda")
83+
def test_full_like(target, dev):
84+
class FullLike(nn.Module):
85+
def __init__(self):
86+
super().__init__()
87+
self.fill_value = 7.0
88+
89+
def forward(self, x):
90+
return torch.full_like(x, self.fill_value)
91+
92+
torch_module = FullLike().eval()
93+
raw_data = np.random.rand(2, 3).astype("float32")
94+
95+
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev)
96+
97+
98+
@tvm.testing.parametrize_targets("cuda")
99+
def test_ones(target, dev):
100+
class FullModel(nn.Module):
101+
def __init__(self):
102+
super().__init__()
103+
104+
def forward(self, x):
105+
return torch.ones((2, 3))
106+
107+
torch_module = FullModel().eval()
108+
109+
raw_data = np.random.rand(1, 1).astype("float32")
110+
111+
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev)
112+
113+
66114
@tvm.testing.parametrize_targets("cuda")
67115
def test_tensor_clamp(target, dev):
68116
class ClampBothTensor(torch.nn.Module):

0 commit comments

Comments
 (0)