Skip to content

Commit 4e41b42

Browse files
authored
[Relax][PyTorch] Support narrow and broadcast_to ops for ExportedProgram importer (#17830)
* Update exported_program_translator.py * Update test_frontend_from_exported_program.py * Update test_frontend_from_exported_program.py * Update test_frontend_from_exported_program.py
1 parent 820642b commit 4e41b42

File tree

2 files changed

+59
-0
lines changed

2 files changed

+59
-0
lines changed

python/tvm/relax/frontend/torch/exported_program_translator.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,13 @@ def _upsample_nearest2d(self, node: fx.node) -> relax.Var:
202202

203203
########## Manipulation ##########
204204

205+
def _narrow(self, node: fx.Node) -> relax.Var:
206+
x = self.env[node.args[0]]
207+
dim = node.args[1]
208+
start = node.args[2]
209+
length = node.args[3]
210+
return self.block_builder.emit(relax.op.strided_slice(x, [dim], [start], [length]))
211+
205212
def _select(self, node: fx.Node) -> relax.Var:
206213
x = self.env[node.args[0]]
207214
dim = node.args[1]
@@ -390,6 +397,7 @@ def create_convert_map(
390397
"where.self": self._where,
391398
# tensor manipulation
392399
"argsort.default": self._argsort,
400+
"broadcast_to.default": self._broadcast_to,
393401
"cat.default": self._cat,
394402
"chunk.default": self._chunk,
395403
"clamp.Tensor": self._clamp,
@@ -402,6 +410,7 @@ def create_convert_map(
402410
"flatten.using_ints": self._flatten,
403411
"flip.default": self._flip,
404412
"gather.default": self._gather,
413+
"narrow.default": self._narrow,
405414
"permute.default": self._permute,
406415
"repeat.default": self._repeat,
407416
"select.int": self._select,

tests/python/relax/test_frontend_from_exported_program.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3856,5 +3856,55 @@ def main(
38563856
verify_model(DynamicModel(), example_args, {}, Expected, dynamic_shapes=dynamic_shapes)
38573857

38583858

3859+
def test_broadcast_to():
3860+
class BroadcastTo(Module):
3861+
def forward(self, x):
3862+
return torch.broadcast_to(x, (5, 3))
3863+
3864+
@tvm.script.ir_module
3865+
class Expected:
3866+
@R.function
3867+
def main(
3868+
x: R.Tensor((5, 1), dtype="float32")
3869+
) -> R.Tuple(R.Tensor((5, 3), dtype="float32")):
3870+
with R.dataflow():
3871+
lv: R.Tensor((5, 3), dtype="float32") = R.broadcast_to(x, R.shape([5, 3]))
3872+
gv: R.Tuple(R.Tensor((5, 3), dtype="float32")) = (lv,)
3873+
R.output(gv)
3874+
3875+
return gv
3876+
3877+
example_args = (torch.randn(5, 1, dtype=torch.float32),)
3878+
verify_model(BroadcastTo(), example_args, {}, Expected)
3879+
3880+
3881+
def test_narrow():
3882+
class Narrow(Module):
3883+
def forward(self, x):
3884+
return torch.narrow(x, 1, 0, 2)
3885+
3886+
@tvm.script.ir_module
3887+
class Expected:
3888+
@R.function
3889+
def main(
3890+
x: R.Tensor((5, 3), dtype="float32")
3891+
) -> R.Tuple(R.Tensor((5, 2), dtype="float32")):
3892+
with R.dataflow():
3893+
lv: R.Tensor((5, 2), dtype="float32") = R.strided_slice(
3894+
x,
3895+
(R.prim_value(1),),
3896+
(R.prim_value(0),),
3897+
(R.prim_value(2),),
3898+
assume_inbound=False,
3899+
)
3900+
gv: R.Tuple(R.Tensor((5, 2), dtype="float32")) = (lv,)
3901+
R.output(gv)
3902+
3903+
return gv
3904+
3905+
example_args = (torch.randn(5, 3, dtype=torch.float32),)
3906+
verify_model(Narrow(), example_args, {}, Expected)
3907+
3908+
38593909
if __name__ == "__main__":
38603910
tvm.testing.main()

0 commit comments

Comments
 (0)