Skip to content

Commit 8fce8de

Browse files
tlopexShiboXing
authored andcommitted
[Relax][PyTorch] Add support for eye op in fx graph (apache#17908)
1 parent 2ddd42f commit 8fce8de

File tree

2 files changed

+19
-0
lines changed

2 files changed

+19
-0
lines changed

python/tvm/relax/frontend/torch/fx_translator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -821,6 +821,7 @@ def create_convert_map(
821821
"clone": lambda node: self.env[node.args[0]],
822822
"empty": self._empty,
823823
"empty_like": self._empty_like,
824+
"eye": self._eye,
824825
"fill": self._fill,
825826
"fill_": self._inplace_fill,
826827
"full": self._full,

tests/python/relax/test_frontend_from_fx.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5342,5 +5342,23 @@ def main(
53425342
verify_model(BFloat16Model(), [([10, 10], "bfloat16"), ([10, 10], "bfloat16")], {}, Expected)
53435343

53445344

5345+
def test_eye():
5346+
import numpy as np
5347+
5348+
class Eye(Module):
5349+
def forward(self, input):
5350+
return torch.eye(3)
5351+
5352+
graph_model = fx.symbolic_trace(Eye())
5353+
mod = from_fx(graph_model, [([3, 3], "float32")])
5354+
assert len(mod["main"].body.blocks) == 1
5355+
assert len(mod["main"].body.blocks[0].bindings) == 1
5356+
assert isinstance(mod["main"].body.blocks[0].bindings[0].value, relax.Constant)
5357+
tvm.testing.assert_allclose(
5358+
mod["main"].body.blocks[0].bindings[0].value.data.numpy(),
5359+
np.eye(3, dtype="float32"),
5360+
)
5361+
5362+
53455363
if __name__ == "__main__":
53465364
tvm.testing.main()

0 commit comments

Comments
 (0)