Skip to content

Commit 3f1625b

Browse files
kavin-sai-krishnaShiboXing
authored andcommitted
[Relax][Pytorch] Add masked_fill op support in ExportedProgram (apache#17850)
* Add masked_fill support in exportedProgram * Fix lint issues
1 parent 47e13fc commit 3f1625b

File tree

4 files changed

+42
-17
lines changed

4 files changed

+42
-17
lines changed

python/tvm/relax/frontend/torch/base_fx_graph_translator.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1349,6 +1349,23 @@ def _index_select(self, node: fx.Node) -> relax.Var:
13491349
index = self.env[node.args[2]]
13501350
return self.block_builder.emit(relax.op.take(x, index, dim))
13511351

1352+
def _inplace_masked_fill(self, node: fx.Node) -> relax.Var:
1353+
x = self.env[node.args[0]]
1354+
mask = self.env[node.args[1]]
1355+
value = node.args[2]
1356+
rx_value = relax.const(value)
1357+
values = self.block_builder.emit(relax.op.full_like(x, rx_value))
1358+
output = self.block_builder.emit(relax.op.where(mask, values, x))
1359+
self.env[node.args[0]] = output
1360+
return output
1361+
1362+
def _masked_fill(self, node: fx.Node) -> relax.Var:
1363+
x = self.env[node.args[0]]
1364+
mask = self.env[node.args[1]]
1365+
rx_value = relax.const(node.args[2])
1366+
values = self.block_builder.emit(relax.op.full_like(x, rx_value))
1367+
return self.block_builder.emit(relax.op.where(mask, values, x))
1368+
13521369
def _new_ones(self, node: fx.Node) -> relax.Var:
13531370
args = self.retrieve_args(node)
13541371
self_var = args[0]

python/tvm/relax/frontend/torch/exported_program_translator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,7 @@ def create_convert_map(
448448
"full_like.default": self._full_like,
449449
"index_select.default": self._index_select,
450450
"lift_fresh_copy.default": self._to_copy,
451+
"masked_fill.Scalar": self._masked_fill,
451452
"new_ones.default": self._new_ones,
452453
"one_hot.default": self._one_hot,
453454
"ones.default": self._ones,

python/tvm/relax/frontend/torch/fx_translator.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -476,23 +476,6 @@ def _inplace_fill(self, node: fx.Node) -> relax.Var:
476476
self.env[node.args[0]] = filled
477477
return filled
478478

479-
def _inplace_masked_fill(self, node: fx.Node) -> relax.Var:
480-
x = self.env[node.args[0]]
481-
mask = self.env[node.args[1]]
482-
value = node.args[2]
483-
rx_value = relax.const(value)
484-
values = self.block_builder.emit(relax.op.full_like(x, rx_value))
485-
output = self.block_builder.emit(relax.op.where(mask, values, x))
486-
self.env[node.args[0]] = output
487-
return output
488-
489-
def _masked_fill(self, node: fx.Node) -> relax.Var:
490-
x = self.env[node.args[0]]
491-
mask = self.env[node.args[1]]
492-
rx_value = relax.const(node.args[2])
493-
values = self.block_builder.emit(relax.op.full_like(x, rx_value))
494-
return self.block_builder.emit(relax.op.where(mask, values, x))
495-
496479
def _masked_scatter(self, node: fx.Node) -> relax.Var:
497480
x = self.env[node.args[0]]
498481
mask = self.env[node.args[1]]

tests/python/relax/test_frontend_from_exported_program.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3260,6 +3260,30 @@ def main(
32603260
verify_model(Fill(), example_args, {}, Expected)
32613261

32623262

3263+
def test_masked_fill():
3264+
class Masked_Fill(Module):
3265+
def forward(self, input: torch.Tensor, mask: torch.Tensor):
3266+
return torch.masked_fill(input, mask, 0)
3267+
3268+
@tvm.script.ir_module
3269+
class Expected:
3270+
@R.function
3271+
def main(
3272+
input: R.Tensor((128, 128), dtype="float32"), mask: R.Tensor((128, 128), dtype="bool")
3273+
) -> R.Tuple(R.Tensor((128, 128), dtype="float32")):
3274+
with R.dataflow():
3275+
lv: R.Tensor((128, 128), dtype="float32") = R.full_like(
3276+
input, R.const(0, "int32"), dtype="void"
3277+
)
3278+
lv1: R.Tensor((128, 128), dtype="float32") = R.where(mask, lv, input)
3279+
gv: R.Tuple(R.Tensor((128, 128), dtype="float32")) = (lv1,)
3280+
R.output(gv)
3281+
return gv
3282+
3283+
example_args = (torch.randn(128, 128, dtype=torch.float32), torch.rand(128, 128) < 0.5)
3284+
verify_model(Masked_Fill(), example_args, {}, Expected)
3285+
3286+
32633287
def test_new_ones():
32643288
class NewOnes(Module):
32653289
def forward(self, x):

0 commit comments

Comments
 (0)