@@ -132,6 +132,54 @@ def convert(node: fx.Node) -> relax.Var:
132
132
133
133
return convert
134
134
135
+ ########## Binary Ops ##############
136
+
137
+ def _binary_op_inplace (self , relax_op : Callable , intrinsic_op : Callable ) -> Callable :
138
+ from torch import fx
139
+
140
+ def convert (node : fx .Node ) -> relax .Var :
141
+ def promote_binary_op_args (lhs , rhs ):
142
+ if isinstance (lhs , relax .Expr ) and isinstance (rhs , relax .Expr ):
143
+ return lhs , rhs
144
+ elif isinstance (lhs , relax .Expr ):
145
+ assert isinstance (lhs .struct_info , relax .TensorStructInfo )
146
+ return lhs , relax .const (rhs , lhs .struct_info .dtype )
147
+ elif isinstance (rhs , relax .Expr ):
148
+ assert isinstance (rhs .struct_info , relax .TensorStructInfo )
149
+ return relax .const (lhs , rhs .struct_info .dtype ), rhs
150
+ else :
151
+ assert False
152
+
153
+ def call_binary_op (op , lhs , rhs ):
154
+ lhs , rhs = promote_binary_op_args (lhs , rhs )
155
+ return self .block_builder .emit (op (lhs , rhs ))
156
+
157
+ lhs , rhs = self .retrieve_args (node )
158
+ if isinstance (lhs , relax .Var ) or isinstance (rhs , relax .Var ):
159
+ output = call_binary_op (relax_op , lhs , rhs )
160
+ self .env [node .args [0 ]] = output
161
+ return output
162
+
163
+ elif isinstance (lhs , relax .expr .Constant ):
164
+ output = call_binary_op (
165
+ relax_op , lhs , relax .const (rhs , dtype = lhs .struct_info .dtype )
166
+ )
167
+ self .env [node .args [0 ]] = output
168
+ return output
169
+
170
+ elif isinstance (rhs , relax .expr .Constant ):
171
+ output = call_binary_op (
172
+ relax_op , relax .const (lhs , dtype = rhs .struct_info .dtype ), rhs
173
+ )
174
+ self .env [node .args [0 ]] = output
175
+ return output
176
+
177
+ output = intrinsic_op (lhs , rhs )
178
+ self .env [node .args [0 ]] = output
179
+ return output
180
+
181
+ return convert
182
+
135
183
########## Neural Network ##########
136
184
137
185
def _adaptive_avg_pool2d_module (self , node : fx .Node ) -> relax .Var :
@@ -679,7 +727,7 @@ def create_convert_map(
679
727
# binary
680
728
"add" : self ._binary_op (relax .op .add , operator .add ),
681
729
"and_" : self ._binary_op (relax .op .bitwise_and , operator .and_ ),
682
- "bitwise_or_" : self ._binary_op (relax .op .bitwise_or , operator .or_ ),
730
+ "bitwise_or_" : self ._binary_op_inplace (relax .op .bitwise_or , operator .or_ ),
683
731
"bitwise_or" : self ._binary_op (relax .op .bitwise_or , operator .or_ ),
684
732
"eq" : self ._binary_op (relax .op .equal , operator .eq ),
685
733
"floordiv" : self ._binary_op (relax .op .floor_divide , operator .floordiv ),
0 commit comments