Skip to content

Commit 0e88580

Browse files
mshr-hShiboXing
authored andcommitted
[Relax][PyTorch] Support torch.bfloat16 dtype in pytorch frontend (apache#17894)
support bfloat16 dtype in pytorch frontend
1 parent 7bca542 commit 0e88580

File tree

3 files changed

+51
-0
lines changed

3 files changed

+51
-0
lines changed

python/tvm/relax/frontend/torch/base_fx_graph_translator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ def _convert_data_type(input_type: Union[str, torch.dtype], env: Optional[Dict]
5858
return "float32"
5959
elif input_type in ["float16", "torch.float16", torch.float16]:
6060
return "float16"
61+
elif input_type in ["bfloat16", "torch.bfloat16", torch.bfloat16]:
62+
return "bfloat16"
6163
elif input_type in ["int64", "torch.int64", torch.int64]:
6264
return "int64"
6365
elif input_type in ["int32", "torch.int32", torch.int32]:

tests/python/relax/test_frontend_from_exported_program.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4869,5 +4869,32 @@ def main(
48694869
verify_model(Linspace(), example_args, {}, Expected)
48704870

48714871

4872+
def test_bfloat16():
4873+
# TODO(mshr-h): Add tests for all the dtypes supported in fx frontend
4874+
example_args = (
4875+
torch.randn(10, 10, dtype=torch.bfloat16),
4876+
torch.randn(10, 10, dtype=torch.bfloat16),
4877+
)
4878+
4879+
class BFloat16Model(Module):
4880+
def forward(self, lhs: torch.Tensor, rhs: torch.Tensor):
4881+
return torch.ops.aten.add(lhs, rhs)
4882+
4883+
@tvm.script.ir_module
4884+
class expected:
4885+
@R.function
4886+
def main(
4887+
lhs: R.Tensor((10, 10), dtype="bfloat16"),
4888+
rhs: R.Tensor((10, 10), dtype="bfloat16"),
4889+
) -> R.Tuple(R.Tensor((10, 10), dtype="bfloat16")):
4890+
with R.dataflow():
4891+
lv: R.Tensor((10, 10), dtype="bfloat16") = relax.op.add(lhs, rhs)
4892+
gv: R.Tuple(R.Tensor((10, 10), dtype="bfloat16")) = (lv,)
4893+
R.output(gv)
4894+
return gv
4895+
4896+
verify_model(BFloat16Model(), example_args, {}, expected)
4897+
4898+
48724899
if __name__ == "__main__":
48734900
tvm.testing.main()

tests/python/relax/test_frontend_from_fx.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5249,5 +5249,27 @@ def main(
52495249
verify_model(Norm(p, dim=dim, keepdim=keepdim), input_info, {}, expected)
52505250

52515251

5252+
def test_bfloat16():
5253+
# TODO(mshr-h): Add tests for all the dtypes supported in EP frontend
5254+
class BFloat16Model(Module):
5255+
def forward(self, lhs: torch.Tensor, rhs: torch.Tensor):
5256+
return torch.ops.aten.add(lhs, rhs)
5257+
5258+
@tvm.script.ir_module
5259+
class Expected:
5260+
@R.function
5261+
def main(
5262+
lhs: R.Tensor((10, 10), dtype="bfloat16"),
5263+
rhs: R.Tensor((10, 10), dtype="bfloat16"),
5264+
) -> R.Tensor((10, 10), dtype="bfloat16"):
5265+
with R.dataflow():
5266+
lv: R.Tensor((10, 10), dtype="bfloat16") = relax.op.add(lhs, rhs)
5267+
gv: R.Tensor((10, 10), dtype="bfloat16") = lv
5268+
R.output(gv)
5269+
return gv
5270+
5271+
verify_model(BFloat16Model(), [([10, 10], "bfloat16"), ([10, 10], "bfloat16")], {}, Expected)
5272+
5273+
52525274
if __name__ == "__main__":
52535275
tvm.testing.main()

0 commit comments

Comments
 (0)