Skip to content

Commit 31c95f1

Browse files
kavin-sai-krishnaShiboXing
authored andcommitted
[BugFix][Relax][Pytorch] Incorrect Handling of In-Place Ops in FX-Based TVM Frontend (apache#17875)
1 parent 1004fd8 commit 31c95f1

File tree

1 file changed

+49
-1
lines changed

1 file changed

+49
-1
lines changed

python/tvm/relax/frontend/torch/fx_translator.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,54 @@ def convert(node: fx.Node) -> relax.Var:
132132

133133
return convert
134134

135+
########## Binary Ops ##############
136+
137+
def _binary_op_inplace(self, relax_op: Callable, intrinsic_op: Callable) -> Callable:
138+
from torch import fx
139+
140+
def convert(node: fx.Node) -> relax.Var:
141+
def promote_binary_op_args(lhs, rhs):
142+
if isinstance(lhs, relax.Expr) and isinstance(rhs, relax.Expr):
143+
return lhs, rhs
144+
elif isinstance(lhs, relax.Expr):
145+
assert isinstance(lhs.struct_info, relax.TensorStructInfo)
146+
return lhs, relax.const(rhs, lhs.struct_info.dtype)
147+
elif isinstance(rhs, relax.Expr):
148+
assert isinstance(rhs.struct_info, relax.TensorStructInfo)
149+
return relax.const(lhs, rhs.struct_info.dtype), rhs
150+
else:
151+
assert False
152+
153+
def call_binary_op(op, lhs, rhs):
154+
lhs, rhs = promote_binary_op_args(lhs, rhs)
155+
return self.block_builder.emit(op(lhs, rhs))
156+
157+
lhs, rhs = self.retrieve_args(node)
158+
if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var):
159+
output = call_binary_op(relax_op, lhs, rhs)
160+
self.env[node.args[0]] = output
161+
return output
162+
163+
elif isinstance(lhs, relax.expr.Constant):
164+
output = call_binary_op(
165+
relax_op, lhs, relax.const(rhs, dtype=lhs.struct_info.dtype)
166+
)
167+
self.env[node.args[0]] = output
168+
return output
169+
170+
elif isinstance(rhs, relax.expr.Constant):
171+
output = call_binary_op(
172+
relax_op, relax.const(lhs, dtype=rhs.struct_info.dtype), rhs
173+
)
174+
self.env[node.args[0]] = output
175+
return output
176+
177+
output = intrinsic_op(lhs, rhs)
178+
self.env[node.args[0]] = output
179+
return output
180+
181+
return convert
182+
135183
########## Neural Network ##########
136184

137185
def _adaptive_avg_pool2d_module(self, node: fx.Node) -> relax.Var:
@@ -679,7 +727,7 @@ def create_convert_map(
679727
# binary
680728
"add": self._binary_op(relax.op.add, operator.add),
681729
"and_": self._binary_op(relax.op.bitwise_and, operator.and_),
682-
"bitwise_or_": self._binary_op(relax.op.bitwise_or, operator.or_),
730+
"bitwise_or_": self._binary_op_inplace(relax.op.bitwise_or, operator.or_),
683731
"bitwise_or": self._binary_op(relax.op.bitwise_or, operator.or_),
684732
"eq": self._binary_op(relax.op.equal, operator.eq),
685733
"floordiv": self._binary_op(relax.op.floor_divide, operator.floordiv),

0 commit comments

Comments
 (0)