diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 3e81ff1f0bfe..888a3cef22ae 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -407,6 +407,20 @@ def call_binary_op(op, lhs, rhs): return convert + def _fmod(self, node: fx.Node): + args = self.retrieve_args(node) + lhs = args[0] + rhs = args[1] + if isinstance(lhs, relax.Expr) and isinstance(rhs, relax.Expr): + return self.block_builder.emit(relax.op.mod(lhs, rhs)) + elif isinstance(lhs, relax.Expr): + rhs = relax.const(rhs, lhs.struct_info.dtype) + elif isinstance(rhs, relax.Expr): + lhs = relax.const(lhs, rhs.struct_info.dtype) + else: + assert False + return self.block_builder.emit(relax.op.mod(lhs, rhs)) + def _rsub(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) lhs = args[0] diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index a3ab575c4b78..e9e510e19bc1 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -344,6 +344,8 @@ def create_convert_map( "eq.Scalar": self._binary_op(relax.op.equal, operator.eq), "eq.Tensor": self._binary_op(relax.op.equal, operator.eq), "floor_divide.default": self._binary_op(relax.op.floor_divide, operator.floordiv), + "fmod.Scalar": self._fmod, + "fmod.Tensor": self._fmod, "logaddexp.default": self._binary_op(relax.op.log_add_exp, torch.logaddexp), "ge.Scalar": self._binary_op(relax.op.greater_equal, operator.ge), "ge.Tensor": self._binary_op(relax.op.greater_equal, operator.ge), diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 18dba2d988f2..16c9cef6e5cd 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -731,6 +731,7 @@ def create_convert_map( "bitwise_or": self._binary_op(relax.op.bitwise_or, operator.or_), "eq": self._binary_op(relax.op.equal, operator.eq), "floordiv": self._binary_op(relax.op.floor_divide, operator.floordiv), + "fmod": self._fmod, "ge": self._binary_op(relax.op.greater_equal, operator.ge), "gt": self._binary_op(relax.op.greater, operator.gt), "iadd": self._binary_op(relax.op.add, operator.add), diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index e3b6f4ad9c17..9e3ae8d2217f 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -852,6 +852,7 @@ def main( (torch.ops.aten.mul_, R.multiply), (operator.truediv, R.divide), (operator.floordiv, R.floor_divide), + (torch.ops.aten.fmod, R.mod), (operator.pow, R.power), (operator.mod, R.floor_mod), (operator.and_, R.bitwise_and), diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 4003202d4f55..ef116a2051ca 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -1653,6 +1653,7 @@ def main( (operator.mul, R.multiply), (operator.truediv, R.divide), (operator.floordiv, R.floor_divide), + (torch.ops.aten.fmod, R.mod), (operator.pow, R.power), (operator.mod, R.floor_mod), ]