Skip to content

Commit 0bcaf4d

Browse files
Deivanayaki-Sdeivanayakisankaralingam
authored andcommitted
[Relax][PyTorch] Add RSub Op Support for Exported Program and FX graph (apache#17849)
* add rsub op support into exported and fx graph frontend * fix trailing whitespace issue * fix lint issues in test scripts --------- Co-authored-by: deivanayakisankaralingam <deiva@Deivanayaki>
1 parent 3f1625b commit 0bcaf4d

File tree

5 files changed

+89
-0
lines changed

5 files changed

+89
-0
lines changed

python/tvm/relax/frontend/torch/base_fx_graph_translator.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,16 @@ def call_binary_op(op, lhs, rhs):
407407

408408
return convert
409409

410+
def _rsub(self, node: fx.Node) -> relax.Var:
411+
args = self.retrieve_args(node)
412+
lhs = args[0]
413+
rhs = args[1]
414+
415+
if isinstance(rhs, (int, float)):
416+
rhs = relax.const(rhs)
417+
418+
return self.block_builder.emit(relax.op.subtract(rhs, lhs))
419+
410420
########## Linear Algebra ##########
411421

412422
def _linalg_vector_norm(self, node: fx.Node) -> relax.Var:

python/tvm/relax/frontend/torch/exported_program_translator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,8 @@ def create_convert_map(
305305
"relu_.default": self._unary_op(relax.op.nn.relu),
306306
"round.default": self._round,
307307
"rsqrt.default": self._unary_op(relax.op.rsqrt),
308+
"rsub.Tensor": self._rsub,
309+
"rsub.Scalar": self._rsub,
308310
"selu.default": self._unary_op(relax.op.nn.selu),
309311
"sigmoid.default": self._unary_op(relax.op.sigmoid),
310312
"sign.default": self._unary_op(relax.op.sign),

python/tvm/relax/frontend/torch/fx_translator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -692,6 +692,7 @@ def create_convert_map(
692692
"pow": self._binary_op(relax.op.power, operator.pow),
693693
"or_": self._binary_op(relax.op.bitwise_or, operator.or_),
694694
"rshift": self._binary_op(relax.op.right_shift, operator.rshift),
695+
"rsub": self._rsub,
695696
"sub": self._binary_op(relax.op.subtract, operator.sub),
696697
"truediv": self._binary_op(relax.op.divide, operator.truediv),
697698
"xor": self._binary_op(relax.op.bitwise_xor, operator.xor),

tests/python/relax/test_frontend_from_exported_program.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -935,6 +935,7 @@ def test_binary3():
935935
torch.randn(10, 10, dtype=torch.float32),
936936
torch.randn(10, 10, dtype=torch.float32),
937937
)
938+
example_args2 = (torch.randn(10, 10, dtype=torch.float32),)
938939

939940
# Max
940941
class Max1(Module):
@@ -976,6 +977,42 @@ def main(
976977

977978
verify_model(Min1(), example_args1, {}, expected_min1)
978979

980+
# RSub
981+
class RSub1(Module):
982+
def forward(self, x, y):
983+
return torch.rsub(x, y)
984+
985+
class RSub2(Module):
986+
def forward(self, x):
987+
return torch.rsub(x, 5.0)
988+
989+
@tvm.script.ir_module
990+
class expected_rsub1:
991+
@R.function
992+
def main(
993+
x: R.Tensor((10, 10), dtype="float32"), y: R.Tensor((10, 10), dtype="float32")
994+
) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
995+
with R.dataflow():
996+
lv: R.Tensor((10, 10), dtype="float32") = R.subtract(y, x)
997+
gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,)
998+
R.output(gv)
999+
return gv
1000+
1001+
@tvm.script.ir_module
1002+
class expected_rsub2:
1003+
@R.function
1004+
def main(
1005+
x: R.Tensor((10, 10), dtype="float32")
1006+
) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
1007+
with R.dataflow():
1008+
lv: R.Tensor((10, 10), dtype="float32") = R.subtract(R.const(5.0, "float32"), x)
1009+
gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,)
1010+
R.output(gv)
1011+
return gv
1012+
1013+
verify_model(RSub1(), example_args1, {}, expected_rsub1)
1014+
verify_model(RSub2(), example_args2, {}, expected_rsub2)
1015+
9791016

9801017
def test_batchnorm2d():
9811018
class BatchNorm2d(Module):

tests/python/relax/test_frontend_from_fx.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1738,6 +1738,45 @@ def main(
17381738
verify_model(Binary2(op), input_info2, {}, expected_binary2)
17391739

17401740

1741+
# RSub
1742+
def test_rsub():
1743+
input_info1 = [([10, 10], "float32"), ([10, 10], "float32")]
1744+
input_info2 = [([10, 10], "float32")]
1745+
1746+
class RSub1(Module):
1747+
def forward(self, x, y):
1748+
return torch.rsub(x, y)
1749+
1750+
class RSub2(Module):
1751+
def forward(self, x):
1752+
return torch.rsub(x, 5.0)
1753+
1754+
@tvm.script.ir_module
1755+
class expected_rsub1:
1756+
@R.function
1757+
def main(
1758+
x: R.Tensor((10, 10), dtype="float32"), y: R.Tensor((10, 10), dtype="float32")
1759+
) -> R.Tensor((10, 10), dtype="float32"):
1760+
with R.dataflow():
1761+
lv: R.Tensor((10, 10), dtype="float32") = R.subtract(y, x)
1762+
gv: R.Tensor((10, 10), dtype="float32") = lv
1763+
R.output(gv)
1764+
return gv
1765+
1766+
@tvm.script.ir_module
1767+
class expected_rsub2:
1768+
@R.function
1769+
def main(x: R.Tensor((10, 10), dtype="float32")) -> R.Tensor((10, 10), dtype="float32"):
1770+
with R.dataflow():
1771+
lv: R.Tensor((10, 10), dtype="float32") = R.subtract(R.const(5.0, "float32"), x)
1772+
gv: R.Tensor((10, 10), dtype="float32") = lv
1773+
R.output(gv)
1774+
return gv
1775+
1776+
verify_model(RSub1(), input_info1, {}, expected_rsub1)
1777+
verify_model(RSub2(), input_info2, {}, expected_rsub2)
1778+
1779+
17411780
def test_size():
17421781
input_info = [([1, 3, 10, 10], "float32")]
17431782

0 commit comments

Comments
 (0)