Skip to content

Commit 2ddd42f

Browse files
Pratheesh-04-MCWShiboXing
authored andcommitted
Add op support for new_zeros op in Exported Program and fx graph frontend (apache#17911)
1 parent a2e4732 commit 2ddd42f

File tree

5 files changed

+69
-0
lines changed

5 files changed

+69
-0
lines changed

python/tvm/relax/frontend/torch/base_fx_graph_translator.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1549,6 +1549,25 @@ def _new_ones(self, node: fx.Node) -> relax.Var:
15491549
)
15501550
)
15511551

1552+
def _new_zeros(self, node: fx.Node) -> relax.Var:
1553+
args = self.retrieve_args(node)
1554+
input_tensor = args[0]
1555+
size = (
1556+
args[1]
1557+
if isinstance(args[1], (list, tuple))
1558+
else (args[1],)
1559+
if len(args[1:]) == 1
1560+
else args[1:]
1561+
)
1562+
size = relax.ShapeExpr(size)
1563+
return self.block_builder.emit(
1564+
relax.op.full(
1565+
size,
1566+
relax.const(0, input_tensor.struct_info.dtype),
1567+
input_tensor.struct_info.dtype,
1568+
)
1569+
)
1570+
15521571
def _ones(self, node: fx.Node) -> relax.Var:
15531572
import torch
15541573

python/tvm/relax/frontend/torch/exported_program_translator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,7 @@ def create_convert_map(
484484
"masked_fill.Scalar": self._masked_fill,
485485
"masked_fill_.Scalar": self._inplace_masked_fill,
486486
"new_ones.default": self._new_ones,
487+
"new_zeros.default": self._new_zeros,
487488
"one_hot.default": self._one_hot,
488489
"ones.default": self._ones,
489490
"ones_like.default": lambda node: self.block_builder.emit(

python/tvm/relax/frontend/torch/fx_translator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -829,6 +829,7 @@ def create_convert_map(
829829
"masked_fill": self._masked_fill,
830830
"masked_scatter": self._masked_scatter,
831831
"new_ones": self._new_ones,
832+
"new_zeros": self._new_zeros,
832833
"ones": self._ones,
833834
"one_hot": self._one_hot,
834835
"ones_like": lambda node: self.block_builder.emit(

tests/python/relax/test_frontend_from_exported_program.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3775,6 +3775,29 @@ def main(
37753775
verify_model(NewOnes(), example_args, {}, expected1)
37763776

37773777

3778+
def test_new_zeros():
3779+
class NewZeros(torch.nn.Module):
3780+
def forward(self, x):
3781+
return x.new_zeros(1, 128, 128)
3782+
3783+
@tvm.script.ir_module
3784+
class expected1:
3785+
@R.function
3786+
def main(
3787+
x: R.Tensor((1, 128, 128), dtype="float32")
3788+
) -> R.Tuple(R.Tensor((1, 128, 128), dtype="float32")):
3789+
with R.dataflow():
3790+
lv: R.Tensor((1, 128, 128), dtype="float32") = R.full(
3791+
R.shape([1, 128, 128]), R.const(0, "float32"), dtype="float32"
3792+
)
3793+
gv: R.Tuple(R.Tensor((1, 128, 128), dtype="float32")) = (lv,)
3794+
R.output(gv)
3795+
return gv
3796+
3797+
example_args = (torch.randn(1, 128, 128, dtype=torch.float32),)
3798+
verify_model(NewZeros(), example_args, {}, expected1)
3799+
3800+
37783801
def test_to_copy():
37793802
# float
37803803
class ToFloat(Module):

tests/python/relax/test_frontend_from_fx.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3352,6 +3352,31 @@ def main(x: R.Tensor((1, 2, 3), dtype="float32")) -> R.Tensor((1, 2, 3), dtype="
33523352
verify_model(NewOnes(), input_info, {}, expected1)
33533353

33543354

3355+
def test_new_zeros():
3356+
input_info = [([1, 128, 128], "float32")]
3357+
3358+
class NewZeros(Module):
3359+
def forward(self, x):
3360+
return x.new_zeros(1, 128, 128)
3361+
3362+
@tvm.script.ir_module
3363+
class expected:
3364+
@R.function
3365+
def main(
3366+
x: R.Tensor((1, 128, 128), dtype="float32")
3367+
) -> R.Tensor((1, 128, 128), dtype="float32"):
3368+
# block 0
3369+
with R.dataflow():
3370+
lv: R.Tensor((1, 128, 128), dtype="float32") = R.full(
3371+
(1, 128, 128), R.const(0.0, "float32"), dtype="float32"
3372+
)
3373+
gv: R.Tensor((1, 128, 128), dtype="float32") = lv
3374+
R.output(gv)
3375+
return gv
3376+
3377+
verify_model(NewZeros(), input_info, {}, expected)
3378+
3379+
33553380
def test_expand():
33563381
input_info = [([1, 2, 3, 4], "float32")]
33573382

0 commit comments

Comments
 (0)