Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Conversation

vfdev-5
Copy link
Contributor

@vfdev-5 vfdev-5 commented Mar 29, 2022

Fixes #626

Description:

  • Fixing tensor.numpy on wrapped tensors
  • Added a test

vfdev-5 added 2 commits March 29, 2022 11:51
Fixes pytorch#626

Description:
- Fixing tensor.numpy on wrapped tensors
Comment on lines +109 to +128
level = _C.maybe_get_level(tensor)
if level == -1:
return _old_numpy(tensor)

if _C.is_functionaltensor(tensor):
# Since we're unwrapping the FunctionalTensorWrapper, we need to make sure
# that it's up to date first
torch._sync(tensor)

value = _C.get_unwrapped(tensor)
dl_enabled = _C.tls_set_is_included()
try:
# Disable temporarily kDynamicLayerFrontModeKey/kDynamicLayerBackModeKey as included dispatch keys
if (dl_enabled):
_C._set_dynamic_layer_keys_included(False)
return value.numpy()
finally:
# Reenable kDynamicLayerFrontModeKey/kDynamicLayerBackModeKey as included dispatch keys
if (dl_enabled):
_C._set_dynamic_layer_keys_included(True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, so this is a little more complicated than this I think.

When someone calls .numpy() under vmap, we probably want to error out. Otherwise some weird things might happen:

def f(x):
  return torch.tensor(x.numpy())

x = torch.randn(B)
vmap(f)(x) # returns a Tensor of size B, B -- is that what we want?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When someone calls .numpy() under the grad transform then we should support this (as long as there are no vmaps involved). I'm not sure what the best way to support this is... one thing we can do is keep unwrapping the Tensor and seeing that no BatchedTensors are involved.

In the long-term we want a better fix for this that perhaps involves making the pytorch dispatcher recognize .numpy() as an operation

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Issue with tensor.numpy() for wrapped tensors
3 participants