Skip to content

[aot_eager] [hf_Longformer] Cannot view a tensor with shape #93428

@anijain2305

Description

@anijain2305

🐛 Describe the bug

No response

Error logs

Traceback (most recent call last):
  File "/scratch/anijain/work/pytorch/repro.py", line 52, in <module>
    res = run_fwd_maybe_bwd(opt_mod, args)
  File "/scratch/anijain/work/pytorch/torch/_dynamo/debug_utils.py", line 505, in run_fwd_maybe_bwd
    out = gm(args)
  File "/scratch/anijain/work/pytorch/functorch/_src/aot_autograd.py", line 321, in g
    return f(*args)
  File "/scratch/anijain/work/pytorch/torch/nn/modules/module.py", line 1427, in _call_impl
    return forward_call(*input, **kwargs)
  File "/scratch/anijain/work/pytorch/torch/_dynamo/eval_frame.py", line 66, in forward
    return self.dynamo_ctx(self._orig_mod.forward)(*args, **kwargs)
  File "/scratch/anijain/work/pytorch/torch/_dynamo/eval_frame.py", line 174, in _fn
    return fn(*args, **kwargs)
  File "/scratch/anijain/work/pytorch/repro.py", line 25, in forward
    def forward(self, transpose, as_strided, transpose_8):
  File "/scratch/anijain/work/pytorch/torch/_dynamo/eval_frame.py", line 174, in _fn
    return fn(*args, **kwargs)
  File "/scratch/anijain/work/pytorch/functorch/_src/aot_autograd.py", line 951, in forward
    return compiled_f(
  File "/scratch/anijain/work/pytorch/functorch/_src/aot_autograd.py", line 937, in new_func
    compiled_fn = create_aot_dispatcher_function(
  File "/scratch/anijain/work/torchdynamo/torchdynamo/utils.py", line 86, in time_wrapper
    r = func(*args, **kwargs)
  File "/scratch/anijain/work/pytorch/functorch/_src/aot_autograd.py", line 657, in create_aot_dispatcher_function
    aot_dispatch_autograd(flat_fn, fake_flat_tensor_args, aot_config)
  File "/scratch/anijain/work/pytorch/functorch/_src/aot_autograd.py", line 486, in aot_dispatch_autograd
    fx_g = make_fx(
  File "/scratch/anijain/work/pytorch/torch/fx/experimental/proxy_tensor.py", line 657, in wrapped
    t = dispatch_trace(wrap_key(func, args, fx_tracer), tracer=fx_tracer, concrete_args=tuple(phs))
  File "/scratch/anijain/work/pytorch/torch/_dynamo/eval_frame.py", line 174, in _fn
    return fn(*args, **kwargs)
  File "/scratch/anijain/work/pytorch/torch/fx/experimental/proxy_tensor.py", line 417, in dispatch_trace
    graph = tracer.trace(root, concrete_args)
  File "/scratch/anijain/work/pytorch/torch/_dynamo/eval_frame.py", line 174, in _fn
    return fn(*args, **kwargs)
  File "/scratch/anijain/work/pytorch/torch/fx/_symbolic_trace.py", line 739, in trace
    (self.create_arg(fn(*args)),),
  File "/scratch/anijain/work/pytorch/torch/fx/_symbolic_trace.py", line 614, in flatten_fn
    tree_out = root_fn(*tree_args)
  File "/scratch/anijain/work/pytorch/torch/fx/experimental/proxy_tensor.py", line 431, in wrapped
    out = f(*tensors)
  File "/scratch/anijain/work/pytorch/functorch/_src/aot_autograd.py", line 189, in inner
    outs = f(*f_args, **f_kwargs)
  File "/scratch/anijain/work/pytorch/functorch/_src/aot_autograd.py", line 257, in joint_forward_backward
    backward_out = torch.autograd.grad(
  File "/scratch/anijain/work/pytorch/torch/autograd/__init__.py", line 300, in grad
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/scratch/anijain/work/pytorch/torch/fx/experimental/proxy_tensor.py", line 457, in __torch_dispatch__
    return self.inner_torch_dispatch(func, types, args, kwargs)
  File "/scratch/anijain/work/pytorch/torch/fx/experimental/proxy_tensor.py", line 482, in inner_torch_dispatch
    out = proxy_call(self, func, args, kwargs)
  File "/scratch/anijain/work/pytorch/torch/fx/experimental/proxy_tensor.py", line 321, in proxy_call
    out = func(*args, **kwargs)
  File "/scratch/anijain/work/pytorch/torch/_ops.py", line 285, in __call__
    return self._op(*args, **kwargs or {})
  File "/scratch/anijain/work/pytorch/torch/_subclasses/fake_tensor.py", line 887, in __torch_dispatch__
    r = func(*args, **kwargs)
  File "/scratch/anijain/work/pytorch/torch/_ops.py", line 285, in __call__
    return self._op(*args, **kwargs or {})
  File "/scratch/anijain/work/pytorch/torch/_refs/__init__.py", line 3851, in view
    return _reshape_view_helper(a, *shape, allow_copy=False)
  File "/scratch/anijain/work/pytorch/torch/_refs/__init__.py", line 3137, in _reshape_view_helper
    raise ValueError(msg)
ValueError: Cannot view a tensor with shape torch.Size([4, 12, 1024, 513]) and strides (6303744, 513, 6156, 1) as a tensor with shape (48, 4, 256, 513)!

Minified repro


from math import inf
import torch
from torch import tensor, device
import torch.fx as fx
import torch._dynamo
from torch._dynamo.testing import rand_strided
from torch._dynamo.debug_utils import run_fwd_maybe_bwd
from torch._dynamo.debug_utils import same_two_models

# REPLACEABLE COMMENT FOR TESTING PURPOSES

args = [((1024, 4, 768), (768, 786432, 1), torch.float32, 'cuda', True), ((48, 3, 512, 64), (64, 786432, 3072, 1), torch.float16, 'cuda', True), ((4, 1024, 1, 513), (525312, 513, 525312, 1), torch.float16, 'cuda', False)]
args = [rand_strided(sh, st, dt, dev).requires_grad_(rg) for (sh, st, dt, dev, rg) in args]


from torch.nn import *
class Repro(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.self_self_layer_0_attention_self_key = Linear(in_features=768, out_features=768, bias=True).cuda()



    def forward(self, transpose, as_strided, transpose_8):
        self_self_layer_0_attention_self_key = self.self_self_layer_0_attention_self_key(transpose);  transpose = None
        view_1 = self_self_layer_0_attention_self_key.view(1024, 4, 12, 64);  self_self_layer_0_attention_self_key = None
        transpose_2 = view_1.transpose(0, 1);  view_1 = None
        transpose_4 = transpose_2.transpose(1, 2);  transpose_2 = None
        reshape_1 = transpose_4.reshape(48, 1024, 64);  transpose_4 = None
        view_3 = reshape_1.view(48, 2, 512, 64);  reshape_1 = None
        as_strided_1 = view_3.as_strided(size = [48, 3, 512, 64], stride = [64, 786432, 3072, 1]);  view_3 = None
        einsum = torch.functional.einsum('bcxd,bcyd->bcxy', (as_strided, as_strided_1));  as_strided = as_strided_1 = None
        pad = torch.nn.functional.pad(einsum, (0, 0, 0, 1));  einsum = None
        view_4 = pad.view(48, 3, 512, 513);  pad = None
        new_empty = view_4.new_empty((48, 4, 256, 513))
        getitem_3 = view_4[(slice(None, None, None), 0, slice(None, 255, None), slice(-255, None, None))];  view_4 = None
        new_empty[(slice(None, None, None), 0, slice(1, 256, None), slice(1, 256, None))] = getitem_3;  setitem_3 = new_empty;  getitem_3 = None
        view_5 = new_empty.view(4, 12, 1024, 513);  new_empty = None
        transpose_5 = view_5.transpose(2, 1);  view_5 = None
        transpose_5 += transpose_8;  iadd = transpose_5;  transpose_5 = transpose_8 = None
        return (iadd,)



mod = Repro()
opt_mod = torch._dynamo.optimize("aot_eager")(mod)


with torch.cuda.amp.autocast(enabled=True):
    ref = run_fwd_maybe_bwd(mod, args)
    res = run_fwd_maybe_bwd(opt_mod, args)


cc @ezyang @msaroufim @wconstab @bdhirsh @zou3519 @soumith @ngimel

Metadata

Metadata

Assignees

Labels

module: aotdispatchumbrella label for AOTAutograd issuesmodule: pt2-dispatcherPT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,oncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions