Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions pytensor/link/jax/dispatch/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ def incsubtensor(x, y, *ilist, jax_fn=jax_fn, idx_list=idx_list):

@jax_funcify.register(AdvancedIncSubtensor)
def jax_funcify_AdvancedIncSubtensor(op, node, **kwargs):
idx_list = getattr(op, "idx_list", None)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why getattr? The op should always have it


if getattr(op, "set_instead_of_inc", False):

def jax_fn(x, indices, y):
Expand All @@ -87,8 +89,11 @@ def jax_fn(x, indices, y):
def jax_fn(x, indices, y):
return x.at[indices].add(y)

def advancedincsubtensor(x, y, *ilist, jax_fn=jax_fn):
return jax_fn(x, ilist, y)
def advancedincsubtensor(x, y, *ilist, jax_fn=jax_fn, idx_list=idx_list):
indices = indices_from_subtensor(ilist, idx_list)
if len(indices) == 1:
indices = indices[0]
return jax_fn(x, indices, y)

return advancedincsubtensor

Expand Down
44 changes: 23 additions & 21 deletions pytensor/link/numba/dispatch/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,28 +107,30 @@ def {function_name}({", ".join(input_names)}):
@numba_funcify.register(AdvancedIncSubtensor)
def numba_funcify_AdvancedSubtensor(op, node, **kwargs):
if isinstance(op, AdvancedSubtensor):
x, y, idxs = node.inputs[0], None, node.inputs[1:]
x, y, tensor_inputs = node.inputs[0], None, node.inputs[1:]
else:
x, y, *idxs = node.inputs

basic_idxs = [
idx
for idx in idxs
if (
isinstance(idx.type, NoneTypeT)
or (isinstance(idx.type, SliceType) and not is_full_slice(idx))
)
]
adv_idxs = [
{
"axis": i,
"dtype": idx.type.dtype,
"bcast": idx.type.broadcastable,
"ndim": idx.type.ndim,
}
for i, idx in enumerate(idxs)
if isinstance(idx.type, TensorType)
]
x, y, *tensor_inputs = node.inputs

# Reconstruct indexing information from idx_list and tensor inputs
basic_idxs = []
adv_idxs = []
input_idx = 0

for i, entry in enumerate(op.idx_list):
if isinstance(entry, slice):
# Basic slice index
basic_idxs.append(entry)
elif isinstance(entry, Type):
# Advanced tensor index
if input_idx < len(tensor_inputs):
idx_input = tensor_inputs[input_idx]
adv_idxs.append({
"axis": i,
"dtype": idx_input.type.dtype,
"bcast": idx_input.type.broadcastable,
"ndim": idx_input.type.ndim,
})
input_idx += 1

# Special implementation for consecutive integer vector indices
if (
Expand Down
21 changes: 15 additions & 6 deletions pytensor/link/pytorch/dispatch/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,10 @@ def makeslice(start, stop, step):
@pytorch_funcify.register(AdvancedSubtensor1)
@pytorch_funcify.register(AdvancedSubtensor)
def pytorch_funcify_AdvSubtensor(op, node, **kwargs):
def advsubtensor(x, *indices):
idx_list = getattr(op, "idx_list", None)

def advsubtensor(x, *flattened_indices):
indices = indices_from_subtensor(flattened_indices, idx_list)
check_negative_steps(indices)
return x[indices]

Expand Down Expand Up @@ -102,12 +105,14 @@ def inc_subtensor(x, y, *flattened_indices):
@pytorch_funcify.register(AdvancedIncSubtensor)
@pytorch_funcify.register(AdvancedIncSubtensor1)
def pytorch_funcify_AdvancedIncSubtensor(op, node, **kwargs):
idx_list = getattr(op, "idx_list", None)
inplace = op.inplace
ignore_duplicates = getattr(op, "ignore_duplicates", False)

if op.set_instead_of_inc:

def adv_set_subtensor(x, y, *indices):
def adv_set_subtensor(x, y, *flattened_indices):
indices = indices_from_subtensor(flattened_indices, idx_list)
check_negative_steps(indices)
if isinstance(op, AdvancedIncSubtensor1):
op._check_runtime_broadcasting(node, x, y, indices)
Expand All @@ -120,7 +125,8 @@ def adv_set_subtensor(x, y, *indices):

elif ignore_duplicates:

def adv_inc_subtensor_no_duplicates(x, y, *indices):
def adv_inc_subtensor_no_duplicates(x, y, *flattened_indices):
indices = indices_from_subtensor(flattened_indices, idx_list)
check_negative_steps(indices)
if isinstance(op, AdvancedIncSubtensor1):
op._check_runtime_broadcasting(node, x, y, indices)
Expand All @@ -132,13 +138,16 @@ def adv_inc_subtensor_no_duplicates(x, y, *indices):
return adv_inc_subtensor_no_duplicates

else:
if any(isinstance(idx.type, SliceType) for idx in node.inputs[2:]):
# Check if we have slice indexing in idx_list
has_slice_indexing = any(isinstance(entry, slice) for entry in idx_list) if idx_list else False
if has_slice_indexing:
raise NotImplementedError(
"IncSubtensor with potential duplicates indexes and slice indexing not implemented in PyTorch"
)

def adv_inc_subtensor(x, y, *indices):
# Not needed because slices aren't supported
def adv_inc_subtensor(x, y, *flattened_indices):
indices = indices_from_subtensor(flattened_indices, idx_list)
# Not needed because slices aren't supported in this path
# check_negative_steps(indices)
if not inplace:
x = x.clone()
Expand Down
78 changes: 71 additions & 7 deletions pytensor/tensor/rewriting/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,18 @@ def local_replace_AdvancedSubtensor(fgraph, node):
return

indexed_var = node.inputs[0]
indices = node.inputs[1:]
tensor_inputs = node.inputs[1:]

# Reconstruct indices from idx_list and tensor inputs
indices = []
input_idx = 0
for entry in node.op.idx_list:
if isinstance(entry, slice):
indices.append(entry)
elif isinstance(entry, Type):
if input_idx < len(tensor_inputs):
indices.append(tensor_inputs[input_idx])
input_idx += 1

axis = get_advsubtensor_axis(indices)

Expand All @@ -255,7 +266,18 @@ def local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1(fgraph, node):

res = node.inputs[0]
val = node.inputs[1]
indices = node.inputs[2:]
tensor_inputs = node.inputs[2:]

# Reconstruct indices from idx_list and tensor inputs
indices = []
input_idx = 0
for entry in node.op.idx_list:
if isinstance(entry, slice):
indices.append(entry)
elif isinstance(entry, Type):
if input_idx < len(tensor_inputs):
indices.append(tensor_inputs[input_idx])
input_idx += 1

axis = get_advsubtensor_axis(indices)

Expand Down Expand Up @@ -1751,9 +1773,22 @@ def ravel_multidimensional_bool_idx(fgraph, node):
x[eye(3, dtype=bool)].set(y) -> x.ravel()[eye(3).ravel()].set(y).reshape(x.shape)
"""
if isinstance(node.op, AdvancedSubtensor):
x, *idxs = node.inputs
x = node.inputs[0]
tensor_inputs = node.inputs[1:]
else:
x, y, *idxs = node.inputs
x, y = node.inputs[0], node.inputs[1]
tensor_inputs = node.inputs[2:]

# Reconstruct indices from idx_list and tensor inputs
idxs = []
input_idx = 0
for entry in node.op.idx_list:
if isinstance(entry, slice):
idxs.append(entry)
elif isinstance(entry, Type):
if input_idx < len(tensor_inputs):
idxs.append(tensor_inputs[input_idx])
input_idx += 1

if any(
(
Expand Down Expand Up @@ -1791,12 +1826,41 @@ def ravel_multidimensional_bool_idx(fgraph, node):
new_idxs[bool_idx_pos] = raveled_bool_idx

if isinstance(node.op, AdvancedSubtensor):
new_out = node.op(raveled_x, *new_idxs)
# Create new AdvancedSubtensor with updated idx_list
new_idx_list = list(node.op.idx_list)
new_tensor_inputs = list(tensor_inputs)

# Update the idx_list and tensor_inputs for the raveled boolean index
input_idx = 0
for i, entry in enumerate(node.op.idx_list):
if isinstance(entry, Type):
if input_idx == bool_idx_pos:
new_tensor_inputs[input_idx] = raveled_bool_idx
input_idx += 1

new_out = AdvancedSubtensor(new_idx_list)(raveled_x, *new_tensor_inputs)
else:
# Create new AdvancedIncSubtensor with updated idx_list
new_idx_list = list(node.op.idx_list)
new_tensor_inputs = list(tensor_inputs)

# Update the tensor_inputs for the raveled boolean index
input_idx = 0
for i, entry in enumerate(node.op.idx_list):
if isinstance(entry, Type):
if input_idx == bool_idx_pos:
new_tensor_inputs[input_idx] = raveled_bool_idx
input_idx += 1

# The dimensions of y that correspond to the boolean indices
# must already be raveled in the original graph, so we don't need to do anything to it
new_out = node.op(raveled_x, y, *new_idxs)
# But we must reshape the output to math the original shape
new_out = AdvancedIncSubtensor(
new_idx_list,
inplace=node.op.inplace,
set_instead_of_inc=node.op.set_instead_of_inc,
ignore_duplicates=node.op.ignore_duplicates
)(raveled_x, y, *new_tensor_inputs)
# But we must reshape the output to match the original shape
new_out = new_out.reshape(x_shape)

return [copy_stack_trace(node.outputs[0], new_out)]
Expand Down
Loading