Skip to content

Commit fe1b228

Browse files
[Relax][Pytorch] Add support for bitwise_or op support (#17871)
This PR adds support for the bitwise OR operation used in the Mistral/Mistral-3B-Instruct model.
1 parent b5b0337 commit fe1b228

File tree

4 files changed

+10
-0
lines changed

4 files changed

+10
-0
lines changed

python/tvm/relax/frontend/torch/exported_program_translator.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,10 @@ def create_convert_map(
328328
# binary
329329
"add.Tensor": self._binary_op(relax.op.add, operator.add),
330330
"add_.Tensor": self._binary_op(relax.op.add, operator.add),
331+
"bitwise_or_.Scalar": self._binary_op(relax.op.bitwise_or, operator.or_),
332+
"bitwise_or.Scalar": self._binary_op(relax.op.bitwise_or, operator.or_),
333+
"bitwise_or_.Tensor": self._binary_op(relax.op.bitwise_or, operator.or_),
334+
"bitwise_or.Tensor": self._binary_op(relax.op.bitwise_or, operator.or_),
331335
"div.Tensor": self._binary_op(relax.op.divide, operator.truediv),
332336
"eq.Scalar": self._binary_op(relax.op.equal, operator.eq),
333337
"eq.Tensor": self._binary_op(relax.op.equal, operator.eq),

python/tvm/relax/frontend/torch/fx_translator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -679,6 +679,8 @@ def create_convert_map(
679679
# binary
680680
"add": self._binary_op(relax.op.add, operator.add),
681681
"and_": self._binary_op(relax.op.bitwise_and, operator.and_),
682+
"bitwise_or_": self._binary_op(relax.op.bitwise_or, operator.or_),
683+
"bitwise_or": self._binary_op(relax.op.bitwise_or, operator.or_),
682684
"eq": self._binary_op(relax.op.equal, operator.eq),
683685
"floordiv": self._binary_op(relax.op.floor_divide, operator.floordiv),
684686
"ge": self._binary_op(relax.op.greater_equal, operator.ge),

tests/python/relax/test_frontend_from_exported_program.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -845,6 +845,8 @@ def main(
845845
operator_binary_1 = [
846846
(operator.add, R.add),
847847
(torch.ops.aten.add_, R.add),
848+
(torch.ops.aten.bitwise_or, R.bitwise_or),
849+
(torch.ops.aten.bitwise_or_, R.bitwise_or),
848850
(operator.sub, R.subtract),
849851
(operator.mul, R.multiply),
850852
(torch.ops.aten.mul_, R.multiply),

tests/python/relax/test_frontend_from_fx.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1769,6 +1769,8 @@ def main(
17691769

17701770

17711771
operator_binary_3 = [
1772+
(torch.ops.aten.bitwise_or_, R.bitwise_or),
1773+
(torch.ops.aten.bitwise_or, R.bitwise_or),
17721774
(operator.lshift, R.left_shift),
17731775
(operator.rshift, R.right_shift),
17741776
(operator.and_, R.bitwise_and),

0 commit comments

Comments
 (0)