-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Closed
Labels
needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address itPRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug
Description
Expected behavior
When using torch.fx
to convert a PyTorch model containing in-place operations (e.g., bitwise_or_
), the resulting IR in TVM should accurately reflect the updated tensor and return the modified value.
Actual behavior
Currently, the FX-based tracing results in incorrect IR where the original input tensor is returned, instead of the tensor updated by the in-place operation. This leads to a semantic mismatch.
Example:
# Original PyTorch model
class Model(Module):
def forward(self, input: torch.Tensor, other: torch.Tensor):
input.bitwise_or_(other)
return input
Produces this incorrect FX-derived IR:
@R.function
def main(inp_0: R.Tensor((128, 128), dtype="int32"), inp_1: R.Tensor((128, 128), dtype="int32")) -> R.Tensor((128, 128), dtype="int32"):
with R.dataflow():
lv: R.Tensor((128, 128), dtype="int32") = R.bitwise_or(inp_0, inp_1)
gv: R.Tensor((128, 128), dtype="int32") = inp_0 # Incorrect: should return lv
R.output(gv)
return gv
Whereas using exported_program
gives the correct representation:
@R.function
def main(input: R.Tensor((128, 128), dtype="int32"), other: R.Tensor((128, 128), dtype="int32")) -> R.Tuple(R.Tensor((128, 128), dtype="int32")):
with R.dataflow():
lv: R.Tensor((128, 128), dtype="int32") = R.bitwise_or(input, other)
gv: R.Tuple(R.Tensor((128, 128), dtype="int32")) = (lv,)
R.output(gv)
return gv
- frontend
- bug
cc @shingjan
Metadata
Metadata
Assignees
Labels
needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address itPRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug