Skip to content

Commit 111ddf7

Browse files
authored
[Relax][PyTorch] Support eye op for ExportedProgram importer (#17864)
1 parent b3d3a7a commit 111ddf7

File tree

3 files changed

+49
-0
lines changed

3 files changed

+49
-0
lines changed

python/tvm/relax/frontend/torch/base_fx_graph_translator.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1416,6 +1416,13 @@ def _empty_like(self, node: fx.Node) -> relax.Var:
14161416
x = self.env[node.args[0]]
14171417
return self.block_builder.emit(relax.op.zeros_like(x))
14181418

1419+
def _eye(self, node: fx.Node) -> relax.Var:
1420+
args = self.retrieve_args(node)
1421+
n = args[0]
1422+
m = args[1] if len(args) > 1 else n
1423+
dtype = self._convert_data_type(str(node.kwargs["dtype"]), self.env)
1424+
return self.block_builder.emit(relax.op.eye(n, m, dtype=dtype))
1425+
14191426
def _fill(self, node: fx.Node) -> relax.Var:
14201427
args = self.retrieve_args(node)
14211428
x = args[0]

python/tvm/relax/frontend/torch/exported_program_translator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,8 @@ def create_convert_map(
453453
"clone.default": lambda node: self.env[node.args[0]],
454454
"empty.memory_format": self._empty,
455455
"empty_like.default": self._empty_like,
456+
"eye.default": self._eye,
457+
"eye.m": self._eye,
456458
"fill.Scalar": self._fill,
457459
"full.default": self._full,
458460
"full_like.default": self._full_like,

tests/python/relax/test_frontend_from_exported_program.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4377,5 +4377,45 @@ def main(
43774377
verify_model(Narrow(), example_args, {}, Expected)
43784378

43794379

4380+
def test_eye():
4381+
class Eye1(Module):
4382+
def forward(self, input):
4383+
return torch.eye(3, 5, dtype=torch.float32)
4384+
4385+
@tvm.script.ir_module
4386+
class Expected1:
4387+
@R.function
4388+
def main(
4389+
input: R.Tensor((3, 5), dtype="float32")
4390+
) -> R.Tuple(R.Tensor((3, 5), dtype="float32")):
4391+
with R.dataflow():
4392+
lv: R.Tensor((3, 5), dtype="float32") = R.eye(3, 5, dtype="float32")
4393+
gv: R.Tuple(R.Tensor((3, 5), dtype="float32")) = (lv,)
4394+
R.output(gv)
4395+
return gv
4396+
4397+
class Eye2(Module):
4398+
def forward(self, input):
4399+
return torch.eye(5, dtype=torch.float32)
4400+
4401+
@tvm.script.ir_module
4402+
class Expected2:
4403+
@R.function
4404+
def main(
4405+
input: R.Tensor((5,), dtype="float32")
4406+
) -> R.Tuple(R.Tensor((5, 5), dtype="float32")):
4407+
with R.dataflow():
4408+
lv: R.Tensor((5, 5), dtype="float32") = R.eye(5, dtype="float32")
4409+
gv: R.Tuple(R.Tensor((5, 5), dtype="float32")) = (lv,)
4410+
R.output(gv)
4411+
return gv
4412+
4413+
example_args1 = (torch.randn(3, 5, dtype=torch.float32),)
4414+
verify_model(Eye1(), example_args1, {}, Expected1)
4415+
4416+
example_args2 = (torch.randn(5, dtype=torch.float32),)
4417+
verify_model(Eye2(), example_args2, {}, Expected2)
4418+
4419+
43804420
if __name__ == "__main__":
43814421
tvm.testing.main()

0 commit comments

Comments
 (0)