Skip to content

[Bug] [FRONTEND][PYTORCH][FxGraph] Incorrect Handling of In-Place Ops in FX-Based TVM Frontend #17874

@kavin-sai-krishna

Description

@kavin-sai-krishna

Expected behavior

When using torch.fx to convert a PyTorch model containing in-place operations (e.g., bitwise_or_), the resulting IR in TVM should accurately reflect the updated tensor and return the modified value.

Actual behavior

Currently, the FX-based tracing results in incorrect IR where the original input tensor is returned, instead of the tensor updated by the in-place operation. This leads to a semantic mismatch.

Example:

# Original PyTorch model
class Model(Module):
    def forward(self, input: torch.Tensor, other: torch.Tensor):
        input.bitwise_or_(other)
        return input

Produces this incorrect FX-derived IR:

@R.function
def main(inp_0: R.Tensor((128, 128), dtype="int32"), inp_1: R.Tensor((128, 128), dtype="int32")) -> R.Tensor((128, 128), dtype="int32"):
    with R.dataflow():
        lv: R.Tensor((128, 128), dtype="int32") = R.bitwise_or(inp_0, inp_1)
        gv: R.Tensor((128, 128), dtype="int32") = inp_0  # Incorrect: should return lv
        R.output(gv)
    return gv

Whereas using exported_program gives the correct representation:

@R.function
def main(input: R.Tensor((128, 128), dtype="int32"), other: R.Tensor((128, 128), dtype="int32")) -> R.Tuple(R.Tensor((128, 128), dtype="int32")):
    with R.dataflow():
        lv: R.Tensor((128, 128), dtype="int32") = R.bitwise_or(input, other)
        gv: R.Tuple(R.Tensor((128, 128), dtype="int32")) = (lv,)
        R.output(gv)
    return gv
  • frontend
  • bug

cc @shingjan

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions